%load_ext autoreload
%autoreload 2
# Necessary imports
import os
import sys
import warnings
import numpy as np
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
warnings.filterwarnings("ignore")
sys.path.append('../..')
from seismiqb.batchflow import Pipeline, FilesIndex
from seismiqb.batchflow import B, V, C, D, P, R, W, L
from seismiqb.batchflow.models.torch.layers import *
from seismiqb.batchflow.models.torch import *
from seismiqb import SeismicCubeset, SeismicGeometry
from seismiqb import plot_image, plot_loss
from utils import show_slide_dataset, show_slide
# Set GPU
%env CUDA_VISIBLE_DEVICES=0
# Global parameters
CROP_SHAPE = (1, 512, 832) # shape of sampled 3D crops
DYNAMIC_FACTOR = 1. # scaling of shape during train
ITERS = 800 # number of train iterations
BATCH_SIZE = 512 # number of crops inside one batch
CLASS_LABELS = [
'Basement/other',
'Slope Mudstone A',
'Mass Transport\n Deposit',
'Slope Mudstone B',
'Slope Valley',
'Submarine Canyon\n System'
]
cube_path = '/data/hackathon/train_amplitudes.hdf5'
label_path = '/data/hackathon/train_labels.hdf5'
dsi = FilesIndex(path=[cube_path], no_ext=True)
dataset = SeismicCubeset(dsi)
dataset.load_geometries()
dataset.labels[dataset.indices[0]] = SeismicGeometry(label_path)
print(dataset.geometries[0])
dataset.create_sampler(mode='default')
dataset.modify_sampler('train_sampler', finish=True)
_ = dataset.show_slices(src_sampler='train_sampler',
normalize=False, shape=CROP_SHAPE, side_view=True,
cmap='Reds', interpolation='bilinear',
figsize=(8, 6))
show_slide_dataset(100, dataset=dataset, figsize=(15, 15))
import torch
def dice_loss(pred, target, eps=1e-7):
num_classes = pred.shape[1]
target = target.long()
probas = torch.nn.functional.softmax(pred, dim=1)
true_1_hot = torch.eye(num_classes)[target.squeeze(1)]
true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
true_1_hot = true_1_hot.to(pred.device).type(pred.type())
dims = (0,) + tuple(range(2, target.ndimension()))
intersection = torch.sum(probas * true_1_hot, dims)
cardinality = torch.sum(probas + true_1_hot, dims)
loss = (2. * intersection / (cardinality + eps)).mean()
return 1 - loss
MODEL_CONFIG = {
'inputs/masks/classes': 6,
# Model layout
'initial_block': {
'base_block': ResBlock,
'filters': 16,
'kernel_size': 5,
'downsample': False,
'attention': 'scse'
},
'body/encoder': {
'num_stages': 4,
'order': 'sbd',
'blocks': {
'base': ResBlock,
'n_reps': 1,
'filters': [32, 64, 128, 256],
'attention': 'scse',
},
},
'body/embedding': {
'base': ResBlock,
'n_reps': 1,
'filters': 256,
'attention': 'scse',
},
'body/decoder': {
'num_stages': 4,
'upsample': {
'layout': 'tna',
'kernel_size': 2,
},
'blocks': {
'base': ResBlock,
'filters': [128, 64, 32, 16],
'attention': 'scse',
},
},
'head': {
'base_block': ResBlock,
'filters': [16, 8],
'attention': 'scse'
},
'output': 'sigmoid',
# Train configuration
'loss': dice_loss,
'optimizer': {'name': 'Adam', 'lr': 0.01,},
'decay': {'name': 'exp', 'gamma': 0.5, 'frequency': 100},
'microbatch': 2,
'common/activation': 'relu6',
}
def generate_shape(batch, shape, dynamic_factor=1, dynamic_low=None, dynamic_high=None):
dynamic_low = dynamic_low or dynamic_factor
dynamic_high = dynamic_high or dynamic_factor
i, x, h = shape
x_ = np.random.randint(x // dynamic_low, x * dynamic_high + 1)
h_ = np.random.randint(h // dynamic_low, h * dynamic_high + 1)
return (i, x_, h_)
def adjust_masks(batch):
batch.masks -= 1
batch.masks = batch.masks.astype(np.float32)
train_template = (
Pipeline()
# Initialize pipeline variables and model
.init_variable('loss_history', [])
.init_model('dynamic', EncoderDecoder, 'model', MODEL_CONFIG)
# Dynamically generate shape
.init_variable('shape', None)
.call(generate_shape, shape=C('crop_shape'),
dynamic_factor=DYNAMIC_FACTOR, save_to=V('shape'))
.crop(points=D('train_sampler')(BATCH_SIZE),
shape=V('shape'))
# Load data/masks
.load_cubes(dst='images')
.load_cubes(dst='masks', src_geometry='labels')
.adaptive_reshape(src=['images', 'masks'], shape=V('shape'))
.scale(mode='q', src='images')
.call(adjust_masks)
# Augmentations
.transpose(src=['images', 'masks'], order=(1, 2, 0))
.additive_noise(scale=0.005, src='images', dst='images', p=0.3)
.flip(axis=1, src=['images', 'masks'],
seed=P(R('uniform', 0, 1)), p=0.3)
.rotate(angle=P(R('uniform', -15, 15)),
src=['images', 'masks'], p=0.3)
.scale_2d(scale=P(R('uniform', 0.85, 1.15)),
src=['images', 'masks'], p=0.3)
.elastic_transform(alpha=P(R('uniform', 35, 45)),
sigma=P(R('uniform', 4, 4.5)),
src=['images', 'masks'], p=0.2)
.transpose(src=['images', 'masks'], order=(2, 0, 1))
# Training
.train_model('model',
fetches='loss',
images=B('images'),
masks=B('masks'),
save_to=V('loss_history', mode='a'))
)
ppl_config = {'crop_shape': CROP_SHAPE}
train_pipeline = (train_template << ppl_config) << dataset
%%time
batch = train_pipeline.next_batch(1)
%%time
train_pipeline.run(D('size'), n_iters=ITERS,
bar={'bar': 'n', 'monitors': 'loss_history'})
plot_loss(train_pipeline.v('loss_history'))
train_pipeline.reset('variables')
torch.cuda.empty_cache()
# Validation pipeline: no augmentations
val_template = (
Pipeline()
# Import model
.import_model('model', train_pipeline)
# Load data/masks
.crop(points=D('train_sampler')(16), shape=CROP_SHAPE)
.load_cubes(dst='images')
.load_cubes(dst='masks', src_geometry='labels')
.scale(mode='q', src='images')
.call(adjust_masks)
# Predict with model
.predict_model('model',
B('images'),
fetches='predictions',
save_to=B('predictions'))
.transpose(src=['images', 'masks', 'predictions'],
order=(1, 2, 0))
)
val_pipeline = val_template << dataset
%%time
batch = val_pipeline.next_batch(1)
from sklearn.metrics import confusion_matrix, classification_report
cm = confusion_matrix(batch.masks.flatten(),
np.argmax(batch.predictions, axis=-1).flatten())
report = classification_report(batch.masks.flatten(),
np.argmax(batch.predictions, axis=-1).flatten())
print(report)
fig, ax = plt.subplots(figsize=(8, 8))
ax.matshow(cm)
plt.title('Confusion matrix', y=1.21)
plt.xticks(rotation=20)
ax.set_xticklabels([''] + CLASS_LABELS)
ax.set_yticklabels([''] + CLASS_LABELS)
plt.xlabel('Predicted', fontdict={'fontsize': 16})
plt.ylabel('True', fontdict={'fontsize': 16})
plt.show()
idx = 4
img, mask, pred = batch.images[idx], batch.masks[idx], batch.predictions[idx]
show_slide((img, mask), opacity=0.15)
show_slide((img, np.argmax(pred, axis=-1)), opacity=0.15)