Self-supervised learning in 3D images

Use the context encoder-style method as described in [1]

[1] D. Pathak et al. "Context encoders: Feature learning by inpainting."
    CVPR. 2016.

Setup notebook

In [1]:
from typing import Callable, List, Optional, Tuple, Union

from glob import glob
import math
import os
import random
import sys

gpu_id = 1
os.environ["CUDA_VISIBLE_DEVICES"] = f'{gpu_id}'

import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import pandas as pd
import seaborn as sns

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision

from selfsupervised3d import *

Support in-notebook plotting

In [2]:
%matplotlib inline

Report versions

In [3]:
print('numpy version: {}'.format(np.__version__))
from matplotlib import __version__ as mplver
print('matplotlib version: {}'.format(mplver))
print(f'pytorch version: {torch.__version__}')
print(f'torchvision version: {torchvision.__version__}')
numpy version: 1.17.2
matplotlib version: 3.1.1
pytorch version: 1.5.0
torchvision version: 0.6.0a0+82fd1c8
In [4]:
pv = sys.version_info
print('python version: {}.{}.{}'.format(pv.major, pv.minor, pv.micro))
python version: 3.7.7

Reload packages where content for package development

In [5]:
%load_ext autoreload
%autoreload 2

Check GPU(s)

In [6]:
!nvidia-smi
Fri May 22 20:28:25 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 430.40       Driver Version: 430.40       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  Tesla M40 24GB      Off  | 00000000:02:00.0 Off |                    0 |
| N/A   28C    P8    16W / 250W |      0MiB / 22945MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  Tesla M40 24GB      Off  | 00000000:03:00.0 Off |                    0 |
| N/A   39C    P8    16W / 250W |      0MiB / 22945MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
In [7]:
assert torch.cuda.is_available()
device = torch.device('cuda')
torch.backends.cudnn.benchmark = True

Set seeds for better reproducibility. See this note before using multiprocessing.

In [8]:
seed = 1339
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

Setup training and validation data

Get the location of the training (and validation) data

In [9]:
train_dir = '/iacl/pg20/jacobr/ixi/subsets/hh/'
t1_dir = os.path.join(train_dir, 't1')
t2_dir = os.path.join(train_dir, 't2')
In [10]:
t1_fns = glob(os.path.join(t1_dir, '*.nii*'))
t2_fns = glob(os.path.join(t2_dir, '*.nii*'))
assert len(t1_fns) == len(t2_fns) and len(t1_fns) != 0

Look at example training dataset

Look at an axial view of the source T1-weighted (T1-w) and target T2-weighted (T2-w) images.

In [11]:
def imshow(x, ax, title, n_rot=3):
    ax.imshow(np.rot90(x,n_rot), aspect='equal', cmap='gray')
    ax.set_title(title,fontsize=22)
    ax.axis('off')
In [12]:
j = 100
t1_ex, t2_ex = nib.load(t1_fns[0]).get_data(), nib.load(t2_fns[0]).get_data()
fig,(ax1,ax2) = plt.subplots(1,2,figsize=(16,9))
imshow(t1_ex[...,j], ax1, 'T1', 1)
imshow(t2_ex[...,j], ax2, 'T2', 1)
In [13]:
mask = create_multiblock_mask(t1_ex > t1_ex.mean(), 15, 20)
In [14]:
j = 100
fig,(ax1,ax2) = plt.subplots(1,2,figsize=(16,9))
imshow(mask[...,j], ax1, 'Mask', 1)
imshow(t1_ex[...,j] * (1-mask[...,j]), ax2, 'Masked Img.', 1)

Setup training

Hyperparameters, optimizers, logging, etc.

In [15]:
data_dirs = [t1_dir]
In [16]:
# system setup
load_model = False

# logging setup
log_rate = 5  # print losses every log_rate epochs
version = 'context_v1'  # naming scheme of model to load
save_rate = 10   # save models every save_rate epochs

# dataset params
block_size = 20
n_blocks = 15
patch_size = 128

