Self-supervised learning in 3D images

Use the proposed heatmap method as described in [1]

[1] M. Blendowski et al. "How to Learn from Unlabeled Volume Data:
    Self-supervised 3D Context Feature Learning." MICCAI. 2019.

Setup notebook

In [1]:
from typing import Callable, List, Optional, Tuple, Union

from glob import glob
import math
import os
import random
import sys

gpu_id = 0
os.environ["CUDA_VISIBLE_DEVICES"] = f'{gpu_id}'

import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import pandas as pd
import seaborn as sns

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision

from selfsupervised3d import *

Support in-notebook plotting

In [2]:
%matplotlib inline

Report versions

In [3]:
print('numpy version: {}'.format(np.__version__))
from matplotlib import __version__ as mplver
print('matplotlib version: {}'.format(mplver))
print(f'pytorch version: {torch.__version__}')
print(f'torchvision version: {torchvision.__version__}')
numpy version: 1.17.2
matplotlib version: 3.1.1
pytorch version: 1.5.0
torchvision version: 0.6.0a0+82fd1c8
In [4]:
pv = sys.version_info
print('python version: {}.{}.{}'.format(pv.major, pv.minor, pv.micro))
python version: 3.7.7

Reload packages where content for package development

In [5]:
%load_ext autoreload
%autoreload 2

Check GPU(s)

In [6]:
!nvidia-smi
Fri May  1 14:09:36 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 430.40       Driver Version: 430.40       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  Tesla M40 24GB      Off  | 00000000:02:00.0 Off |                    0 |
| N/A   39C    P0    58W / 250W |   1401MiB / 22945MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  Tesla M40 24GB      Off  | 00000000:03:00.0 Off |                    0 |
| N/A   57C    P0   147W / 250W |  18818MiB / 22945MiB |    100%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|    0     28995      C   ...r/miniconda3/envs/synthtorch/bin/python  1390MiB |
|    1     10555      C   python3                                    18807MiB |
+-----------------------------------------------------------------------------+
In [7]:
assert torch.cuda.is_available()
device = torch.device('cuda')
torch.backends.cudnn.benchmark = True

Set seeds for better reproducibility. See this note before using multiprocessing.

In [8]:
seed = 1336
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

Setup training and validation data

Get the location of the training (and validation) data

In [9]:
train_dir = '/iacl/pg20/jacobr/ixi/subsets/hh/'
t1_dir = os.path.join(train_dir, 't1')
t2_dir = os.path.join(train_dir, 't2')
In [10]:
t1_fns = glob(os.path.join(t1_dir, '*.nii*'))
t2_fns = glob(os.path.join(t2_dir, '*.nii*'))
assert len(t1_fns) == len(t2_fns) and len(t1_fns) != 0

Look at example training dataset

Look at an axial view of the source T1-weighted (T1-w) and target T2-weighted (T2-w) images.

In [11]:
def imshow(x, ax, title, n_rot=3, **kwargs):
    ax.imshow(np.rot90(x,n_rot), aspect='equal', cmap='gray', **kwargs)
    ax.set_title(title,fontsize=22)
    ax.axis('off')
In [12]:
j = 100
t1_ex, t2_ex = nib.load(t1_fns[0]).get_data(), nib.load(t2_fns[0]).get_data()
fig,(ax1,ax2) = plt.subplots(1,2,figsize=(16,9))
imshow(t1_ex[...,j], ax1, 'T1', 1)
imshow(t2_ex[...,j], ax2, 'T2', 1)
In [13]:
x = torch.from_numpy(t1_ex).unsqueeze(0)
In [14]:
(ctr, qry), (dp_goal, hm_goal) = blendowski_patches(x, min_off_inplane=0., max_off_inplane=0.7, throughplane_axis=1)
In [15]:
ctr = ctr.squeeze().cpu().detach().numpy()
qry = qry.squeeze().cpu().detach().numpy()
hm_goal = hm_goal.squeeze()
In [16]:
dx, dy = dp_goal
print(f'dx: {dx:0.3f}, dy: {dy:0.3f}')
dx: -0.427, dy: 0.501
In [17]:
print(ctr.shape, qry.shape)
(3, 42, 42) (3, 42, 42)
In [18]:
j = 12
fig,(ax1,ax2,ax3) = plt.subplots(1,3,figsize=(16,9))
imshow(ctr[1,...], ax1, 'CTR', 0)
imshow(qry[1,...], ax2, 'QRY', 0)
imshow(hm_goal, ax3, 'HM', 0)

