Combining the power of PyTorch and NiftyNet

Contents

Introduction

NiftyNet is "an open source convolutional neural networks platform for medical image analysis and image-guided therapy" built on top of TensorFlow. Due to its available implementations of successful architectures, patch-based sampling and straightforward configuration, it has become a popular choice to get started with deep learning in medical imaging.

PyTorch is "an open source deep learning platform that provides a seamless path from research prototyping to production deployment". It is low-level enough to offer a lot of control over what is going on under the hood during training, and its dynamic computational graph allows for easy debugging. Being a generic deep learning framework, it is not tailored to the needs of the medical imaging field, although its popularity in this field is increasing rapidly.

One can extend a NiftyNet application, but it is not straightforward without being familiar with the framework and fluent in TensorFlow 1.X. Therefore it can be convenient to implement applications in PyTorch using NiftyNet models and functionalities. In particular, combining both frameworks allows for fast architecture experimentation and transfer learning.

So why not use both? In this tutorial we will port the parameters of a model trained on NiftyNet to a PyTorch model and compare the results of running an inference using both frameworks.

Image segmentation

Although NiftyNet supports different applications, it is mostly used for medical image segmentation.

Image segmentation using deep learning were the 5 most common words in all full paper titles from both MICCAI 2018 and MIDL 2019, which shows the interest of the medical imaging community in the topic.

HighRes3DNet

drawing

HighRes3DNet is a residual convolutional neural network architecture designed to have a large receptive field and preserve a high resolution using a relatively small number of parameters. It was presented in 2017 by Li et al. at IPMI: On the Compactness, Efficiency, and Representation of 3D Convolutional Networks: Brain Parcellation as a Pretext Task.

HighRes3DNet

The authors used NiftyNet to implement and train a model based on this architecture to perform brain parcellation using $T_1$-weighted MR images from the ADNI dataset. They achieved competitive segmentation performance compared with state-of-the-art architectures such as DeepMedic or U-Net.

This figure from the paper shows a parcellation produced by HighRes3DNet:

Input MRI Output parcellation

The code of the architecture is on NiftyNet GitHub repository. The authors have uploaded the parameters and configuration file to the Model Zoo.

After reading the paper and the code, it is relatively straightforward to implement the same architecture using PyTorch.

Running the notebook

All the code is hosted in a GitHub repository: fepegar/miccai-educational-challenge-2019.

The latest release can also be found on the Zenodo repository under this DOI: 10.5281/zenodo.3352316.

Online

If you have a Google account, the best way to run this notebook seamlessly is using Google Colab. You will need to click on "Open in playground", at the top left:

Playground mode screenshot

You will also get a couple of warnings that you can safely ignore. The first one warns about this notebook not being authored by Google and the second one asks for confirmation to reset all runtimes. These are valid points, but will not affect this tutorial.

Open In Colab


Please report any issues on GitHub and I will fix them. You can also drop me an email if you have any questions or comments.

Locally

To write this notebook I used Ubuntu 18.04 installed on an Alienware 13 R3 laptop, which includes a 6-GB GeForce GTX 1060 NVIDIA GPU. I am using CUDA 9.0.

Inference using PyTorch took 5725 MB of GPU memory. TensorFlow usually takes as much as possible beforehand.

To run this notebook locally, I recommend downloading the repository and creating a conda environment:

git clone https://github.com/fepegar/miccai-educational-challenge-2019.git
cd miccai-educational-challenge-2019
conda create -n mec2019 python=3.6 -y
conda activate mec2019 && conda install jupyterlab -y && jupyter lab

nbviewer

An already executed version of the notebook can be rendered using nbviewer.

Interactive volume plots

If you run the notebook, you can use interactive plots to navigate through the volume slices by setting this variable to True. You might need to run the volume visualization cells individually after running the whole notebook. This feature is experimental and therefore disabled by default.

In [0]:
interactive_plots = False

Setup

Install and import libraries

Clone NiftyNet and some custom Python libraries for this notebook. This might take one or two minutes.

In [0]:
%%capture --no-stderr
# This might take about 30 seconds
!rm -rf NiftyNet && git clone https://github.com/NifTK/NiftyNet --depth 1
!cd NiftyNet && git checkout df0f86733357fdc92bbc191c8fec0dcf49aa5499 && cd ..
!pip install -r NiftyNet/requirements-gpu.txt
!curl -O https://raw.githubusercontent.com/fepegar/miccai-educational-challenge-2019/master/requirements.txt
!curl -O https://raw.githubusercontent.com/fepegar/miccai-educational-challenge-2019/master/tf2pt.py
!curl -O https://raw.githubusercontent.com/fepegar/miccai-educational-challenge-2019/master/utils.py
!curl -O https://raw.githubusercontent.com/fepegar/miccai-educational-challenge-2019/master/visualization.py
!curl -O https://raw.githubusercontent.com/fepegar/miccai-educational-challenge-2019/master/highresnet_mapping.py
!curl -O https://raw.githubusercontent.com/fepegar/highresnet/master/GIFNiftyNet.ctbl
!pip install -r requirements.txt
!pip install --upgrade numpy
!pip install ipywidgets
import sys
sys.path.insert(0, 'NiftyNet')
In [3]:
%matplotlib inline
%config InlineBackend.figure_format='retina'

import os
import tempfile
from pathlib import Path
from configparser import ConfigParser

import numpy as np
import pandas as pd

try:
    # Fancy rendering of Pandas tables
    import google.colab.data_table
    %load_ext google.colab.data_table
    print("We are on Google Colab")
