#Run once per session
!pip install fastai -q --upgrade
from fastai.vision.all import *
Todays dataset will be CAMVID
, which is a segmentation based problem from cameras on cars to segment various areas of the road
path = untar_data(URLs.CAMVID)
Our validation set is inside a text document called valid.txt
and split by new lines. Let's read it in:
valid_fnames = (path/'valid.txt').read().split('\n')
valid_fnames[:5]
['0016E5_07959.png', '0016E5_07961.png', '0016E5_07963.png', '0016E5_07965.png', '0016E5_07967.png']
Let's look at an image and see how everything aligns up
path_im = path/'images'
path_lbl = path/'labels'
First we need our filenames
fnames = get_image_files(path_im)
lbl_names = get_image_files(path_lbl)
And now let's work with one of them
img_fn = fnames[10]
img = PILImage.create(img_fn)
img.show(figsize=(5,5))
<matplotlib.axes._subplots.AxesSubplot at 0x7f27e272bd68>
Now let's grab our y's. They live in the labels
folder and are denoted by a _P
get_msk = lambda o: path/'labels'/f'{o.stem}_P{o.suffix}'
The stem and suffix grab everything before and after the period respectively.
Our masks are of type PILMask
and we will make our gradient percentage (alpha) equal to 1 as we are not overlaying this on anything yet
msk = PILMask.create(get_msk(img_fn))
msk.show(figsize=(5,5), alpha=1)
<matplotlib.axes._subplots.AxesSubplot at 0x7f27e267ca90>
Now if we look at what our mask actually is, we can see it's a giant array of pixels:
tensor(msk)
tensor([[ 4, 4, 4, ..., 26, 26, 26], [ 4, 4, 4, ..., 26, 26, 26], [ 4, 4, 4, ..., 26, 26, 26], ..., [19, 19, 19, ..., 17, 17, 17], [19, 19, 19, ..., 17, 17, 17], [19, 19, 19, ..., 17, 17, 17]], dtype=torch.uint8)
Where each one represents a class that we can find in codes.txt
. Let's make a vocabulary with it
codes = np.loadtxt(path/'codes.txt', dtype=str); codes
array(['Animal', 'Archway', 'Bicyclist', 'Bridge', 'Building', 'Car', 'CartLuggagePram', 'Child', 'Column_Pole', 'Fence', 'LaneMkgsDriv', 'LaneMkgsNonDriv', 'Misc_Text', 'MotorcycleScooter', 'OtherMoving', 'ParkingBlock', 'Pedestrian', 'Road', 'RoadShoulder', 'Sidewalk', 'SignSymbol', 'Sky', 'SUVPickupTruck', 'TrafficCone', 'TrafficLight', 'Train', 'Tree', 'Truck_Bus', 'Tunnel', 'VegetationMisc', 'Void', 'Wall'], dtype='<U17')
We need a split function that will split from our list of valid filenames we grabbed earlier. Let's try making our own.
def FileSplitter(fname):
"Split `items` depending on the value of `mask`."
valid = Path(fname).read().split('\n')
def _func(x): return x.name in valid
def _inner(o, **kwargs): return FuncSplitter(_func)(o)
return _inner
This takes in our filenames, and checks for all of our filenames in all of our items in our validation filenames
Jeremy popularized the idea of image resizing:
This first round we will train at half the image size
sz = msk.shape; sz
(720, 960)
half = tuple(int(x/2) for x in sz); half
(360, 480)
camvid = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),
get_items=get_image_files,
splitter=FileSplitter(path/'valid.txt'),
get_y=get_msk,
batch_tfms=[*aug_transforms(size=half), Normalize.from_stats(*imagenet_stats)])
dls = camvid.dataloaders(path/'images', bs=8)
Let's look at a batch, and look at all the classes between codes 1 and 30 (ignoring Animal
and Wall
)
dls.show_batch(max_n=4, vmin=1, vmax=30, figsize=(14,10))
Lastly let's make our vocabulary a part of our DataLoaders
, as our loss function needs to deal with the Void
label
dls.vocab = codes
Now we need a methodology for grabbing that particular code from our output of numbers. Let's make everything into a dictionary
name2id = {v:k for k,v in enumerate(codes)}
name2id
{'Animal': 0, 'Archway': 1, 'Bicyclist': 2, 'Bridge': 3, 'Building': 4, 'Car': 5, 'CartLuggagePram': 6, 'Child': 7, 'Column_Pole': 8, 'Fence': 9, 'LaneMkgsDriv': 10, 'LaneMkgsNonDriv': 11, 'Misc_Text': 12, 'MotorcycleScooter': 13, 'OtherMoving': 14, 'ParkingBlock': 15, 'Pedestrian': 16, 'Road': 17, 'RoadShoulder': 18, 'SUVPickupTruck': 22, 'Sidewalk': 19, 'SignSymbol': 20, 'Sky': 21, 'TrafficCone': 23, 'TrafficLight': 24, 'Train': 25, 'Tree': 26, 'Truck_Bus': 27, 'Tunnel': 28, 'VegetationMisc': 29, 'Void': 30, 'Wall': 31}
Awesome! Let's make an accuracy function
void_code = name2id['Void']
For segmentation, we want to squeeze all the outputted values to have it as a matrix of digits for our segmentation mask. From there, we want to match their argmax to the target's mask for each pixel and take the average
def acc_camvid(inp, targ):
targ = targ.squeeze(1)
mask = targ != void_code
return (inp.argmax(dim=1)[mask]==targ[mask]).float().mean()
U-Net allows us to look at pixel-wise representations of our images through sizing it down and then blowing it bck up into a high resolution image. The first part we call an "encoder" and the second a "decoder"
On the image, the authors of the UNET paper describe the arrows as "denotions of different operations"
We have a special unet_learner
. Something new is we can pass in a configuration(unet_config
) where we can declare a few things!
Let's make a unet_learner
that uses some of the new state of the art techniques. Specifically:
self_attention = True
act_cls = Mish
config = unet_config(self_attention=True, act_cls=Mish)
Along with this we will use the Ranger
as optimizer function.
opt = ranger
learn = unet_learner(dls, resnet34, metrics=acc_camvid, config=config,
opt_func=opt)
Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to /root/.cache/torch/checkpoints/resnet34-333f7ec4.pth 100%|██████████| 83.3M/83.3M [00:02<00:00, 40.2MB/s]
learn.summary()
DynamicUnet (Input shape: ['8 x 3 x 360 x 480']) ================================================================ Layer (type) Output Shape Param # Trainable ================================================================ Conv2d 8 x 64 x 180 x 240 9,408 False ________________________________________________________________ BatchNorm2d 8 x 64 x 180 x 240 128 True ________________________________________________________________ ReLU 8 x 64 x 180 x 240 0 False ________________________________________________________________ MaxPool2d 8 x 64 x 90 x 120 0 False ________________________________________________________________ Conv2d 8 x 64 x 90 x 120 36,864 False ________________________________________________________________ BatchNorm2d 8 x 64 x 90 x 120 128 True ________________________________________________________________ ReLU 8 x 64 x 90 x 120 0 False ________________________________________________________________ Conv2d 8 x 64 x 90 x 120 36,864 False ________________________________________________________________ BatchNorm2d 8 x 64 x 90 x 120 128 True ________________________________________________________________ Conv2d 8 x 64 x 90 x 120 36,864 False ________________________________________________________________ BatchNorm2d 8 x 64 x 90 x 120 128 True ________________________________________________________________ ReLU 8 x 64 x 90 x 120 0 False ________________________________________________________________ Conv2d 8 x 64 x 90 x 120 36,864 False ________________________________________________________________ BatchNorm2d 8 x 64 x 90 x 120 128 True ________________________________________________________________ Conv2d 8 x 64 x 90 x 120 36,864 False ________________________________________________________________ BatchNorm2d 8 x 64 x 90 x 120 128 True ________________________________________________________________ ReLU 8 x 64 x 90 x 120 0 False ________________________________________________________________ Conv2d 8 x 64 x 90 x 120 36,864 False ________________________________________________________________ BatchNorm2d 8 x 64 x 90 x 120 128 True ________________________________________________________________ Conv2d 8 x 128 x 45 x 60 73,728 False ________________________________________________________________ BatchNorm2d 8 x 128 x 45 x 60 256 True ________________________________________________________________ ReLU 8 x 128 x 45 x 60 0 False ________________________________________________________________ Conv2d 8 x 128 x 45 x 60 147,456 False ________________________________________________________________ BatchNorm2d 8 x 128 x 45 x 60 256 True ________________________________________________________________ Conv2d 8 x 128 x 45 x 60 8,192 False ________________________________________________________________ BatchNorm2d 8 x 128 x 45 x 60 256 True ________________________________________________________________ Conv2d 8 x 128 x 45 x 60 147,456 False ________________________________________________________________ BatchNorm2d 8 x 128 x 45 x 60 256 True ________________________________________________________________ ReLU 8 x 128 x 45 x 60 0 False ________________________________________________________________ Conv2d 8 x 128 x 45 x 60 147,456 False ________________________________________________________________ BatchNorm2d 8 x 128 x 45 x 60 256 True ________________________________________________________________ Conv2d 8 x 128 x 45 x 60 147,456 False ________________________________________________________________ BatchNorm2d 8 x 128 x 45 x 60 256 True ________________________________________________________________ ReLU 8 x 128 x 45 x 60 0 False ________________________________________________________________ Conv2d 8 x 128 x 45 x 60 147,456 False ________________________________________________________________ BatchNorm2d 8 x 128 x 45 x 60 256 True ________________________________________________________________ Conv2d 8 x 128 x 45 x 60 147,456 False ________________________________________________________________ BatchNorm2d 8 x 128 x 45 x 60 256 True ________________________________________________________________ ReLU 8 x 128 x 45 x 60 0 False ________________________________________________________________ Conv2d 8 x 128 x 45 x 60 147,456 False ________________________________________________________________ BatchNorm2d 8 x 128 x 45 x 60 256 True ________________________________________________________________ Conv2d 8 x 256 x 23 x 30 294,912 False ________________________________________________________________ BatchNorm2d 8 x 256 x 23 x 30 512 True ________________________________________________________________ ReLU 8 x 256 x 23 x 30 0 False ________________________________________________________________ Conv2d 8 x 256 x 23 x 30 589,824 False ________________________________________________________________ BatchNorm2d 8 x 256 x 23 x 30 512 True ________________________________________________________________ Conv2d 8 x 256 x 23 x 30 32,768 False ________________________________________________________________ BatchNorm2d 8 x 256 x 23 x 30 512 True ________________________________________________________________ Conv2d 8 x 256 x 23 x 30 589,824 False ________________________________________________________________ BatchNorm2d 8 x 256 x 23 x 30 512 True ________________________________________________________________ ReLU 8 x 256 x 23 x 30 0 False ________________________________________________________________ Conv2d 8 x 256 x 23 x 30 589,824 False ________________________________________________________________ BatchNorm2d 8 x 256 x 23 x 30 512 True ________________________________________________________________ Conv2d 8 x 256 x 23 x 30 589,824 False ________________________________________________________________ BatchNorm2d 8 x 256 x 23 x 30 512 True ________________________________________________________________ ReLU 8 x 256 x 23 x 30 0 False ________________________________________________________________ Conv2d 8 x 256 x 23 x 30 589,824 False ________________________________________________________________ BatchNorm2d 8 x 256 x 23 x 30 512 True ________________________________________________________________ Conv2d 8 x 256 x 23 x 30 589,824 False ________________________________________________________________ BatchNorm2d 8 x 256 x 23 x 30 512 True ________________________________________________________________ ReLU 8 x 256 x 23 x 30 0 False ________________________________________________________________ Conv2d 8 x 256 x 23 x 30 589,824 False ________________________________________________________________ BatchNorm2d 8 x 256 x 23 x 30 512 True ________________________________________________________________ Conv2d 8 x 256 x 23 x 30 589,824 False ________________________________________________________________ BatchNorm2d 8 x 256 x 23 x 30 512 True ________________________________________________________________ ReLU 8 x 256 x 23 x 30 0 False ________________________________________________________________ Conv2d 8 x 256 x 23 x 30 589,824 False ________________________________________________________________ BatchNorm2d 8 x 256 x 23 x 30 512 True ________________________________________________________________ Conv2d 8 x 256 x 23 x 30 589,824 False ________________________________________________________________ BatchNorm2d 8 x 256 x 23 x 30 512 True ________________________________________________________________ ReLU 8 x 256 x 23 x 30 0 False ________________________________________________________________ Conv2d 8 x 256 x 23 x 30 589,824 False ________________________________________________________________ BatchNorm2d 8 x 256 x 23 x 30 512 True ________________________________________________________________ Conv2d 8 x 512 x 12 x 15 1,179,648 False ________________________________________________________________ BatchNorm2d 8 x 512 x 12 x 15 1,024 True ________________________________________________________________ ReLU 8 x 512 x 12 x 15 0 False ________________________________________________________________ Conv2d 8 x 512 x 12 x 15 2,359,296 False ________________________________________________________________ BatchNorm2d 8 x 512 x 12 x 15 1,024 True ________________________________________________________________ Conv2d 8 x 512 x 12 x 15 131,072 False ________________________________________________________________ BatchNorm2d 8 x 512 x 12 x 15 1,024 True ________________________________________________________________ Conv2d 8 x 512 x 12 x 15 2,359,296 False ________________________________________________________________ BatchNorm2d 8 x 512 x 12 x 15 1,024 True ________________________________________________________________ ReLU 8 x 512 x 12 x 15 0 False ________________________________________________________________ Conv2d 8 x 512 x 12 x 15 2,359,296 False ________________________________________________________________ BatchNorm2d 8 x 512 x 12 x 15 1,024 True ________________________________________________________________ Conv2d 8 x 512 x 12 x 15 2,359,296 False ________________________________________________________________ BatchNorm2d 8 x 512 x 12 x 15 1,024 True ________________________________________________________________ ReLU 8 x 512 x 12 x 15 0 False ________________________________________________________________ Conv2d 8 x 512 x 12 x 15 2,359,296 False ________________________________________________________________ BatchNorm2d 8 x 512 x 12 x 15 1,024 True ________________________________________________________________ BatchNorm2d 8 x 512 x 12 x 15 1,024 True ________________________________________________________________ ReLU 8 x 512 x 12 x 15 0 False ________________________________________________________________ Conv2d 8 x 1024 x 12 x 15 4,718,592 True ________________________________________________________________ BatchNorm2d 8 x 1024 x 12 x 15 2,048 True ________________________________________________________________ Mish 8 x 1024 x 12 x 15 0 False ________________________________________________________________ Conv2d 8 x 512 x 12 x 15 4,718,592 True ________________________________________________________________ BatchNorm2d 8 x 512 x 12 x 15 1,024 True ________________________________________________________________ Mish 8 x 512 x 12 x 15 0 False ________________________________________________________________ Conv2d 8 x 1024 x 12 x 15 524,288 True ________________________________________________________________ BatchNorm2d 8 x 1024 x 12 x 15 2,048 True ________________________________________________________________ Mish 8 x 1024 x 12 x 15 0 False ________________________________________________________________ PixelShuffle 8 x 256 x 24 x 30 0 False ________________________________________________________________ BatchNorm2d 8 x 256 x 23 x 30 512 True ________________________________________________________________ Conv2d 8 x 512 x 23 x 30 2,359,296 True ________________________________________________________________ BatchNorm2d 8 x 512 x 23 x 30 1,024 True ________________________________________________________________ Mish 8 x 512 x 23 x 30 0 False ________________________________________________________________ Conv2d 8 x 512 x 23 x 30 2,359,296 True ________________________________________________________________ BatchNorm2d 8 x 512 x 23 x 30 1,024 True ________________________________________________________________ Mish 8 x 512 x 23 x 30 0 False ________________________________________________________________ Mish 8 x 512 x 23 x 30 0 False ________________________________________________________________ Conv2d 8 x 1024 x 23 x 30 524,288 True ________________________________________________________________ BatchNorm2d 8 x 1024 x 23 x 30 2,048 True ________________________________________________________________ Mish 8 x 1024 x 23 x 30 0 False ________________________________________________________________ PixelShuffle 8 x 256 x 46 x 60 0 False ________________________________________________________________ BatchNorm2d 8 x 128 x 45 x 60 256 True ________________________________________________________________ Conv2d 8 x 384 x 45 x 60 1,327,104 True ________________________________________________________________ BatchNorm2d 8 x 384 x 45 x 60 768 True ________________________________________________________________ Mish 8 x 384 x 45 x 60 0 False ________________________________________________________________ Conv2d 8 x 384 x 45 x 60 1,327,104 True ________________________________________________________________ BatchNorm2d 8 x 384 x 45 x 60 768 True ________________________________________________________________ Mish 8 x 384 x 45 x 60 0 False ________________________________________________________________ Conv1d 8 x 48 x 2700 18,432 True ________________________________________________________________ Conv1d 8 x 48 x 2700 18,432 True ________________________________________________________________ Conv1d 8 x 384 x 2700 147,456 True ________________________________________________________________ Mish 8 x 384 x 45 x 60 0 False ________________________________________________________________ Conv2d 8 x 768 x 45 x 60 294,912 True ________________________________________________________________ BatchNorm2d 8 x 768 x 45 x 60 1,536 True ________________________________________________________________ Mish 8 x 768 x 45 x 60 0 False ________________________________________________________________ PixelShuffle 8 x 192 x 90 x 120 0 False ________________________________________________________________ BatchNorm2d 8 x 64 x 90 x 120 128 True ________________________________________________________________ Conv2d 8 x 256 x 90 x 120 589,824 True ________________________________________________________________ BatchNorm2d 8 x 256 x 90 x 120 512 True ________________________________________________________________ Mish 8 x 256 x 90 x 120 0 False ________________________________________________________________ Conv2d 8 x 256 x 90 x 120 589,824 True ________________________________________________________________ BatchNorm2d 8 x 256 x 90 x 120 512 True ________________________________________________________________ Mish 8 x 256 x 90 x 120 0 False ________________________________________________________________ Mish 8 x 256 x 90 x 120 0 False ________________________________________________________________ Conv2d 8 x 512 x 90 x 120 131,072 True ________________________________________________________________ BatchNorm2d 8 x 512 x 90 x 120 1,024 True ________________________________________________________________ Mish 8 x 512 x 90 x 120 0 False ________________________________________________________________ PixelShuffle 8 x 128 x 180 x 240 0 False ________________________________________________________________ BatchNorm2d 8 x 64 x 180 x 240 128 True ________________________________________________________________ Conv2d 8 x 96 x 180 x 240 165,888 True ________________________________________________________________ BatchNorm2d 8 x 96 x 180 x 240 192 True ________________________________________________________________ Mish 8 x 96 x 180 x 240 0 False ________________________________________________________________ Conv2d 8 x 96 x 180 x 240 82,944 True ________________________________________________________________ BatchNorm2d 8 x 96 x 180 x 240 192 True ________________________________________________________________ Mish 8 x 96 x 180 x 240 0 False ________________________________________________________________ Mish 8 x 192 x 180 x 240 0 False ________________________________________________________________ Conv2d 8 x 384 x 180 x 240 36,864 True ________________________________________________________________ BatchNorm2d 8 x 384 x 180 x 240 768 True ________________________________________________________________ Mish 8 x 384 x 180 x 240 0 False ________________________________________________________________ PixelShuffle 8 x 96 x 360 x 480 0 False ________________________________________________________________ ResizeToOrig 8 x 96 x 360 x 480 0 False ________________________________________________________________ MergeLayer 8 x 99 x 360 x 480 0 False ________________________________________________________________ Conv2d 8 x 99 x 360 x 480 88,209 True ________________________________________________________________ BatchNorm2d 8 x 99 x 360 x 480 198 True ________________________________________________________________ Mish 8 x 99 x 360 x 480 0 False ________________________________________________________________ Conv2d 8 x 99 x 360 x 480 88,209 True ________________________________________________________________ BatchNorm2d 8 x 99 x 360 x 480 198 True ________________________________________________________________ Sequential 8 x 99 x 360 x 480 0 False ________________________________________________________________ Mish 8 x 99 x 360 x 480 0 False ________________________________________________________________ Conv2d 8 x 32 x 360 x 480 3,168 True ________________________________________________________________ BatchNorm2d 8 x 32 x 360 x 480 64 True ________________________________________________________________ Total params: 41,416,462 Total trainable params: 20,148,814 Total non-trainable params: 21,267,648 Optimizer used: <function ranger at 0x7f27ede82a60> Loss function: FlattenedLoss of CrossEntropyLoss() Model frozen up to parameter group number 2 Callbacks: - TrainEvalCallback - Recorder - ProgressCallback
If we do a learn.summary
we can see this blow-up trend, and see that our model came in frozen. Let's find a learning rate
learn.lr_find()
lr = 1e-3
With our new optimizer, we will also want to use a different fit function, called fit_flat_cos
learn.fit_flat_cos(10, slice(lr))
epoch | train_loss | valid_loss | acc_camvid | time |
---|---|---|---|---|
0 | 2.267863 | 1.742724 | 0.827059 | 01:12 |
1 | 1.821520 | 1.487319 | 0.858477 | 00:59 |
2 | 1.588776 | 1.349779 | 0.865436 | 00:59 |
3 | 1.430715 | 1.215538 | 0.886508 | 00:59 |
4 | 1.324019 | 1.124806 | 0.902292 | 00:59 |
5 | 1.232799 | 1.035299 | 0.903780 | 00:59 |
6 | 1.133886 | 0.990665 | 0.904250 | 00:59 |
7 | 1.059976 | 0.914800 | 0.911261 | 00:59 |
8 | 0.980094 | 0.883291 | 0.914787 | 00:59 |
9 | 0.920574 | 0.861305 | 0.915084 | 00:59 |
learn.save('stage-1')
learn.load('stage-1');
learn.show_results(max_n=4, figsize=(12,6))
Let's unfreeze the model and decrease our learning rate by 4 (Rule of thumb)
lrs = slice(lr/400, lr/4)
lr, lrs
(0.001, slice(2.5e-06, 0.00025, None))
learn.unfreeze()
And train for a bit more
learn.fit_flat_cos(12, lrs)
epoch | train_loss | valid_loss | acc_camvid | time |
---|---|---|---|---|
0 | 0.904472 | 0.864132 | 0.914719 | 01:01 |
1 | 0.893953 | 0.861930 | 0.915803 | 01:01 |
2 | 0.870641 | 0.839532 | 0.915104 | 01:00 |
3 | 0.857703 | 0.844047 | 0.912143 | 01:00 |
4 | 0.835774 | 0.817610 | 0.914030 | 01:00 |
5 | 0.818883 | 0.815786 | 0.915836 | 01:01 |
6 | 0.790792 | 0.797801 | 0.914842 | 01:01 |
7 | 0.782354 | 0.782085 | 0.916641 | 01:01 |
8 | 0.764274 | 0.766298 | 0.917362 | 01:00 |
9 | 0.747887 | 0.769105 | 0.917114 | 01:01 |
10 | 0.721998 | 0.756591 | 0.916820 | 01:01 |
11 | 0.717462 | 0.758870 | 0.916470 | 01:01 |
Now let's save that model away
learn.save('model_1')
And look at a few results
learn.show_results(max_n=4, figsize=(18,8))
Let's take a look at how to do inference with test_dl
dl = learn.dls.test_dl(fnames[:5])
dl.show_batch()
Let's do the first five pictures
preds = learn.get_preds(dl=dl)
preds[0].shape
torch.Size([5, 32, 360, 480])
Alright so we have a 5x32x360x480
len(codes)
32
What does this mean? We had five images, so each one is one of our five images in our batch. Let's look at the first
pred_1 = preds[0][0]
pred_1.shape
torch.Size([32, 360, 480])
Now let's take the argmax of our values
pred_arx = pred_1.argmax(dim=0)
And look at it
plt.imshow(pred_arx)
<matplotlib.image.AxesImage at 0x7f27b71272b0>
What do we do from here? We need to save it away. We can do this one of two ways, as a numpy array to image, and as a tensor (to say use later rawly)
pred_arx = pred_arx.numpy()
rescaled = (255.0 / pred_arx.max() * (pred_arx - pred_arx.min())).astype(np.uint8)
im = Image.fromarray(rescaled)
im
im.save('test.png')
Let's make a function to do so for our files
for i, pred in enumerate(preds[0]):
pred_arg = pred.argmax(dim=0).numpy()
rescaled = (255.0 / pred_arg.max() * (pred_arg - pred_arg.min())).astype(np.uint8)
im = Image.fromarray(rescaled)
im.save(f'Image_{i}.png')
Now let's save away the raw:
torch.save(preds[0][0], 'Image_1.pt')
pred_1 = torch.load('Image_1.pt')
plt.imshow(pred_1.argmax(dim=0))
<matplotlib.image.AxesImage at 0x7f27b6f1d390>
Now let's go full sized. Restart your instance to re-free your memory
from fastai.basics import *
from fastai.vision.all import *
from fastai.callback.all import *
path = untar_data(URLs.CAMVID)
valid_fnames = (path/'valid.txt').read().split('\n')
get_msk = lambda o: path/'labels'/f'{o.stem}_P{o.suffix}'
codes = np.loadtxt(path/'codes.txt', dtype=str); codes
def ListSplitter(items):
def _inner(it):
val_mask = tensor([o.name in items for o in it])
return [~val_mask, val_mask]
return _inner
name2id = {v:k for k,v in enumerate(codes)}
void_code = name2id['Void']
def acc_camvid(inp, targ):
targ = targ.squeeze(1)
mask = targ != void_code
return (inp.argmax(dim=1)[mask]==targ[mask]).float().mean()
And re-make our dataloaders. But this time we want our size to be the full size
camvid = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),
get_items=get_image_files,
splitter=ListSplitter(valid_fnames),
get_y=get_msk,
batch_tfms=[*aug_transforms(size=half), Normalize.from_stats(*imagenet_stats)])
We'll also want to lower our batch size to not run out of memory
dls = camvid.dataloaders(path/"images", bs=1)
Let's assign our vocab, make our learner, and load our weights
config = unet_config(self_attention=True, act_cls=Mish)
opt = ranger
dls.vocab = codes
learn = unet_learner(dls, resnet34, metrics=acc_camvid, config=config,
opt_func=opt)
learn.load('model_1');
And now let's find our learning rate and train!
learn.lr_find()
lr = 1e-3
learn.fit_flat_cos(10, slice(lr))
epoch | train_loss | valid_loss | acc_camvid | time |
---|---|---|---|---|
0 | 0.506793 | 0.342499 | 0.926143 | 05:15 |
1 | 0.430471 | 0.331063 | 0.924518 | 05:14 |
2 | 0.349916 | 0.330467 | 0.915646 | 05:14 |
3 | 0.330413 | 0.291114 | 0.926524 | 05:13 |
4 | 0.307470 | 0.273733 | 0.928941 | 05:13 |
5 | 0.281810 | 0.278765 | 0.926421 | 05:12 |
6 | 0.250463 | 0.288746 | 0.923950 | 05:12 |
7 | 0.248945 | 0.284137 | 0.923201 | 05:11 |
8 | 0.214759 | 0.256868 | 0.931487 | 05:10 |
9 | 0.188897 | 0.267987 | 0.928474 | 05:10 |
learn.save('full_1')
learn.unfreeze()
lrs = slice(1e-6,lr/10); lrs
slice(1e-06, 0.0001, None)
learn.fit_flat_cos(10, lrs)
epoch | train_loss | valid_loss | acc_camvid | time |
---|---|---|---|---|
0 | 0.187095 | 0.266967 | 0.926938 | 05:20 |
1 | 0.186410 | 0.262047 | 0.928341 | 05:20 |
2 | 0.190994 | 0.282032 | 0.922368 | 05:19 |
3 | 0.186702 | 0.262346 | 0.929233 | 05:19 |
4 | 0.179937 | 0.270361 | 0.925767 | 05:19 |
5 | 0.176367 | 0.270718 | 0.926208 | 05:19 |
6 | 0.176295 | 0.259349 | 0.928089 | 05:19 |
7 | 0.171654 | 0.260007 | 0.928189 | 05:19 |
8 | 0.170364 | 0.261180 | 0.928160 | 05:19 |
9 | 0.160006 | 0.253466 | 0.929910 | 05:19 |
learn.save('full_2')
learn.show_results(max_n=4, figsize=(18,8))
We can use weighted loss functions to help with class imbalancing. We need to do this because simply oversampling won't quite work here! So, how do we do it? fastai
's CrossEntropyLossFlat
is just a wrapper around PyTorch
's CrossEntropyLoss
, so we can pass in a weight
parameter (even if it doesn't show up in our autocompletion!)
class CrossEntropyLossFlat(BaseLoss):
"Same as `nn.CrossEntropyLoss`, but flattens input and target."
y_int = True
def __init__(self, *args, axis=-1, **kwargs): super().__init__(nn.CrossEntropyLoss, *args, axis=axis, **kwargs)
def decodes(self, x): return x.argmax(dim=self.axis)
def activation(self, x): return F.softmax(x, dim=self.axis)
But what should this weight be? It needs to be a 1xn
tensor, where n
is the number of classes in your dataset. We'll use a quick example, where all but the last class has a weight of 90% and the last class has a weight of 110%
Also, as we are training on the GPU, we need the tensor to be so as well:
weights = torch.tensor([[0.9]*31 + [1.1]]).cuda()
weights
tensor([[0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 0.9000, 1.1000]], device='cuda:0')
Now we can pass this into CrossEntropyLossFlat
learn.loss_func = CrossEntropyLossFlat(weight=weights, axis=1)
(or to pass it into cnn_learner
)
loss_func = CrossEntropyLossFlat(weight=weights, axis=1)
learn = unet_learner(dls, resnet34, metrics=acc_camvid, loss_func=loss_func)