Setup training

Hyperparameters, optimizers, logging, etc.

In [19]:
data_dirs = [t1_dir]
In [20]:
# system setup
load_model = False

# logging setup
log_rate = 10  # print losses every log_rate epochs
version = 'blendowski_v1'  # naming scheme of model to load
save_rate = 100   # save models every save_rate epochs

# model, optimizer, loss, and training parameters
valid_split = 0.1
batch_size = 8
n_jobs = 8
n_epochs = 500
stack_dim = 3
input_channels = stack_dim * len(data_dirs)
descriptor_size = 128
use_adam = True
opt_kwargs = dict(lr=1e-3, betas=(0.9,0.99), weight_decay=1e-6) if use_adam else \
             dict(lr=5e-3, momentum=0.9)
use_scheduler = True
scheduler_kwargs = dict(step_size=100, gamma=0.5)
In [21]:
def init_fn(worker_id):
    random.seed((torch.initial_seed() + worker_id) % (2**32))
    np.random.seed((torch.initial_seed() + worker_id) % (2**32))
In [22]:
# setup training and validation dataloaders
dataset = BlendowskiDataset(data_dirs, stack_dim=stack_dim, throughplane_axis=1)
num_train = len(dataset)
indices = list(range(num_train))
split = int(valid_split * num_train)
valid_idx = np.random.choice(indices, size=split, replace=False)
train_idx = list(set(indices) - set(valid_idx))
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = DataLoader(dataset, sampler=train_sampler, batch_size=batch_size, 
                          worker_init_fn=init_fn, num_workers=n_jobs,
                          pin_memory=True, collate_fn=blendowski_collate)
valid_loader = DataLoader(dataset, sampler=valid_sampler, batch_size=batch_size, 
                          worker_init_fn=init_fn, num_workers=n_jobs,
                          pin_memory=True, collate_fn=blendowski_collate)
In [23]:
print(f'Number of training images: {num_train-split}')
print(f'Number of validation images: {split}')
Number of training images: 121
Number of validation images: 13
In [24]:
embedding_model = D2DConvNet(input_channels=input_channels, descriptor_size=descriptor_size)
decoder_model = HeatNet(descriptor_size=descriptor_size)
In [25]:
def num_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
In [26]:
print(f'Number of trainable parameters in embedding model: {num_params(embedding_model)}')
print(f'Number of trainable parameters in decoder model: {num_params(decoder_model)}')
Number of trainable parameters in embedding model: 629184
Number of trainable parameters in decoder model: 111993
In [27]:
if load_model: 
    embedding_model.load_state_dict(torch.load(f'embedding_model_{version}.pth'))
    decoder_model.load_state_dict(torch.load(f'decoder_model_{version}.pth'))
In [28]:
embedding_model.to(device)
decoder_model.to(device)
optim_cls = torch.optim.AdamW if use_adam else torch.optim.SGD
embedding_opt = optim_cls(embedding_model.parameters(), **opt_kwargs)
decoder_opt = optim_cls(decoder_model.parameters(), **opt_kwargs)
criterion = nn.MSELoss()
if use_scheduler: 
    embedding_scheduler = torch.optim.lr_scheduler.StepLR(embedding_opt, **scheduler_kwargs)
    decoder_scheduler = torch.optim.lr_scheduler.StepLR(decoder_opt, **scheduler_kwargs)

Train model