except ModuleNotFoundError:
    print("We are not on Google Colab")
    pd.set_option('display.max_colwidth', -1)  # do not truncate strings when displaying data frames
    pd.set_option('display.max_rows', None)    # show all rows

import torch

from highresnet import HighRes3DNet
We are on Google Colab
In [0]:
%%capture
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from tensorflow.python.util import deprecation
deprecation._PRINT_DEPRECATION_WARNINGS = False

import tf2pt
import utils
import visualization
import highresnet_mapping

if interactive_plots:  # for Colab or Jupyter
    plot_volume = visualization.plot_volume_interactive
else:  # for HTML, GitHub or nbviewer
    plot_volume = visualization.plot_volume

from niftynet.io.image_reader import ImageReader
from niftynet.engine.sampler_grid_v2 import GridSampler
from niftynet.engine.windows_aggregator_grid import GridSamplesAggregator
from niftynet.layer.pad import PadLayer
from niftynet.layer.binary_masking import BinaryMaskingLayer
from niftynet.layer.histogram_normalisation import HistogramNormalisationLayer
from niftynet.layer.mean_variance_normalisation import MeanVarNormalisationLayer

Download NiftyNet model and test data

We can use NiftyNet's net_download to get all we need from the Model Zoo entry corresponding to brain parcellation using HighRes3DNet:

In [0]:
%%capture
%run NiftyNet/net_download.py highres3dnet_brain_parcellation_model_zoo
In [6]:
niftynet_dir = Path('~/niftynet').expanduser()
utils.list_files(niftynet_dir)
niftynet/
    data/
        OASIS/
            license
            OAS1_0145_MR2_mpr_n4_anon_sbj_111.nii.gz
    models/
        highres3dnet_brain_parcellation/
            inference_niftynet_log
            databrain_std_hist_models_otsu.txt
            settings_inference.txt
            Modality0.csv
            parcellation_output/
                window_seg_OAS1_0145_MR2_mpr_n4_anon_sbj_111__niftynet_out.nii.gz
                inferred.csv
            logs/
            models/
                model.ckpt-33000.meta
                model.ckpt-33000.data-00000-of-00001
                model.ckpt-33000.index
    extensions/
        __init__.py
        highres3dnet_brain_parcellation/
            __init__.py
            highres3dnet_config_eval.ini
        network/
            __init__.py

There are three directories under ~/niftynet:

  1. extensions is a Python package that contains the [configuration file].(https://niftynet.readthedocs.io/en/dev/config_spec.html)
  2. models contains the landmarks for histogram standardization (an MRI preprocessing step) and the network parameters.
  3. data contains an OASIS MRI sample that can be used to test the model.

Here are the paths to the downloaded files and to the files that will be generated by the notebook. I use nn for NiftyNet, tf for TensorFlow and pt for PyTorch.

In [0]:
models_dir = niftynet_dir / 'models'
zoo_entry = 'highres3dnet_brain_parcellation'
input_checkpoint_tf_name = 'model.ckpt-33000'
input_checkpoint_tf_path = models_dir / zoo_entry / 'models' / input_checkpoint_tf_name
data_dir = niftynet_dir / 'data' / 'OASIS'
config_path = niftynet_dir / 'extensions' / zoo_entry / 'highres3dnet_config_eval.ini'
histogram_landmarks_path = models_dir / zoo_entry / 'databrain_std_hist_models_otsu.txt'
tempdir = Path(tempfile.gettempdir()) / 'miccai_niftynet_pytorch'
tempdir.mkdir(exist_ok=True)
output_csv_tf_path = tempdir / 'variables_tf.csv'
output_state_dict_tf_path = tempdir / 'state_dict_tf.pth'
output_state_dict_pt_path = tempdir / 'state_dict_pt.pth'
prediction_pt_dir = tempdir / 'prediction'
prediction_pt_dir.mkdir(exist_ok=True)
color_table_path = 'GIFNiftyNet.ctbl'

Note that the path to the checkpoint is not a path to an existing filename, but the basename of the three checkpoint files.

Transferring parameters from NiftyNet to PyTorch

Variables in TensorFlow world

drawing

There are two modules that are relevant for this section in the repository. tf2pt contains generic functions that can be used to transform any TensorFlow model to PyTorch. highresnet_mapping contains custom functions that are specific for HighRes3DNet.

Let's see what variables are stored in the checkpoint.

These are filtered out by highresnet_mapping.is_not_valid() for clarity:

  • Variables used by the Adam optimizer during training
  • Variables with no shape. They won't help much.
  • Variables containing biased or ExponentialMovingAverage. I have experimented with them and the results using these variables turned out to be different to the ones generated by NiftyNet.

We will store the variables names in a data frame to list them in this notebook and the values in a Python dictionary to retrieve them later. I figured out the code in tf2pt.checkpoint_tf_to_state_dict_tf() reading the corresponding TensorFlow docs and Stack Overflow answers.

In [8]:
# %%capture
tf2pt.checkpoint_tf_to_state_dict_tf(
    input_checkpoint_tf_path=input_checkpoint_tf_path,
    output_csv_tf_path=output_csv_tf_path,
    output_state_dict_tf_path=output_state_dict_tf_path,
    filter_out_function=highresnet_mapping.is_not_valid,
    replace_string='HighRes3DNet/',
)
data_frame_tf = pd.read_csv(output_csv_tf_path)
state_dict_tf = torch.load(output_state_dict_tf_path)
W0824 12:06:24.860505 140162299176832 deprecation_wrapper.py:119] From /content/tf2pt.py:106: The name tf.reset_default_graph is deprecated. Please use tf.compat.v1.reset_default_graph instead.