# model, optimizer, loss, and training parameters
alpha = (0.99, 0.01)  # weight for regression & discriminator loss, resp.
valid_split = 0.1
batch_size = 2
n_jobs = batch_size
n_epochs = 50
input_channels = len(data_dirs)
use_adam = True
opt_kwargs = dict(lr=1e-3, betas=(0.9,0.99), weight_decay=1e-6) if use_adam else \
             dict(lr=5e-3, momentum=0.9)
use_scheduler = True
scheduler_kwargs = dict(step_size=10, gamma=0.5)
In [17]:
def init_fn(worker_id):
    random.seed((torch.initial_seed() + worker_id) % (2**32))
    np.random.seed((torch.initial_seed() + worker_id) % (2**32))
In [18]:
# setup training and validation dataloaders
dataset = ContextDataset(data_dirs, n_blocks=n_blocks, size=block_size, patch_size=patch_size)
num_train = len(dataset)
indices = list(range(num_train))
split = int(valid_split * num_train)
valid_idx = np.random.choice(indices, size=split, replace=False)
train_idx = list(set(indices) - set(valid_idx))
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = DataLoader(dataset, sampler=train_sampler, batch_size=batch_size, 
                          worker_init_fn=init_fn, num_workers=n_jobs,
                          pin_memory=True, collate_fn=context_collate)
valid_loader = DataLoader(dataset, sampler=valid_sampler, batch_size=batch_size, 
                          worker_init_fn=init_fn, num_workers=n_jobs,
                          pin_memory=True, collate_fn=context_collate)
In [19]:
print(f'Number of training images: {num_train-split}')
print(f'Number of validation images: {split}')
Number of training images: 121
Number of validation images: 13
In [20]:
src, tgt, mask = dataset[2]
fig,(ax1,ax2,ax3) = plt.subplots(1,3,figsize=(16,9))
src = src.detach().cpu().numpy().squeeze()
tgt = tgt.detach().cpu().numpy().squeeze()
mask = mask.detach().cpu().numpy().squeeze()
imshow(src[:,64,:], ax1, 'Src.')
imshow(tgt[:,64,:], ax2, 'Tgt.')
imshow(mask[:,64,:], ax3, 'Mask')
In [21]:
recon_model = FrankUNet(ic=input_channels, nc=32)
discriminator = PatchDiscriminator(1, 72) 
In [22]:
def num_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
In [23]:
print(f'Number of trainable parameters in reconstruction model: {num_params(recon_model)}')
print(f'Number of trainable parameters in discriminator: {num_params(discriminator)}')
Number of trainable parameters in reconstruction model: 13887425
Number of trainable parameters in discriminator: 13978153
In [24]:
if load_model: 
    recon_model.load_state_dict(torch.load(f'recon_model_{version}.pth'))
    discriminator.load_state_dict(torch.load(f'discriminator_{version}.pth'))
In [25]:
recon_model.to(device)
discriminator.to(device)
optim_cls = torch.optim.AdamW if use_adam else torch.optim.SGD
optG = optim_cls(recon_model.parameters(), **opt_kwargs)
optD = optim_cls(discriminator.parameters(), **opt_kwargs)
criterion = InpaintLoss(alpha=alpha)
if use_scheduler:
    recon_scheduler = torch.optim.lr_scheduler.StepLR(optG, **scheduler_kwargs)
    disc_scheduler = torch.optim.lr_scheduler.StepLR(optD, **scheduler_kwargs)

Train model