In [29]:
train_losses, valid_losses = [], []
n_batches = len(train_loader)
In [30]:
min_off_inplane = np.linspace(0.25, 0.0, n_epochs)
max_off_inplane = np.linspace(0.30, 0.7, n_epochs)
In [31]:
for t in range(1, n_epochs + 1):
    # training
    t_losses = []
    embedding_model.train()
    decoder_model.train()
    for i, ((ctr, qry), (_, goal)) in enumerate(train_loader):
        ctr, qry, goal = ctr.to(device), qry.to(device), goal.to(device)
        embedding_opt.zero_grad()
        decoder_opt.zero_grad()
        ctr_f = embedding_model(ctr)
        qry_f = embedding_model(qry)
        out = decoder_model(ctr_f, qry_f)
        loss = criterion(out, goal)
        t_losses.append(loss.item())
        loss.backward()
        embedding_opt.step()
        decoder_opt.step()
    train_losses.append(t_losses)

    # validation
    v_losses = []
    embedding_model.eval()
    decoder_model.eval()
    with torch.no_grad():
        for i, ((ctr, qry), (_, goal)) in enumerate(valid_loader):
            ctr, qry, goal = ctr.to(device), qry.to(device), goal.to(device)
            ctr_f = embedding_model(ctr)
            qry_f = embedding_model(qry)
            out = decoder_model(ctr_f, qry_f)
            loss = criterion(out, goal)
            v_losses.append(loss.item())
        valid_losses.append(v_losses)

    # expand inplane offset range as per paper
    dataset.min_off_inplane = min_off_inplane[t-1]
    dataset.max_off_inplane = max_off_inplane[t-1]
    
    # log, step scheduler, and save results from epoch
    if not np.all(np.isfinite(t_losses)): 
        raise RuntimeError('NaN or Inf in training loss, cannot recover. Exiting.')
    if t % log_rate == 0:
        log = (f'Epoch: {t} - TL: {np.mean(t_losses):.2e}, VL: {np.mean(v_losses):.2e}')
        print(log)
    if use_scheduler:
        embedding_scheduler.step()
        decoder_scheduler.step()
    if t % save_rate == 0:
        torch.save(embedding_model.state_dict(), f'embedding_model_{version}_{t}.pth')
        torch.save(decoder_model.state_dict(), f'decoder_model_{version}_{t}.pth')