W0824 12:06:24.872821 140162299176832 deprecation_wrapper.py:119] From /content/tf2pt.py:114: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.

W0824 12:06:25.661258 140162299176832 deprecation_wrapper.py:119] From /content/tf2pt.py:122: The name tf.train.Saver is deprecated. Please use tf.compat.v1.train.Saver instead.

W0824 12:06:25.749863 140162299176832 deprecation_wrapper.py:119] From /content/tf2pt.py:124: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.

I0824 12:06:29.631742 140162299176832 saver.py:1280] Restoring parameters from /root/niftynet/models/highres3dnet_brain_parcellation/models/model.ckpt-33000
In [9]:
data_frame_tf
Out[9]:
Unnamed: 0 name shape
0 0 conv_0_bn_relu/bn_/beta 16
1 1 conv_0_bn_relu/bn_/gamma 16
2 2 conv_0_bn_relu/bn_/moving_mean 16
3 3 conv_0_bn_relu/bn_/moving_variance 16
4 4 conv_0_bn_relu/conv_/w 3, 3, 3, 1, 16
5 5 conv_1_bn_relu/bn_/beta 80
6 6 conv_1_bn_relu/bn_/gamma 80
7 7 conv_1_bn_relu/bn_/moving_mean 80
8 8 conv_1_bn_relu/bn_/moving_variance 80
9 9 conv_1_bn_relu/conv_/w 1, 1, 1, 64, 80
10 10 conv_2_bn/bn_/beta 160
11 11 conv_2_bn/bn_/gamma 160
12 12 conv_2_bn/bn_/moving_mean 160
13 13 conv_2_bn/bn_/moving_variance 160
14 14 conv_2_bn/conv_/w 1, 1, 1, 80, 160
15 15 res_1_0/bn_0/beta 16
16 16 res_1_0/bn_0/gamma 16
17 17 res_1_0/bn_0/moving_mean 16
18 18 res_1_0/bn_0/moving_variance 16
19 19 res_1_0/bn_1/beta 16
20 20 res_1_0/bn_1/gamma 16
21 21 res_1_0/bn_1/moving_mean 16
22 22 res_1_0/bn_1/moving_variance 16
23 23 res_1_0/conv_0/w 3, 3, 3, 16, 16
24 24 res_1_0/conv_1/w 3, 3, 3, 16, 16
25 25 res_1_1/bn_0/beta 16
26 26 res_1_1/bn_0/gamma 16
27 27 res_1_1/bn_0/moving_mean 16
28 28 res_1_1/bn_0/moving_variance 16
29 29 res_1_1/bn_1/beta 16
... ... ... ...
75 75 res_3_0/bn_0/beta 32
76 76 res_3_0/bn_0/gamma 32
77 77 res_3_0/bn_0/moving_mean 32
78 78 res_3_0/bn_0/moving_variance 32
79 79 res_3_0/bn_1/beta 64
80 80 res_3_0/bn_1/gamma 64
81 81 res_3_0/bn_1/moving_mean 64
82 82 res_3_0/bn_1/moving_variance 64
83 83 res_3_0/conv_0/w 3, 3, 3, 32, 64
84 84 res_3_0/conv_1/w 3, 3, 3, 64, 64
85 85 res_3_1/bn_0/beta 64
86 86 res_3_1/bn_0/gamma 64
87 87 res_3_1/bn_0/moving_mean 64
88 88 res_3_1/bn_0/moving_variance 64
89 89 res_3_1/bn_1/beta 64
90 90 res_3_1/bn_1/gamma 64
91 91 res_3_1/bn_1/moving_mean 64
92 92 res_3_1/bn_1/moving_variance 64
93 93 res_3_1/conv_0/w 3, 3, 3, 64, 64
94 94 res_3_1/conv_1/w 3, 3, 3, 64, 64
95 95 res_3_2/bn_0/beta 64
96 96 res_3_2/bn_0/gamma 64
97 97 res_3_2/bn_0/moving_mean 64
98 98 res_3_2/bn_0/moving_variance 64
99 99 res_3_2/bn_1/beta 64
100 100 res_3_2/bn_1/gamma 64
101 101 res_3_2/bn_1/moving_mean 64
102 102 res_3_2/bn_1/moving_variance 64
103 103 res_3_2/conv_0/w 3, 3, 3, 64, 64
104 104 res_3_2/conv_1/w 3, 3, 3, 64, 64

105 rows × 3 columns

The weight parameters associated with each convolutional layer, denoted with conv_/w, are stored with shape representing the three spatial dimensions, the input channels and the output channels: $(Depth, Height, Width, Channels_{in}, Channels_{out})$. Calling the spatial dimensions depth, height and width does not make a lot of sense when dealing with 3D medical images, but we will keep this computer vision terminology as it is consistent with the documentation of both PyTorch and TensorFlow.

The layer names and parameter shapes are coherent overall with the figure in the HighRes3DNet paper, but there is an additional $1 \times 1 \times 1$ convolutional layer with 80 output channels, which is also in the code. It seems to be the model with dropout from the paper that achieved the highest performance, so our implementation of the architecture should include this layer as well.

There are three blocks with increasing kernel dilation composed of three residual blocks each, which have two convolutional layers inside. That's $3 \times 3 \times 2 = 18$ layers. The other three convolutional layers are the initial convolution before the first residual block, a convolution before dropout and a convolution to expand the channels to the number of output classes.