In [26]:
train_losses, valid_losses = [], []
n_batches = len(train_loader)
In [27]:
for t in range(1, n_epochs + 1):
    # training
    t_losses = []
    recon_model.train()
    discriminator.train()
    for i, (src, tgt, mask) in enumerate(train_loader):
        src, tgt, mask = src.to(device), tgt.to(device), mask.to(device)
        optG.zero_grad()
        optD.zero_grad()
        recon = recon_model(src)
        disc_real = discriminator(tgt)
        disc_fake = discriminator(recon)
        loss = criterion(recon, tgt, mask, disc_real, disc_fake)
        t_losses.append(loss.item())
        loss.backward()
        optG.step()
        optD.step()
    train_losses.append(t_losses)

    # validation
    v_losses = []
    recon_model.eval()
    discriminator.eval()
    with torch.no_grad():
        for i, (src, tgt, mask) in enumerate(valid_loader):
            src, tgt, mask = src.to(device), tgt.to(device), mask.to(device)
            recon = recon_model(src)
            disc_real = discriminator(tgt)
            disc_fake = discriminator(recon)
            loss = criterion(recon, tgt, mask, disc_real, disc_fake)
            v_losses.append(loss.item())
        valid_losses.append(v_losses)

    # log, step scheduler, and save results from epoch
    if not np.all(np.isfinite(t_losses)): 
        raise RuntimeError('NaN or Inf in training loss, cannot recover. Exiting.')
    if t % log_rate == 0:
        log = (f'Epoch: {t} - TL: {np.mean(t_losses):.2e}, VL: {np.mean(v_losses):.2e}')
        print(log)
    if use_scheduler:
        recon_scheduler.step()
        disc_scheduler.step()
    if t % save_rate == 0:
        torch.save(recon_model.state_dict(), f'recon_model_{version}_{t}.pth')
        torch.save(discriminator.state_dict(), f'discriminator_{version}_{t}.pth')
Epoch: 5 - TL: 1.71e-01, VL: 1.60e-01
Epoch: 10 - TL: 1.40e-01, VL: 1.49e-01
Epoch: 15 - TL: 1.25e-01, VL: 1.42e-01
Epoch: 20 - TL: 1.22e-01, VL: 1.35e-01
Epoch: 25 - TL: 1.11e-01, VL: 1.22e-01
Epoch: 30 - TL: 1.11e-01, VL: 1.13e-01
Epoch: 35 - TL: 1.07e-01, VL: 1.01e-01
Epoch: 40 - TL: 1.08e-01, VL: 1.31e-01
Epoch: 45 - TL: 1.06e-01, VL: 1.14e-01
Epoch: 50 - TL: 1.03e-01, VL: 1.09e-01
In [28]:
save_model = True
if save_model:
    torch.save(recon_model.state_dict(), f'recon_model_{version}.pth')
    torch.save(discriminator.state_dict(), f'discriminator_{version}.pth')

Analyze training

In [29]:
def tidy_losses(train, valid):
    out = {'epoch': [], 'type': [], 'value': [], 'phase': []}
    for i, (tl,vl) in enumerate(zip(train,valid),1):
        for tli in tl:
            out['epoch'].append(i)
            out['type'].append('loss')
            out['value'].append(tli)
            out['phase'].append('train')
        for vli in vl:
            out['epoch'].append(i)
            out['type'].append('loss')
            out['value'].append(vli)
            out['phase'].append('valid')
    return pd.DataFrame(out)
In [30]:
losses = tidy_losses(train_losses, valid_losses)
In [31]:
f, ax1 = plt.subplots(1,1,figsize=(12, 8),sharey=True)
sns.lineplot(x='epoch',y='value',hue='phase',data=losses,ci='sd',ax=ax1,lw=3);
ax1.set_yscale('log');
ax1.set_title('Losses');
In [32]:
save_losses = False
if save_losses:
    f.savefig(f'losses_{version}.pdf')
    losses.to_csv(f'losses_{version}.csv')
In [33]:
j = 100
fig,(ax1,ax2,ax3,ax4) = plt.subplots(1,4,figsize=(16,9))
try:
    src   = src.squeeze().cpu().detach().numpy()
    tgt   = tgt.squeeze().cpu().detach().numpy()
    recon = recon.squeeze().cpu().detach().numpy()
    mask  = mask.squeeze().cpu().detach().numpy()
except AttributeError:
    pass
imshow(src[...,j], ax1, 'Src.', 1)
imshow(tgt[...,j], ax2, 'Tgt.', 1)
imshow(recon[...,j], ax3, 'Recon.', 1)
imshow(mask[...,j], ax4, 'Mask', 1)

The hyperparameters above can probably modified to improve this result. Also, it would be better to use a brain mask as an additional input to the ContextDataset (see the mask_dir keyword argument); a brain mask would keep the random blocks inside the brain which would force the network to learn only how to fill in brain versus neck and other tissue.