import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
import torchvision.transforms as torch_tfms
Support in-notebook plotting
%matplotlib inline
Report versions
print('numpy version: {}'.format(np.__version__))
from matplotlib import __version__ as mplver
print('matplotlib version: {}'.format(mplver))
numpy version: 1.15.3 matplotlib version: 3.0.1
pv = sys.version_info
print('python version: {}.{}.{}'.format(pv.major, pv.minor, pv.micro))
python version: 3.7.0
Reload packages where content for package development
%load_ext autoreload
%autoreload 2
train_dir = '/Users/jcreinhold/Research/data/nn_test/real/train/'
val_dir = '/Users/jcreinhold/Research/data/nn_test/real/test/'
from niftidataset import *
patch_sz = 128
tfms = torch_tfms.Compose([RandomCrop2D(patch_sz, None), ToTensor()])
tds = NiftiDataset(train_dir+'t1', train_dir+'flair', tfms)
vds = NiftiDataset(val_dir+'t1', val_dir+'flair', tfms)
src,tgt = tds[0]
plt.imshow(np.rot90(src), cmap='gist_gray')
plt.axis('off');
import fastai as fai
import fastai.vision as faiv
import torchvision
from torch.utils.data import DataLoader
print(f'fastai version: {fai.__version__}')
print(f'pytorch version: {torch.__version__}')
print(f'torchvision version: {torchvision.__version__}')
fastai version: 1.0.15 pytorch version: 1.0.0.dev20181014 torchvision version: 0.2.1
patch_sz = 128
tfms = torch_tfms.Compose([RandomCrop2D(patch_sz, None), ToTensor(), ToFastaiImage()])
tds = NiftiDataset(train_dir+'t1', train_dir+'flair', tfms)
vds = NiftiDataset(val_dir+'t1', val_dir+'flair', tfms)
x,y = tds[0]
x.flip_lr()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
MR images are not within the normal range of values for an image, so the images will appear washed out with this plot.
x.data.max()
tensor(1203.0923)
tfms = [
faiv.flip_lr(p=0.5),
faiv.rotate(degrees=(-45,45.), p=0.5),
faiv.zoom(scale=(0.5,1.2), p=0.8)
]
idb = faiv.ImageDataBunch.create(tds, vds, bs=2, ds_tfms=(tfms,[]), num_workers=1, tfm_y=True)
for sample in idb.train_dl:
for x in sample:
print(x.shape)
plt.imshow(np.rot90(x[0,0,:,:]), cmap='gist_gray')
plt.axis('off')
plt.show()
break
torch.Size([2, 3, 128, 128])
torch.Size([2, 3, 128, 128])