Apparently, all the convolutional layers have an associated batch normalization layer, which differs from the figure in the paper. That makes 21 convolutional layers and 21 batch normalization layers whose parameters must be transferred.

Each batch normalization layer contains 4 parameter groups: moving mean $\mathrm{E}[x]$, variance $\mathrm{Var}[x]$ and the affine transformation parameters $\gamma$ (scale or weight) and $\beta$ (shift or bias):

\begin{align} y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta \end{align}

Therefore the total number of parameter groups is $21 + 21 \times 4 = 105$. The convolutional layers don't use a bias parameter, as it is not necessary when using batch norm.

Variables in PyTorch world

drawing

To match the model in the paper, we set the number of classes to 160 and enable the flag to add the dropout layer.

In [0]:
num_input_modalities = 1
num_classes = 160
model = HighRes3DNet(num_input_modalities, num_classes, add_dropout_layer=True)

Let's see what the variable names created by PyTorch are:

In [11]:
state_dict_pt = model.state_dict()
rows = []
for name, parameters in state_dict_pt.items():
    if 'num_batches_tracked' in name:  # filter out for clarity
        continue
    shape = ', '.join(str(n) for n in parameters.shape)
    row = {'name': name, 'shape': shape}
    rows.append(row)
df_pt = pd.DataFrame.from_dict(rows)
df_pt
Out[11]:
name shape
0 block.0.convolutional_block.1.weight 16, 1, 3, 3, 3
1 block.0.convolutional_block.2.weight 16
2 block.0.convolutional_block.2.bias 16
3 block.0.convolutional_block.2.running_mean 16
4 block.0.convolutional_block.2.running_var 16
5 block.1.dilation_block.0.residual_block.0.conv... 16
6 block.1.dilation_block.0.residual_block.0.conv... 16
7 block.1.dilation_block.0.residual_block.0.conv... 16
8 block.1.dilation_block.0.residual_block.0.conv... 16
9 block.1.dilation_block.0.residual_block.0.conv... 16, 16, 3, 3, 3
10 block.1.dilation_block.0.residual_block.1.conv... 16
11 block.1.dilation_block.0.residual_block.1.conv... 16
12 block.1.dilation_block.0.residual_block.1.conv... 16
13 block.1.dilation_block.0.residual_block.1.conv... 16
14 block.1.dilation_block.0.residual_block.1.conv... 16, 16, 3, 3, 3
15 block.1.dilation_block.1.residual_block.0.conv... 16
16 block.1.dilation_block.1.residual_block.0.conv... 16
17 block.1.dilation_block.1.residual_block.0.conv... 16
18 block.1.dilation_block.1.residual_block.0.conv... 16
19 block.1.dilation_block.1.residual_block.0.conv... 16, 16, 3, 3, 3
20 block.1.dilation_block.1.residual_block.1.conv... 16
21 block.1.dilation_block.1.residual_block.1.conv... 16
22 block.1.dilation_block.1.residual_block.1.conv... 16
23 block.1.dilation_block.1.residual_block.1.conv... 16
24 block.1.dilation_block.1.residual_block.1.conv... 16, 16, 3, 3, 3
25 block.1.dilation_block.2.residual_block.0.conv... 16
26 block.1.dilation_block.2.residual_block.0.conv... 16
27 block.1.dilation_block.2.residual_block.0.conv... 16
28 block.1.dilation_block.2.residual_block.0.conv... 16
29 block.1.dilation_block.2.residual_block.0.conv... 16, 16, 3, 3, 3
... ... ...
75 block.3.dilation_block.1.residual_block.0.conv... 64
76 block.3.dilation_block.1.residual_block.0.conv... 64
77 block.3.dilation_block.1.residual_block.0.conv... 64
78 block.3.dilation_block.1.residual_block.0.conv... 64
79 block.3.dilation_block.1.residual_block.0.conv... 64, 64, 3, 3, 3
80 block.3.dilation_block.1.residual_block.1.conv... 64
81 block.3.dilation_block.1.residual_block.1.conv... 64
82 block.3.dilation_block.1.residual_block.1.conv... 64
83 block.3.dilation_block.1.residual_block.1.conv... 64
84 block.3.dilation_block.1.residual_block.1.conv... 64, 64, 3, 3, 3
85 block.3.dilation_block.2.residual_block.0.conv... 64
86 block.3.dilation_block.2.residual_block.0.conv... 64
87 block.3.dilation_block.2.residual_block.0.conv... 64
88 block.3.dilation_block.2.residual_block.0.conv... 64
89 block.3.dilation_block.2.residual_block.0.conv... 64, 64, 3, 3, 3
90 block.3.dilation_block.2.residual_block.1.conv... 64
91 block.3.dilation_block.2.residual_block.1.conv... 64
92 block.3.dilation_block.2.residual_block.1.conv... 64
93 block.3.dilation_block.2.residual_block.1.conv... 64
94 block.3.dilation_block.2.residual_block.1.conv... 64, 64, 3, 3, 3
95 block.4.convolutional_block.0.weight 80, 64, 1, 1, 1
96 block.4.convolutional_block.1.weight 80
97 block.4.convolutional_block.1.bias 80
98 block.4.convolutional_block.1.running_mean 80
99 block.4.convolutional_block.1.running_var 80
100 block.6.convolutional_block.0.weight 160, 80, 1, 1, 1
101 block.6.convolutional_block.1.weight 160
102 block.6.convolutional_block.1.bias 160
103 block.6.convolutional_block.1.running_mean 160
104 block.6.convolutional_block.1.running_var 160