Epoch: 10 - TL: 1.66e+00, VL: 1.81e+00
Epoch: 20 - TL: 1.69e+00, VL: 1.70e+00
Epoch: 30 - TL: 1.49e+00, VL: 1.73e+00
Epoch: 40 - TL: 1.35e+00, VL: 1.32e+00
Epoch: 50 - TL: 1.49e+00, VL: 1.57e+00
Epoch: 60 - TL: 1.31e+00, VL: 1.11e+00
Epoch: 70 - TL: 1.10e+00, VL: 1.53e+00
Epoch: 80 - TL: 9.98e-01, VL: 9.29e-01
Epoch: 90 - TL: 9.39e-01, VL: 6.90e-01
Epoch: 100 - TL: 6.67e-01, VL: 6.35e-01
Epoch: 110 - TL: 3.99e-01, VL: 2.24e-01
Epoch: 120 - TL: 4.65e-01, VL: 6.83e-01
Epoch: 130 - TL: 5.81e-01, VL: 1.07e+00
Epoch: 140 - TL: 5.59e-01, VL: 3.85e-01
Epoch: 150 - TL: 4.78e-01, VL: 9.67e-01
Epoch: 160 - TL: 6.60e-01, VL: 4.83e-01
Epoch: 170 - TL: 8.24e-01, VL: 9.62e-01
Epoch: 180 - TL: 5.35e-01, VL: 5.47e-01
Epoch: 190 - TL: 6.28e-01, VL: 4.52e-01
Epoch: 200 - TL: 5.28e-01, VL: 4.18e-01
Epoch: 210 - TL: 4.18e-01, VL: 3.12e-01
Epoch: 220 - TL: 4.54e-01, VL: 4.38e-01
Epoch: 230 - TL: 4.47e-01, VL: 5.75e-01
Epoch: 240 - TL: 4.17e-01, VL: 3.26e-01
Epoch: 250 - TL: 4.96e-01, VL: 4.66e-01
Epoch: 260 - TL: 4.82e-01, VL: 5.64e-01
Epoch: 270 - TL: 4.84e-01, VL: 3.90e-01
Epoch: 280 - TL: 4.43e-01, VL: 4.28e-01
Epoch: 290 - TL: 5.33e-01, VL: 3.92e-01
Epoch: 300 - TL: 4.07e-01, VL: 4.18e-01
Epoch: 310 - TL: 4.79e-01, VL: 7.54e-01
Epoch: 320 - TL: 4.52e-01, VL: 2.68e-01
Epoch: 330 - TL: 4.03e-01, VL: 5.31e-01
Epoch: 340 - TL: 4.07e-01, VL: 4.47e-01
Epoch: 350 - TL: 4.04e-01, VL: 3.60e-01
Epoch: 360 - TL: 4.05e-01, VL: 4.04e-01
Epoch: 370 - TL: 5.12e-01, VL: 4.04e-01
Epoch: 380 - TL: 4.13e-01, VL: 5.28e-01
Epoch: 390 - TL: 4.43e-01, VL: 3.59e-01
Epoch: 400 - TL: 3.66e-01, VL: 6.18e-01
Epoch: 410 - TL: 3.58e-01, VL: 5.61e-01
Epoch: 420 - TL: 5.20e-01, VL: 4.96e-01
Epoch: 430 - TL: 3.63e-01, VL: 4.08e-01
Epoch: 440 - TL: 3.77e-01, VL: 4.85e-01
Epoch: 450 - TL: 3.66e-01, VL: 3.51e-01
Epoch: 460 - TL: 3.56e-01, VL: 2.35e-01
Epoch: 470 - TL: 3.54e-01, VL: 4.04e-01
Epoch: 480 - TL: 3.63e-01, VL: 3.85e-01
Epoch: 490 - TL: 3.54e-01, VL: 3.52e-01
Epoch: 500 - TL: 4.04e-01, VL: 2.23e-01
In [32]:
save_model = True
if save_model:
    torch.save(embedding_model.state_dict(), f'embedding_model_{version}.pth')
    torch.save(decoder_model.state_dict(), f'decoder_model_{version}.pth')

Analyze training

In [33]:
fig,((ax1,ax2,ax3),(ax4,ax5,ax6)) = plt.subplots(2,3,figsize=(16,9))
try:
    ctr  = ctr.squeeze().cpu().detach().numpy()
    qry  = qry.squeeze().cpu().detach().numpy()
    out  = out.squeeze().cpu().detach().numpy()
    goal = goal.squeeze().cpu().detach().numpy()
except AttributeError:
    pass
gm = goal.max()
imshow(ctr[0,1,...], ax1, 'CTR', 0)
imshow(qry[0,1,...], ax2, 'QRY', 0)
ax3.axis('off')
imshow(out[0], ax4, 'OUT', 0, vmin=0, vmax=gm)
imshow(goal[0], ax5, 'HM', 0, vmin=0, vmax=gm)
imshow(np.abs(out[0]-goal[0]), ax6, 'DIFF', 0, vmin=0, vmax=gm)
In [34]:
def tidy_losses(train, valid):
    out = {'epoch': [], 'type': [], 'value': [], 'phase': []}
    for i, (tl,vl) in enumerate(zip(train,valid),1):
        for tli in tl:
            out['epoch'].append(i)
            out['type'].append('loss')
            out['value'].append(tli)
            out['phase'].append('train')
        for vli in vl:
            out['epoch'].append(i)
            out['type'].append('loss')
            out['value'].append(vli)
            out['phase'].append('valid')
    return pd.DataFrame(out)
In [35]:
losses = tidy_losses(train_losses, valid_losses)
In [36]:
f, ax1 = plt.subplots(1,1,figsize=(12, 8),sharey=True)
sns.lineplot(x='epoch',y='value',hue='phase',data=losses,ci='sd',ax=ax1,lw=3);
ax1.set_yscale('log');
ax1.set_title('Losses');
In [37]:
save_losses = False
if save_losses:
    f.savefig(f'losses_{version}.pdf')
    losses.to_csv(f'losses_{version}.csv')