Showcase a simple 3D resnet—built in pytorch and fastai—for MR image synthesis, which is the task of taking a specific MR image contrast and making it look like another MR image contrast (e.g., T1-weighted to FLAIR).
from pathlib import PosixPath
import os
import sys
import fastai.vision as faiv
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
import torchvision
Support in-notebook plotting
%matplotlib inline
Report versions
print('numpy version: {}'.format(np.__version__))
from matplotlib import __version__ as mplver
print('matplotlib version: {}'.format(mplver))
print(f'fastai version: {faiv.__version__}')
print(f'pytorch version: {torch.__version__}')
print(f'torchvision version: {torchvision.__version__}')
numpy version: 1.15.4 matplotlib version: 3.0.2 fastai version: 1.0.39 pytorch version: 1.0.0 torchvision version: 0.2.1
pv = sys.version_info
print('python version: {}.{}.{}'.format(pv.major, pv.minor, pv.micro))
python version: 3.7.1
Reload packages where content for package development
%load_ext autoreload
%autoreload 2
Check GPU
!nvidia-smi
Thu Jan 3 12:51:33 2019 +-----------------------------------------------------------------------------+ | NVIDIA-SMI 396.54 Driver Version: 396.54 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | |===============================+======================+======================| | 0 Tesla M40 24GB Off | 00000000:02:00.0 Off | 0 | | N/A 28C P8 17W / 250W | 11MiB / 22945MiB | 0% Default | +-------------------------------+----------------------+----------------------+ | 1 Tesla M40 24GB Off | 00000000:03:00.0 Off | 0 | | N/A 32C P8 17W / 250W | 11MiB / 22945MiB | 0% Default | +-------------------------------+----------------------+----------------------+ +-----------------------------------------------------------------------------+ | Processes: GPU Memory | | GPU PID Type Process name Usage | |=============================================================================| | No running processes found | +-----------------------------------------------------------------------------+
data_dir = PosixPath('/iacl/pg19/jacobr/zs/blog/')
!ls {data_dir/'t1/train'}
KKI2009-20-MPRAGE_zscore.nii.gz KKI2009-27-MPRAGE_zscore.nii.gz KKI2009-21-MPRAGE_zscore.nii.gz KKI2009-29-MPRAGE_zscore.nii.gz KKI2009-22-MPRAGE_zscore.nii.gz KKI2009-40-MPRAGE_zscore.nii.gz KKI2009-24-MPRAGE_zscore.nii.gz KKI2009-41-MPRAGE_zscore.nii.gz KKI2009-25-MPRAGE_zscore.nii.gz KKI2009-42-MPRAGE_zscore.nii.gz KKI2009-26-MPRAGE_zscore.nii.gz
dev = "cuda" if torch.cuda.is_available() else "cpu"
print(dev)
device = torch.device(dev)
cuda
import nibabel as nib
def open_nii(fn:str) -> faiv.Image:
""" Return fastai `Image` object created from NIfTI image in file `fn`."""
x = nib.load(str(fn)).get_data()
return faiv.Image(torch.Tensor(x))
class NiftiItemList(faiv.ImageItemList):
""" custom item list for nifti files """
def open(self, fn:faiv.PathOrStr)->faiv.Image: return open_nii(fn)
class NiftiNiftiList(NiftiItemList):
""" item list suitable for synthesis tasks """
_label_cls = NiftiItemList
from functools import singledispatch
@faiv.TfmPixel
@singledispatch
def crop(x, pct, axis:int) -> torch.Tensor:
"""" crop a 3d image along an axis """
s = x.shape
i0, i1 = int(s[axis]*pct[0]), int(s[axis]*pct[1])
return x[np.newaxis,i0:i1,:,:].contiguous() if axis == 0 else \
x[np.newaxis,:,i0:i1,:].contiguous() if axis == 1 else \
x[np.newaxis,:,:,i0:i1].contiguous()
tfms = [crop(pct=(0.20,0.80),axis=2)]
def get_y_fn(x):
parent = 'train' if 'train' in str(x) else 'valid'
fn = data_dir/'flair'/parent/f'{str(x.stem)[:10]}-FLAIR_reg_zscore.nii.gz'
return fn
idb = (NiftiNiftiList.from_folder(data_dir/'t1', extensions=('.gz'))
.split_by_folder()
.label_from_func(get_y_fn)
.transform((tfms,tfms), tfm_y=True)
.databunch(bs=2))
spectral_norm = nn.utils.spectral_norm
weight_norm = nn.utils.weight_norm
def conv3d(ni:int, nf:int, ks:int=3, stride:int=1, pad:int=1, norm='batch'):
bias = not norm == 'batch'
conv = faiv.init_default(nn.Conv3d(ni,nf,ks,stride,pad,bias=bias))
conv = spectral_norm(conv) if norm == 'spectral' else \
weight_norm(conv) if norm == 'weight' else conv
layers = [conv]
layers += [nn.ReLU(inplace=True)] # use inplace due to memory constraints
layers += [nn.BatchNorm3d(nf)] if norm == 'batch' else []
return nn.Sequential(*layers)
def res3d_block(ni, nf, ks=3, norm='batch', dense=False):
""" 3d Resnet block of `nf` features """
return faiv.SequentialEx(conv3d(ni, nf, ks, pad=ks//2, norm=norm),
conv3d(nf, nf, ks, pad=ks//2, norm=norm),
faiv.MergeLayer(dense))
norm = 'batch'
layers = ([res3d_block(1,15,7,norm=norm,dense=True)] +
[res3d_block(16,16,norm=norm) for _ in range(4)] +
[conv3d(16,1,ks=1,pad=0,norm=None)])
model = nn.Sequential(*layers)
loss = nn.MSELoss()
learner = faiv.Learner(idb, model, loss_func=loss)
learner.lr_find(num_it=50)
learner.recorder.plot()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
cbs = [faiv.callbacks.CSVLogger(learner, 'history')]
learner.fit_one_cycle(100, 1e-2, callbacks=cbs)
epoch | train_loss | valid_loss |
---|---|---|
1 | 1.626134 | 0.717183 |
2 | 1.209787 | 0.706559 |
3 | 1.014005 | 0.675807 |
4 | 0.902402 | 0.639397 |
5 | 0.826438 | 0.677543 |
6 | 0.768815 | 0.646400 |
7 | 0.723306 | 0.657138 |
8 | 0.689584 | 0.888098 |
9 | 0.660112 | 0.718608 |
10 | 0.640162 | 0.586548 |
11 | 0.623497 | 0.624218 |
12 | 0.606652 | 0.525309 |
13 | 0.590163 | 0.632344 |
14 | 0.576321 | 0.543100 |
15 | 0.565003 | 0.574600 |
16 | 0.555771 | 0.521884 |
17 | 0.546644 | 0.833132 |
18 | 0.540892 | 0.667426 |
19 | 0.536418 | 1.121620 |
20 | 0.531980 | 1.295221 |
21 | 0.525822 | 1.168827 |
22 | 0.517676 | 0.479367 |
23 | 0.511407 | 0.531983 |
24 | 0.505855 | 0.480099 |
25 | 0.499813 | 0.534779 |
26 | 0.492547 | 0.545909 |
27 | 0.486357 | 0.612363 |
28 | 0.482153 | 0.454708 |
29 | 0.476633 | 0.597723 |
30 | 0.471887 | 0.598259 |
31 | 0.469454 | 0.633815 |
32 | 0.466655 | 0.484067 |
33 | 0.461408 | 0.566282 |
34 | 0.454736 | 0.533734 |
35 | 0.450344 | 0.475311 |
36 | 0.447882 | 0.569890 |
37 | 0.444179 | 0.624154 |
38 | 0.444034 | 0.751465 |
39 | 0.443096 | 0.515025 |
40 | 0.439492 | 0.469929 |
41 | 0.436208 | 0.434505 |
42 | 0.432940 | 0.432327 |
43 | 0.429206 | 0.423704 |
44 | 0.424807 | 0.414768 |
45 | 0.419453 | 0.439060 |
46 | 0.415468 | 0.413525 |
47 | 0.410099 | 0.416840 |
48 | 0.405724 | 0.405314 |
49 | 0.401803 | 0.446054 |
50 | 0.397422 | 0.417251 |
51 | 0.393300 | 0.427032 |
52 | 0.390211 | 0.482145 |
53 | 0.387707 | 0.398383 |
54 | 0.384359 | 0.407789 |
55 | 0.382120 | 0.498587 |
56 | 0.379317 | 0.398915 |
57 | 0.377551 | 0.390387 |
58 | 0.374351 | 0.393330 |
59 | 0.370720 | 0.393716 |
60 | 0.368153 | 0.386991 |
61 | 0.366070 | 0.374463 |
62 | 0.363586 | 0.399881 |
63 | 0.361705 | 0.376201 |
64 | 0.359122 | 0.379452 |
65 | 0.357344 | 0.404680 |
66 | 0.355696 | 0.372394 |
67 | 0.353458 | 0.406953 |
68 | 0.351903 | 0.411071 |
69 | 0.349445 | 0.391066 |
70 | 0.348607 | 0.371273 |
71 | 0.346937 | 0.374092 |
72 | 0.345092 | 0.369752 |
73 | 0.344296 | 0.369881 |
74 | 0.342566 | 0.367897 |
75 | 0.340336 | 0.372588 |
76 | 0.339304 | 0.381993 |
77 | 0.337429 | 0.404921 |
78 | 0.336967 | 0.368794 |
79 | 0.335238 | 0.393927 |
80 | 0.334824 | 0.370238 |
81 | 0.333511 | 0.400015 |
82 | 0.331918 | 0.366955 |
83 | 0.330290 | 0.366379 |
84 | 0.329549 | 0.385348 |
85 | 0.328735 | 0.361374 |
86 | 0.327901 | 0.364950 |
87 | 0.328000 | 0.367197 |
88 | 0.327162 | 0.359711 |
89 | 0.326404 | 0.361960 |
90 | 0.325041 | 0.358708 |
91 | 0.324648 | 0.360368 |
92 | 0.323696 | 0.360453 |
93 | 0.323285 | 0.358685 |
94 | 0.321750 | 0.357828 |
95 | 0.321404 | 0.357338 |
96 | 0.320532 | 0.357547 |
97 | 0.320772 | 0.357742 |
98 | 0.320467 | 0.357825 |
99 | 0.319309 | 0.357854 |
100 | 0.318576 | 0.358211 |
learner.save('test')
import nibabel as nib
obj = nib.load(str(data_dir/'t1/test/KKI2009-11-MPRAGE_zscore.nii.gz'))
test = torch.Tensor(obj.get_data()).to(device)
res = learner.model.forward(test[None,None,...]).cpu().detach().numpy()
plt.figure(figsize=(8,8));
plt.imshow(np.rot90(np.squeeze(res)[:,:,150],3),cmap='gray');
plt.axis('off');
nib.Nifti1Image(res,obj.affine,obj.header).to_filename('test.nii.gz')
!ls {data_dir/'flair/test'}
KKI2009-11-FLAIR_reg_zscore.nii.gz KKI2009-19-FLAIR_reg_zscore.nii.gz KKI2009-17-FLAIR_reg_zscore.nii.gz
flair = nib.load(str(data_dir/'flair/test/KKI2009-11-FLAIR_reg_zscore.nii.gz'))
i = 150
def imp(data,ax,i,v,t=''): ax.imshow(np.rot90(data[:,:,i],3),cmap='gray',vmin=v[0],vmax=v[1]); ax.axis('off'); ax.set_title(t);
fig, (ax1,ax2,ax3) = plt.subplots(1,3,figsize=(16,12))
imp(obj.get_data(),ax1,i,(None,None),'T1');imp(flair.get_data(),ax2,i,(0,3.5),'FLAIR');imp(res.squeeze(),ax3,i,(0,3.5),'Syn');
plt.savefig('~/Downloads/blg.png',dpi=200)