105 rows × 2 columns

We can see that moving_mean and moving_variance are called running_mean and running_var in PyTorch. Also, $\gamma$ and $\beta$ are called weight and bias.

The convolutional kernels have a different arrangement: $(Channels_{out}, Channels_{in}, Depth, Height, Width)$.

The names and shapes look consistent between both implementations and there are 105 lines in both lists, therefore we should be able to create a mapping between the TensorFlow and PyTorch variables. The function tf2pt.tf2pt() receives a TensorFlow-like variable and returns the corresponding PyTorch-like variable.

In [12]:
for name_tf, tensor_tf in state_dict_tf.items():
    shape_tf = tuple(tensor_tf.shape)
    print(f'{str(shape_tf):18}', name_tf) 
    
    # Convert TensorFlow name to PyTorch name
    mapping_function = highresnet_mapping.tf2pt_name
    name_pt, tensor_pt = tf2pt.tf2pt(name_tf, tensor_tf, mapping_function)
    
    shape_pt = tuple(state_dict_pt[name_pt].shape)
    print(f'{str(shape_pt):18}', name_pt)
    
    # Sanity check
    if sum(shape_tf) != sum(shape_pt):
        raise ValueError
       
    # Add weights to PyTorch state dictionary
    state_dict_pt[name_pt] = tensor_pt
    print()
torch.save(state_dict_pt, output_state_dict_pt_path)
print('State dictionary saved to', output_state_dict_pt_path)
(16,)              conv_0_bn_relu/bn_/beta
(16,)              block.0.convolutional_block.2.bias

(16,)              conv_0_bn_relu/bn_/gamma
(16,)              block.0.convolutional_block.2.weight

(16,)              conv_0_bn_relu/bn_/moving_mean
(16,)              block.0.convolutional_block.2.running_mean

(16,)              conv_0_bn_relu/bn_/moving_variance
(16,)              block.0.convolutional_block.2.running_var

(3, 3, 3, 1, 16)   conv_0_bn_relu/conv_/w
(16, 1, 3, 3, 3)   block.0.convolutional_block.1.weight

(80,)              conv_1_bn_relu/bn_/beta
(80,)              block.4.convolutional_block.1.bias

(80,)              conv_1_bn_relu/bn_/gamma
(80,)              block.4.convolutional_block.1.weight

(80,)              conv_1_bn_relu/bn_/moving_mean
(80,)              block.4.convolutional_block.1.running_mean

(80,)              conv_1_bn_relu/bn_/moving_variance
(80,)              block.4.convolutional_block.1.running_var

(1, 1, 1, 64, 80)  conv_1_bn_relu/conv_/w
(80, 64, 1, 1, 1)  block.4.convolutional_block.0.weight

(160,)             conv_2_bn/bn_/beta
(160,)             block.6.convolutional_block.1.bias

(160,)             conv_2_bn/bn_/gamma
(160,)             block.6.convolutional_block.1.weight

(160,)             conv_2_bn/bn_/moving_mean
(160,)             block.6.convolutional_block.1.running_mean

(160,)             conv_2_bn/bn_/moving_variance
(160,)             block.6.convolutional_block.1.running_var

(1, 1, 1, 80, 160) conv_2_bn/conv_/w
(160, 80, 1, 1, 1) block.6.convolutional_block.0.weight

(16,)              res_1_0/bn_0/beta
(16,)              block.1.dilation_block.0.residual_block.0.convolutional_block.0.bias

(16,)              res_1_0/bn_0/gamma
(16,)              block.1.dilation_block.0.residual_block.0.convolutional_block.0.weight

(16,)              res_1_0/bn_0/moving_mean
(16,)              block.1.dilation_block.0.residual_block.0.convolutional_block.0.running_mean

(16,)              res_1_0/bn_0/moving_variance
(16,)              block.1.dilation_block.0.residual_block.0.convolutional_block.0.running_var

(16,)              res_1_0/bn_1/beta
(16,)              block.1.dilation_block.0.residual_block.1.convolutional_block.0.bias

(16,)              res_1_0/bn_1/gamma
(16,)              block.1.dilation_block.0.residual_block.1.convolutional_block.0.weight

(16,)              res_1_0/bn_1/moving_mean
(16,)              block.1.dilation_block.0.residual_block.1.convolutional_block.0.running_mean

(16,)              res_1_0/bn_1/moving_variance
(16,)              block.1.dilation_block.0.residual_block.1.convolutional_block.0.running_var

(3, 3, 3, 16, 16)  res_1_0/conv_0/w
(16, 16, 3, 3, 3)  block.1.dilation_block.0.residual_block.0.convolutional_block.3.weight

(3, 3, 3, 16, 16)  res_1_0/conv_1/w
(16, 16, 3, 3, 3)  block.1.dilation_block.0.residual_block.1.convolutional_block.3.weight

(16,)              res_1_1/bn_0/beta
(16,)              block.1.dilation_block.1.residual_block.0.convolutional_block.0.bias

(16,)              res_1_1/bn_0/gamma
(16,)              block.1.dilation_block.1.residual_block.0.convolutional_block.0.weight

(16,)              res_1_1/bn_0/moving_mean
(16,)              block.1.dilation_block.1.residual_block.0.convolutional_block.0.running_mean

(16,)              res_1_1/bn_0/moving_variance
(16,)              block.1.dilation_block.1.residual_block.0.convolutional_block.0.running_var

(16,)              res_1_1/bn_1/beta
(16,)              block.1.dilation_block.1.residual_block.1.convolutional_block.0.bias

