Use the context encoder-style method as described in [1]
[1] D. Pathak et al. "Context encoders: Feature learning by inpainting."
CVPR. 2016.
from typing import Callable, List, Optional, Tuple, Union
from glob import glob
import math
import os
import random
import sys
gpu_id = 1
os.environ["CUDA_VISIBLE_DEVICES"] = f'{gpu_id}'
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision
from selfsupervised3d import *
Support in-notebook plotting
%matplotlib inline
Report versions
print('numpy version: {}'.format(np.__version__))
from matplotlib import __version__ as mplver
print('matplotlib version: {}'.format(mplver))
print(f'pytorch version: {torch.__version__}')
print(f'torchvision version: {torchvision.__version__}')
numpy version: 1.17.2 matplotlib version: 3.1.1 pytorch version: 1.5.0 torchvision version: 0.6.0a0+82fd1c8
pv = sys.version_info
print('python version: {}.{}.{}'.format(pv.major, pv.minor, pv.micro))
python version: 3.7.7
Reload packages where content for package development
%load_ext autoreload
%autoreload 2
Check GPU(s)
!nvidia-smi
Fri May 22 20:28:25 2020 +-----------------------------------------------------------------------------+ | NVIDIA-SMI 430.40 Driver Version: 430.40 CUDA Version: 10.1 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | |===============================+======================+======================| | 0 Tesla M40 24GB Off | 00000000:02:00.0 Off | 0 | | N/A 28C P8 16W / 250W | 0MiB / 22945MiB | 0% Default | +-------------------------------+----------------------+----------------------+ | 1 Tesla M40 24GB Off | 00000000:03:00.0 Off | 0 | | N/A 39C P8 16W / 250W | 0MiB / 22945MiB | 0% Default | +-------------------------------+----------------------+----------------------+ +-----------------------------------------------------------------------------+ | Processes: GPU Memory | | GPU PID Type Process name Usage | |=============================================================================| | No running processes found | +-----------------------------------------------------------------------------+
assert torch.cuda.is_available()
device = torch.device('cuda')
torch.backends.cudnn.benchmark = True
Set seeds for better reproducibility. See this note before using multiprocessing.
seed = 1339
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
Get the location of the training (and validation) data
train_dir = '/iacl/pg20/jacobr/ixi/subsets/hh/'
t1_dir = os.path.join(train_dir, 't1')
t2_dir = os.path.join(train_dir, 't2')
t1_fns = glob(os.path.join(t1_dir, '*.nii*'))
t2_fns = glob(os.path.join(t2_dir, '*.nii*'))
assert len(t1_fns) == len(t2_fns) and len(t1_fns) != 0
Look at an axial view of the source T1-weighted (T1-w) and target T2-weighted (T2-w) images.
def imshow(x, ax, title, n_rot=3):
ax.imshow(np.rot90(x,n_rot), aspect='equal', cmap='gray')
ax.set_title(title,fontsize=22)
ax.axis('off')
j = 100
t1_ex, t2_ex = nib.load(t1_fns[0]).get_data(), nib.load(t2_fns[0]).get_data()
fig,(ax1,ax2) = plt.subplots(1,2,figsize=(16,9))
imshow(t1_ex[...,j], ax1, 'T1', 1)
imshow(t2_ex[...,j], ax2, 'T2', 1)
mask = create_multiblock_mask(t1_ex > t1_ex.mean(), 15, 20)
j = 100
fig,(ax1,ax2) = plt.subplots(1,2,figsize=(16,9))
imshow(mask[...,j], ax1, 'Mask', 1)
imshow(t1_ex[...,j] * (1-mask[...,j]), ax2, 'Masked Img.', 1)
Hyperparameters, optimizers, logging, etc.
data_dirs = [t1_dir]
# system setup
load_model = False
# logging setup
log_rate = 5 # print losses every log_rate epochs
version = 'context_v1' # naming scheme of model to load
save_rate = 10 # save models every save_rate epochs
# dataset params
block_size = 20
n_blocks = 15
patch_size = 128
# model, optimizer, loss, and training parameters
alpha = (0.99, 0.01) # weight for regression & discriminator loss, resp.
valid_split = 0.1
batch_size = 2
n_jobs = batch_size
n_epochs = 50
input_channels = len(data_dirs)
use_adam = True
opt_kwargs = dict(lr=1e-3, betas=(0.9,0.99), weight_decay=1e-6) if use_adam else \
dict(lr=5e-3, momentum=0.9)
use_scheduler = True
scheduler_kwargs = dict(step_size=10, gamma=0.5)
def init_fn(worker_id):
random.seed((torch.initial_seed() + worker_id) % (2**32))
np.random.seed((torch.initial_seed() + worker_id) % (2**32))
# setup training and validation dataloaders
dataset = ContextDataset(data_dirs, n_blocks=n_blocks, size=block_size, patch_size=patch_size)
num_train = len(dataset)
indices = list(range(num_train))
split = int(valid_split * num_train)
valid_idx = np.random.choice(indices, size=split, replace=False)
train_idx = list(set(indices) - set(valid_idx))
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = DataLoader(dataset, sampler=train_sampler, batch_size=batch_size,
worker_init_fn=init_fn, num_workers=n_jobs,
pin_memory=True, collate_fn=context_collate)
valid_loader = DataLoader(dataset, sampler=valid_sampler, batch_size=batch_size,
worker_init_fn=init_fn, num_workers=n_jobs,
pin_memory=True, collate_fn=context_collate)
print(f'Number of training images: {num_train-split}')
print(f'Number of validation images: {split}')
Number of training images: 121 Number of validation images: 13
src, tgt, mask = dataset[2]
fig,(ax1,ax2,ax3) = plt.subplots(1,3,figsize=(16,9))
src = src.detach().cpu().numpy().squeeze()
tgt = tgt.detach().cpu().numpy().squeeze()
mask = mask.detach().cpu().numpy().squeeze()
imshow(src[:,64,:], ax1, 'Src.')
imshow(tgt[:,64,:], ax2, 'Tgt.')
imshow(mask[:,64,:], ax3, 'Mask')
recon_model = FrankUNet(ic=input_channels, nc=32)
discriminator = PatchDiscriminator(1, 72)
def num_params(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Number of trainable parameters in reconstruction model: {num_params(recon_model)}')
print(f'Number of trainable parameters in discriminator: {num_params(discriminator)}')
Number of trainable parameters in reconstruction model: 13887425 Number of trainable parameters in discriminator: 13978153
if load_model:
recon_model.load_state_dict(torch.load(f'recon_model_{version}.pth'))
discriminator.load_state_dict(torch.load(f'discriminator_{version}.pth'))
recon_model.to(device)
discriminator.to(device)
optim_cls = torch.optim.AdamW if use_adam else torch.optim.SGD
optG = optim_cls(recon_model.parameters(), **opt_kwargs)
optD = optim_cls(discriminator.parameters(), **opt_kwargs)
criterion = InpaintLoss(alpha=alpha)
if use_scheduler:
recon_scheduler = torch.optim.lr_scheduler.StepLR(optG, **scheduler_kwargs)
disc_scheduler = torch.optim.lr_scheduler.StepLR(optD, **scheduler_kwargs)
train_losses, valid_losses = [], []
n_batches = len(train_loader)
for t in range(1, n_epochs + 1):
# training
t_losses = []
recon_model.train()
discriminator.train()
for i, (src, tgt, mask) in enumerate(train_loader):
src, tgt, mask = src.to(device), tgt.to(device), mask.to(device)
optG.zero_grad()
optD.zero_grad()
recon = recon_model(src)
disc_real = discriminator(tgt)
disc_fake = discriminator(recon)
loss = criterion(recon, tgt, mask, disc_real, disc_fake)
t_losses.append(loss.item())
loss.backward()
optG.step()
optD.step()
train_losses.append(t_losses)
# validation
v_losses = []
recon_model.eval()
discriminator.eval()
with torch.no_grad():
for i, (src, tgt, mask) in enumerate(valid_loader):
src, tgt, mask = src.to(device), tgt.to(device), mask.to(device)
recon = recon_model(src)
disc_real = discriminator(tgt)
disc_fake = discriminator(recon)
loss = criterion(recon, tgt, mask, disc_real, disc_fake)
v_losses.append(loss.item())
valid_losses.append(v_losses)
# log, step scheduler, and save results from epoch
if not np.all(np.isfinite(t_losses)):
raise RuntimeError('NaN or Inf in training loss, cannot recover. Exiting.')
if t % log_rate == 0:
log = (f'Epoch: {t} - TL: {np.mean(t_losses):.2e}, VL: {np.mean(v_losses):.2e}')
print(log)
if use_scheduler:
recon_scheduler.step()
disc_scheduler.step()
if t % save_rate == 0:
torch.save(recon_model.state_dict(), f'recon_model_{version}_{t}.pth')
torch.save(discriminator.state_dict(), f'discriminator_{version}_{t}.pth')
Epoch: 5 - TL: 1.71e-01, VL: 1.60e-01 Epoch: 10 - TL: 1.40e-01, VL: 1.49e-01 Epoch: 15 - TL: 1.25e-01, VL: 1.42e-01 Epoch: 20 - TL: 1.22e-01, VL: 1.35e-01 Epoch: 25 - TL: 1.11e-01, VL: 1.22e-01 Epoch: 30 - TL: 1.11e-01, VL: 1.13e-01 Epoch: 35 - TL: 1.07e-01, VL: 1.01e-01 Epoch: 40 - TL: 1.08e-01, VL: 1.31e-01 Epoch: 45 - TL: 1.06e-01, VL: 1.14e-01 Epoch: 50 - TL: 1.03e-01, VL: 1.09e-01
save_model = True
if save_model:
torch.save(recon_model.state_dict(), f'recon_model_{version}.pth')
torch.save(discriminator.state_dict(), f'discriminator_{version}.pth')
def tidy_losses(train, valid):
out = {'epoch': [], 'type': [], 'value': [], 'phase': []}
for i, (tl,vl) in enumerate(zip(train,valid),1):
for tli in tl:
out['epoch'].append(i)
out['type'].append('loss')
out['value'].append(tli)
out['phase'].append('train')
for vli in vl:
out['epoch'].append(i)
out['type'].append('loss')
out['value'].append(vli)
out['phase'].append('valid')
return pd.DataFrame(out)
losses = tidy_losses(train_losses, valid_losses)
f, ax1 = plt.subplots(1,1,figsize=(12, 8),sharey=True)
sns.lineplot(x='epoch',y='value',hue='phase',data=losses,ci='sd',ax=ax1,lw=3);
ax1.set_yscale('log');
ax1.set_title('Losses');
save_losses = False
if save_losses:
f.savefig(f'losses_{version}.pdf')
losses.to_csv(f'losses_{version}.csv')
j = 100
fig,(ax1,ax2,ax3,ax4) = plt.subplots(1,4,figsize=(16,9))
try:
src = src.squeeze().cpu().detach().numpy()
tgt = tgt.squeeze().cpu().detach().numpy()
recon = recon.squeeze().cpu().detach().numpy()
mask = mask.squeeze().cpu().detach().numpy()
except AttributeError:
pass
imshow(src[...,j], ax1, 'Src.', 1)
imshow(tgt[...,j], ax2, 'Tgt.', 1)
imshow(recon[...,j], ax3, 'Recon.', 1)
imshow(mask[...,j], ax4, 'Mask', 1)
The hyperparameters above can probably modified to improve this result. Also, it would be better to use a brain mask as an additional input to the ContextDataset
(see the mask_dir
keyword argument); a brain mask would keep the random blocks inside the brain which would force the network to learn only how to fill in brain versus neck and other tissue.