using HDF5, Statistics, Images, ImageView using Knet # load training data file_path = "/mnt/jk489/Kathi/Julia/cropped_data_0_3000.h5" # not provided train_data = h5read(file_path, "data"); X_train = convert(KnetArray, reshape(train_data[:,:,1:100,2], 256,256,1,100)); Y_train = convert(KnetArray, reshape(train_data[:,:,1:100,1], 256,256,1,100)); # load testing data file_path = "./cropped_data_3000_6000.h5" # not provided test_data = h5read(file_path, "data"); X_test = convert(KnetArray, reshape(test_data[:,:,1:10,2], 256,256,1,10)); Y_test = convert(KnetArray, reshape(test_data[:,:,1:10,1], 256,256,1,10)); architecture = load("./images_for_notebook/UNet_architecture.png") input = load("./images_for_notebook/example_input.png") mask = load("./images_for_notebook/example_mask.png") function flatten(array) size_array = size(array) length_flattened = 1 for dim = 1:ndims(array) length_flattened = length_flattened * size_array[dim] end return flattened = reshape(array,length_flattened) end function cat3d(arr1, arr2) R = size(arr1)[1] C = size(arr1)[2] Ch = size(arr1)[3]+size(arr2)[3] S = size(arr1)[4] arr1 = flatten(permutedims(arr1, [1,2,4,3])) arr2 = flatten(permutedims(arr2, [1,2,4,3])) out = reshape(cat1d(arr1,arr2),R,C,S,Ch) out = permutedims(out,[1,2,4,3]) return out end function conv_layer_encoder(w, bn_params, x) out = conv4(w[1], x, padding=1).+ w[2] out = batchnorm(out, bn_moments[1], w[3]) conv_1 = relu.(out) out = conv4(w[4], conv_1, padding=1) .+ w[5] out = batchnorm(out, bn_moments[2], w[6]) out = relu.(out) pooled = pool(out) return conv_1, pooled end function conv_layer_encoder(w, bn_params, x, ) out = conv4(w[1], x, padding=1).+ w[2] out = batchnorm(out, bn_moments[1], w[3]) conv_1 = relu.(out) out = conv4(w[4], conv_1, padding=1) .+ w[5] out = batchnorm(out, bn_moments[2], w[6]) out = relu.(out) pooled = pool(out) return conv_1, pooled end function conv_layer_decoder(w, bn_moments, x_cat, x) number_of_channels = size(x)[3] upsampling_kernel = convert(KnetArray, bilinear(Float64,2,2,number_of_channels,number_of_channels)) out = deconv4(upsampling_kernel, x, padding = 1,stride = 2) out = cat3d(out, x_cat) out = conv4(w[1], out, padding = 1) .+ w[2] out = batchnorm(out, bn_moments[1], w[3]) out = relu.(out) out = conv4(w[4], out, padding = 1) .+ w[5] out = batchnorm(out, bn_moments[2], w[6]) out = relu.(out) return out end function bottleneck(w, bn_moments, x) out = conv4(w[1], x, padding = 1) .+ w[2] out = batchnorm(out, bn_moments[1], w[3]) out = relu.(out) out = conv4(w[4], out, padding = 1) .+ w[5] out = batchnorm(out, bn_moments[2], w[6]) out = relu.(out) return out end function output_layer(w, x) out = sigm.(conv4(w[1], x, padding = 1).+ w[2]) return out end function predict(x_in, w, bn_moments) conv_1, out = conv_layer_encoder(w[1:6], bn_moments[1:2], x_in); # layer 2 conv_2, out = conv_layer_encoder(w[7:12], bn_moments[3:4], out); #layer 3 conv_3, out = conv_layer_encoder(w[13:18], bn_moments[5:6], out); # layer 4 = bottleneck out = bottleneck(w[19:24], bn_moments[7:8], out); # layer 5 out = conv_layer_decoder(w[25:30], bn_moments[9:10], conv_3, out); # layer 6 out = conv_layer_decoder(w[31:36], bn_moments[11:12], conv_2, out); # layer 7 out = conv_layer_decoder(w[37:42], bn_moments[13:14], conv_1, out); out = output_layer(w[43:44], out) return out end function init_model(n_input_channels=1, n_output_channels=1, depth=4, max_channels=512, kernel_size=3) w = Any[] bn_moments = Any[bnmoments() for i = 1:14] # 1st layer # convolution push!(w,xavier(Float64,kernel_size,kernel_size,n_input_channels,Int64(max_channels/(2^(depth-1))))); # bias push!(w,zeros(Float64,1,1,Int64(max_channels/(2^(depth-1))),1)); # batchnorm params push!(w, bnparams(Float64,Int64(max_channels/(2^(depth-1))))); # convolution push!(w,xavier(Float64,kernel_size,kernel_size,Int64(max_channels/(2^(depth-1))),Int64(max_channels/(2^(depth-1))))); # bias push!(w,zeros(Float64,1,1,Int64(max_channels/(2^(depth-1))),1)); # batchnorm params push!(w, bnparams(Float64,Int64(max_channels/(2^(depth-1))))); # encoding arm: 2nd up to and including bottleneck for layer=2:depth push!(w,xavier(Float64,kernel_size,kernel_size,Int64(max_channels/(2^(depth-layer+1))),Int64(max_channels/(2^(depth-layer))))); push!(w, zeros(Float64,1,1,Int64(max_channels/(2^(depth-layer))),1)); push!(w, bnparams(Float64,Int64(max_channels/(2^(depth-layer))))); push!(w,xavier(Float64,kernel_size,kernel_size,Int64(max_channels/(2^(depth-layer))),Int64(max_channels/(2^(depth-layer))))); push!(w, zeros(Float64,1,1,Int64(max_channels/(2^(depth-layer))),1)); push!(w, bnparams(Float64,Int64(max_channels/(2^(depth-layer))))); end # decoding arm (except 3rd convolution in the last layer) for layer=(depth+1):(2*depth-1) push!(w,xavier(Float64,kernel_size,kernel_size,Int64(1.5*max_channels/(2^(layer-depth-1))),Int64(max_channels/(2^(layer-depth))))); push!(w, zeros(Float64,1,1,Int64(max_channels/(2^(layer-depth))),1)); push!(w, bnparams(Float64,Int64(max_channels/(2^(layer-depth))))); push!(w,xavier(Float64,kernel_size,kernel_size,Int64(max_channels/(2^(layer-depth))),Int64(max_channels/(2^(layer-depth))))); push!(w, zeros(Float64,1,1,Int64(max_channels/(2^(layer-depth))),1)); push!(w, bnparams(Float64,Int64(max_channels/(2^(layer-depth))))); end # output convolution push!(w, xavier(Float64,kernel_size,kernel_size,Int64(max_channels/(2^(depth-1))),n_output_channels)) push!(w, zeros(Float64,1,1,n_output_channels,1)) w = map(KnetArray, w) return w, bn_moments end # dice coefficient to track training progress function dice_coeff(X, Y, predict, w, bn_moments) lambda = 1 Y_pred = predict(X, w, bn_moments) Y_pred_flatten = flatten(Y_pred) Y_flatten = flatten(Y) intersection = sum(Y_pred_flatten .* Y_flatten) union_area = sum(Y_pred_flatten) + sum(Y_flatten) return (2*intersection)/(union_area + lambda) end w, bn_moments = init_model(); lr = 0.001 optim = optimizers(w, Adam; lr=lr); loss(w, bn_moments, x_in, y_true, predict) = bce(predict(x_in, w, bn_moments), y_true) lossgradient = grad(loss) function epoch!(w, bn_moments, optim, X_train, Y_train; batch_size=2) data = minibatch(X_train, Y_train, batch_size; shuffle=true, xtype=KnetArray) epoch_dice = [] for (X, Y) in data gradient = lossgradient(w, bn_moments, X, Y, predict) update!(w, gradient, optim) dice = dice_coeff(X, Y, predict, w, bn_moments) push!(epoch_dice, dice) Knet.gc() end println("dice: ", string(mean(epoch_dice))) end function train(w, bn_moments, optim, X_train, Y_train, predict; batch_size=2, epochs=10) for epoch = 1:epochs epoch!(w, bn_moments, optim, X_train, Y_train; batch_size=batch_size) end end train(w, bn_moments, optim, X_train, Y_train, predict, batch_size=2, epochs=17) test = load("./images_for_notebook/model_test.png")