(16,)              res_1_1/bn_1/gamma
(16,)              block.1.dilation_block.1.residual_block.1.convolutional_block.0.weight

(16,)              res_1_1/bn_1/moving_mean
(16,)              block.1.dilation_block.1.residual_block.1.convolutional_block.0.running_mean

(16,)              res_1_1/bn_1/moving_variance
(16,)              block.1.dilation_block.1.residual_block.1.convolutional_block.0.running_var

(3, 3, 3, 16, 16)  res_1_1/conv_0/w
(16, 16, 3, 3, 3)  block.1.dilation_block.1.residual_block.0.convolutional_block.3.weight

(3, 3, 3, 16, 16)  res_1_1/conv_1/w
(16, 16, 3, 3, 3)  block.1.dilation_block.1.residual_block.1.convolutional_block.3.weight

(16,)              res_1_2/bn_0/beta
(16,)              block.1.dilation_block.2.residual_block.0.convolutional_block.0.bias

(16,)              res_1_2/bn_0/gamma
(16,)              block.1.dilation_block.2.residual_block.0.convolutional_block.0.weight

(16,)              res_1_2/bn_0/moving_mean
(16,)              block.1.dilation_block.2.residual_block.0.convolutional_block.0.running_mean

(16,)              res_1_2/bn_0/moving_variance
(16,)              block.1.dilation_block.2.residual_block.0.convolutional_block.0.running_var

(16,)              res_1_2/bn_1/beta
(16,)              block.1.dilation_block.2.residual_block.1.convolutional_block.0.bias

(16,)              res_1_2/bn_1/gamma
(16,)              block.1.dilation_block.2.residual_block.1.convolutional_block.0.weight

(16,)              res_1_2/bn_1/moving_mean
(16,)              block.1.dilation_block.2.residual_block.1.convolutional_block.0.running_mean

(16,)              res_1_2/bn_1/moving_variance
(16,)              block.1.dilation_block.2.residual_block.1.convolutional_block.0.running_var

(3, 3, 3, 16, 16)  res_1_2/conv_0/w
(16, 16, 3, 3, 3)  block.1.dilation_block.2.residual_block.0.convolutional_block.3.weight

(3, 3, 3, 16, 16)  res_1_2/conv_1/w
(16, 16, 3, 3, 3)  block.1.dilation_block.2.residual_block.1.convolutional_block.3.weight

(16,)              res_2_0/bn_0/beta
(16,)              block.2.dilation_block.0.residual_block.0.convolutional_block.0.bias

(16,)              res_2_0/bn_0/gamma
(16,)              block.2.dilation_block.0.residual_block.0.convolutional_block.0.weight

(16,)              res_2_0/bn_0/moving_mean
(16,)              block.2.dilation_block.0.residual_block.0.convolutional_block.0.running_mean

(16,)              res_2_0/bn_0/moving_variance
(16,)              block.2.dilation_block.0.residual_block.0.convolutional_block.0.running_var

(32,)              res_2_0/bn_1/beta
(32,)              block.2.dilation_block.0.residual_block.1.convolutional_block.0.bias

(32,)              res_2_0/bn_1/gamma
(32,)              block.2.dilation_block.0.residual_block.1.convolutional_block.0.weight

(32,)              res_2_0/bn_1/moving_mean
(32,)              block.2.dilation_block.0.residual_block.1.convolutional_block.0.running_mean

(32,)              res_2_0/bn_1/moving_variance
(32,)              block.2.dilation_block.0.residual_block.1.convolutional_block.0.running_var

(3, 3, 3, 16, 32)  res_2_0/conv_0/w
(32, 16, 3, 3, 3)  block.2.dilation_block.0.residual_block.0.convolutional_block.3.weight

(3, 3, 3, 32, 32)  res_2_0/conv_1/w
(32, 32, 3, 3, 3)  block.2.dilation_block.0.residual_block.1.convolutional_block.3.weight

(32,)              res_2_1/bn_0/beta
(32,)              block.2.dilation_block.1.residual_block.0.convolutional_block.0.bias

(32,)              res_2_1/bn_0/gamma
(32,)              block.2.dilation_block.1.residual_block.0.convolutional_block.0.weight

(32,)              res_2_1/bn_0/moving_mean
(32,)              block.2.dilation_block.1.residual_block.0.convolutional_block.0.running_mean

(32,)              res_2_1/bn_0/moving_variance
(32,)              block.2.dilation_block.1.residual_block.0.convolutional_block.0.running_var

(32,)              res_2_1/bn_1/beta
(32,)              block.2.dilation_block.1.residual_block.1.convolutional_block.0.bias

(32,)              res_2_1/bn_1/gamma
(32,)              block.2.dilation_block.1.residual_block.1.convolutional_block.0.weight

(32,)              res_2_1/bn_1/moving_mean
(32,)              block.2.dilation_block.1.residual_block.1.convolutional_block.0.running_mean

(32,)              res_2_1/bn_1/moving_variance
(32,)              block.2.dilation_block.1.residual_block.1.convolutional_block.0.running_var

(3, 3, 3, 32, 32)  res_2_1/conv_0/w
(32, 32, 3, 3, 3)  block.2.dilation_block.1.residual_block.0.convolutional_block.3.weight

(3, 3, 3, 32, 32)  res_2_1/conv_1/w
(32, 32, 3, 3, 3)  block.2.dilation_block.1.residual_block.1.convolutional_block.3.weight

(32,)              res_2_2/bn_0/beta
(32,)              block.2.dilation_block.2.residual_block.0.convolutional_block.0.bias

(32,)              res_2_2/bn_0/gamma
(32,)              block.2.dilation_block.2.residual_block.0.convolutional_block.0.weight

(32,)              res_2_2/bn_0/moving_mean
(32,)              block.2.dilation_block.2.residual_block.0.convolutional_block.0.running_mean

(32,)              res_2_2/bn_0/moving_variance
(32,)              block.2.dilation_block.2.residual_block.0.convolutional_block.0.running_var

(32,)              res_2_2/bn_1/beta
(32,)              block.2.dilation_block.2.residual_block.1.convolutional_block.0.bias

(32,)              res_2_2/bn_1/gamma
(32,)              block.2.dilation_block.2.residual_block.1.convolutional_block.0.weight

(32,)              res_2_2/bn_1/moving_mean
(32,)              block.2.dilation_block.2.residual_block.1.convolutional_block.0.running_mean

(32,)              res_2_2/bn_1/moving_variance
(32,)              block.2.dilation_block.2.residual_block.1.convolutional_block.0.running_var

(3, 3, 3, 32, 32)  res_2_2/conv_0/w
(32, 32, 3, 3, 3)  block.2.dilation_block.2.residual_block.0.convolutional_block.3.weight

(3, 3, 3, 32, 32)  res_2_2/conv_1/w
(32, 32, 3, 3, 3)  block.2.dilation_block.2.residual_block.1.convolutional_block.3.weight

(32,)              res_3_0/bn_0/beta
(32,)              block.3.dilation_block.0.residual_block.0.convolutional_block.0.bias

(32,)              res_3_0/bn_0/gamma
(32,)              block.3.dilation_block.0.residual_block.0.convolutional_block.0.weight

(32,)              res_3_0/bn_0/moving_mean
(32,)              block.3.dilation_block.0.residual_block.0.convolutional_block.0.running_mean

(32,)              res_3_0/bn_0/moving_variance
(32,)              block.3.dilation_block.0.residual_block.0.convolutional_block.0.running_var

(64,)              res_3_0/bn_1/beta
(64,)              block.3.dilation_block.0.residual_block.1.convolutional_block.0.bias

(64,)              res_3_0/bn_1/gamma
(64,)              block.3.dilation_block.0.residual_block.1.convolutional_block.0.weight

(64,)              res_3_0/bn_1/moving_mean
(64,)              block.3.dilation_block.0.residual_block.1.convolutional_block.0.running_mean

(64,)              res_3_0/bn_1/moving_variance
(64,)              block.3.dilation_block.0.residual_block.1.convolutional_block.0.running_var

(3, 3, 3, 32, 64)  res_3_0/conv_0/w
(64, 32, 3, 3, 3)  block.3.dilation_block.0.residual_block.0.convolutional_block.3.weight

(3, 3, 3, 64, 64)  res_3_0/conv_1/w
(64, 64, 3, 3, 3)  block.3.dilation_block.0.residual_block.1.convolutional_block.3.weight

(64,)              res_3_1/bn_0/beta
(64,)              block.3.dilation_block.1.residual_block.0.convolutional_block.0.bias

(64,)              res_3_1/bn_0/gamma
(64,)              block.3.dilation_block.1.residual_block.0.convolutional_block.0.weight

(64,)              res_3_1/bn_0/moving_mean
(64,)              block.3.dilation_block.1.residual_block.0.convolutional_block.0.running_mean

(64,)              res_3_1/bn_0/moving_variance
(64,)              block.3.dilation_block.1.residual_block.0.convolutional_block.0.running_var

(64,)              res_3_1/bn_1/beta
(64,)              block.3.dilation_block.1.residual_block.1.convolutional_block.0.bias

(64,)              res_3_1/bn_1/gamma
(64,)              block.3.dilation_block.1.residual_block.1.convolutional_block.0.weight

(64,)              res_3_1/bn_1/moving_mean
(64,)              block.3.dilation_block.1.residual_block.1.convolutional_block.0.running_mean

(64,)              res_3_1/bn_1/moving_variance
(64,)              block.3.dilation_block.1.residual_block.1.convolutional_block.0.running_var

(3, 3, 3, 64, 64)  res_3_1/conv_0/w
(64, 64, 3, 3, 3)  block.3.dilation_block.1.residual_block.0.convolutional_block.3.weight

(3, 3, 3, 64, 64)  res_3_1/conv_1/w
(64, 64, 3, 3, 3)  block.3.dilation_block.1.residual_block.1.convolutional_block.3.weight

(64,)              res_3_2/bn_0/beta
(64,)              block.3.dilation_block.2.residual_block.0.convolutional_block.0.bias

(64,)              res_3_2/bn_0/gamma
(64,)              block.3.dilation_block.2.residual_block.0.convolutional_block.0.weight

(64,)              res_3_2/bn_0/moving_mean
(64,)              block.3.dilation_block.2.residual_block.0.convolutional_block.0.running_mean

(64,)              res_3_2/bn_0/moving_variance
(64,)              block.3.dilation_block.2.residual_block.0.convolutional_block.0.running_var

(64,)              res_3_2/bn_1/beta
(64,)              block.3.dilation_block.2.residual_block.1.convolutional_block.0.bias

(64,)              res_3_2/bn_1/gamma
(64,)              block.3.dilation_block.2.residual_block.1.convolutional_block.0.weight

(64,)              res_3_2/bn_1/moving_mean
(64,)              block.3.dilation_block.2.residual_block.1.convolutional_block.0.running_mean

(64,)              res_3_2/bn_1/moving_variance
(64,)              block.3.dilation_block.2.residual_block.1.convolutional_block.0.running_var

(3, 3, 3, 64, 64)  res_3_2/conv_0/w
(64, 64, 3, 3, 3)  block.3.dilation_block.2.residual_block.0.convolutional_block.3.weight

(3, 3, 3, 64, 64)  res_3_2/conv_1/w
(64, 64, 3, 3, 3)  block.3.dilation_block.2.residual_block.1.convolutional_block.3.weight

State dictionary saved to /tmp/miccai_niftynet_pytorch/state_dict_pt.pth

If PyTorch is happy when loading our state dict into the model, we should be on the right track 🤞...

In [13]:
model.load_state_dict(state_dict_pt)
Out[13]:
IncompatibleKeys(missing_keys=[], unexpected_keys=[])

No incompatible keys. Yay! 🎉

Plotting weights with PyTorch

Something great about PyTorch is that the model parameters are easily accessible. Let's plot some of them before and after training:

In [0]:
model_initial = HighRes3DNet(num_input_modalities, num_classes, add_dropout_layer=True)
model_pretrained = model

By default, convolutional layers in PyTorch are initialized using He uniform variance scaling. These are the probability density functions (PDFs) of the kernel parameters of each convolutional layer. Note how the domain of each function change with the corresponding input size at that layer.

In [15]:
visualization.plot_all_parameters(model_initial)

This is what the PDFs look like after training:

In [16]:
visualization.plot_all_parameters(model_pretrained)

Testing the model

The last step is to test the PyTorch model. We will preprocess the image according to the configuration file, initialize the reader, sampler and aggregator, run the inference, and verify that results are consistent between NiftyNet and PyTorch.

Configuration file

[Modality0]
path_to_search = data/OASIS/
filename_contains = nii
pixdim = (1.0, 1.0, 1.0)
axcodes = (R, A, S)

[NETWORK]
name = highres3dnet
volume_padding_size = 10
whitening = True
normalisation = True
normalise_foreground_only=True
foreground_type = mean_plus
histogram_ref_file = databrain_std_hist_models_otsu.txt
cutoff = (0.001, 0.999)

[INFERENCE]
border = 2
spatial_window_size = (128, 128, 128)

We need to match the configuration used during training in order to obtain consistent results. These are the relevant contents of the downloaded configuration file:

In [0]:
config = ConfigParser()
config.read(config_path);

Reader

The necessary preprocessing is described in the paper, code and configuration file.

NiftyNet offers some powerful I/O tools. We will use its readers, samplers and aggregators to read, preprocess and write all the files. There are multiple demos in the NiftyNet repository that show the usage of these modules.

In [0]:
%%capture
input_dict = dict(
    path_to_search=str(data_dir),
    filename_contains='nii',
    axcodes=('R', 'A', 'S'),
    pixdim=(1, 1, 1),
)
data_parameters = {
    'image': input_dict,
}
reader = ImageReader().initialise(data_parameters)
In [19]:
_, image_data_dict, _ = reader()
original_image = image_data_dict['image']
original_image.shape
Out[19]:
(160, 256, 256, 1, 1)

Looking at the shape of our image and knowing that the reader reoriented it into RAS+ orientation, we can see that it represents $160$ sagittal slices of $256 \times 256$ pixels, with $1$ channel (monomodal) and $1$ time point. Let's see what it looks like:

In [20]:
plot_volume(original_image, title='Original volume')
In [21]:
visualization.plot_histogram(original_image, kde=False, add_labels=True, ylim=(0, 1e6))

Preprocessing

We pad the input volume and crop the output volume to reduce the border effect introduced by the padded convolutions:

In [0]:
volume_padding_layer = PadLayer(
    image_name=['image'],  # https://github.com/NifTK/NiftyNet/blob/61f2a8bbac1348591412c00f55d1c19b91c0367f/niftynet/layer/pad.py#L52
    border=(10, 10, 10),
)

We use a masking function in order to use only the foreground voxels for normalization:

In [23]:
binary_masking_func = BinaryMaskingLayer(type_str=config['NETWORK']['foreground_type'])
mask = binary_masking_func(original_image)
plot_volume(mask, enhance=False, title='Binary mask for preprocessing')

We use MRI histogram standardization trained on the training dataset for our test image. We use the mean intensity of the volume as a threshold for the mask, as the authors of the method claim that this usually gives good results.

In [0]:
hist_norm = HistogramNormalisationLayer(
    image_name='image',
    modalities=['Modality0'],
    model_filename=str(histogram_landmarks_path),
    binary_masking_func=binary_masking_func,
    cutoff=(0.001, 0.999),
    name='hist_norm_layer',
)

Finally, we force our image foreground to have zero mean and unit variance:

In [0]:
whitening = MeanVarNormalisationLayer(
    image_name='image', binary_masking_func=binary_masking_func)

Here is our preprocessed image:

In [26]:
%%capture --no-display
preprocessing_layers = [
    volume_padding_layer,
    hist_norm,
    whitening,
]
reader = ImageReader().initialise(data_parameters)
reader.add_preprocessing_layers(preprocessing_layers)
_, image_data_dict, _ = reader()
preprocessed_image = image_data_dict['image']
plot_volume(preprocessed_image, title='Preprocessed image')

Note the small difference of intensities due to histogram standardization and the 10-voxel zero-padding.

We can clearly see the effect of the whitening layer on the histogram:

In [27]:
visualization.plot_histogram(preprocessed_image, kde=False, add_labels=True, ylim=(0, 1e6))