Fine-tuning a pretrained model on the pets dataset.
from local.test import *
from local.basics import *
from local.callback.all import *
from local.vision.all import *
We use the data block API to get our data in a DataBunch
. Here our inputs are images and our targets categories. The images are all in a folder, so we use get_image_files
to collect them all, a RandomSplitter
to split between training and validation, then we get the label from the filenames with a regex labeller.
pets = DataBlock(blocks=(ImageBlock, CategoryBlock),
get_items=get_image_files,
splitter=RandomSplitter(),
get_y=RegexLabeller(pat = r'/([^/]+)_\d+.jpg$'))
The pets object by itself is empty: it only containes the functions that will help us gather the data. We have to call its datasource
or databunch
method to get a DataSource
or a DataBunch
. The first thing we need to pass to either of those functions is the source, here the folder where all the images are. Then we specify some dataset transforms (a random resized crop to 300 by 300) and some dataloader transforms (basic data augmentation on the GPU and normalization using the imagenet statistics).
dbunch = pets.databunch(untar_data(URLs.PETS)/"images", item_tfms=RandomResizedCrop(460, min_scale=0.75), bs=32,
batch_tfms=[*aug_transforms(size=299, max_warp=0), Normalize(*imagenet_stats)])
Then we can look at some of our pictures with dbunch.show_batch()
len(dbunch.train_ds.items)
5912
dbunch.show_batch(max_n=9)
First let's import a resnet34
form torchvision.
from torchvision.models import resnet34,resnet50
#from local.vision.models.xresnet import xresnet50
We will use the AdamW optimizer (Adam with true weight decay).
opt_func = partial(Adam, lr=slice(3e-3), wd=0.01, eps=1e-8)
#Or use Ranger
#def opt_func(p, lr=slice(3e-3)): return Lookahead(RAdam(p, lr=lr, mom=0.95, wd=0.01))
Then we can call cnn_learner
to build a Learner
from our DataBunch
. Since we are using a pretrained model, it comes automatically frozen, which means only the head is going to be trained.
learn = cnn_learner(dbunch, resnet50, opt_func=opt_func, metrics=error_rate, config=cnn_config(ps=0.33)).to_fp16()
We can train the head a little bit using the 1cycle policy.
learn.fit_one_cycle(1)
epoch | train_loss | valid_loss | error_rate | time |
---|---|---|---|---|
0 | 0.455705 | 0.239800 | 0.077808 | 00:36 |
#learn.fit_one_cycle(8, slice(3e-3))
Then we can unfreeze the model and use discriminative learning rates.
learn.unfreeze()
# learn.fit_one_cycle(4, slice(1e-5, 1e-3))
learn.fit_one_cycle(1, slice(1e-5, 1e-3))
learn.predict(dbunch.train_ds.items[0])
('miniature_pinscher', tensor(26), tensor([1.3895e-03, 1.1339e-04, 4.7167e-06, 1.8621e-04, 3.6235e-06, 2.1440e-05, 1.2743e-07, 8.3745e-07, 1.4442e-06, 8.6767e-05, 5.1617e-04, 4.7932e-03, 2.0890e-06, 3.9454e-05, 1.0844e-05, 1.0174e-04, 3.2597e-05, 4.6821e-03, 5.1652e-06, 2.3441e-06, 8.5756e-05, 5.3335e-07, 1.6687e-06, 5.7565e-06, 1.2709e-05, 5.5136e-07, 9.8764e-01, 9.1082e-07, 1.4529e-05, 6.8028e-06, 1.8983e-06, 4.2363e-06, 4.4036e-05, 5.3980e-05, 3.2933e-05, 1.0753e-06, 1.0164e-04]))
learn.show_results(max_n=9)
interp = Interpretation.from_learner(learn)
interp.plot_top_losses(9, figsize=(15,10))
planet_source = untar_data(URLs.PLANET_TINY)
df = pd.read_csv(planet_source/"labels.csv")
planet = DataBlock(blocks=(ImageBlock, MultiCategoryBlock),
get_x=lambda x:planet_source/"train"/f'{x[0]}.jpg',
splitter=RandomSplitter(),
get_y=lambda x:x[1].split(' '))
dbunch = planet.databunch(df.values,
batch_tfms=aug_transforms(flip_vert=True, max_lighting=0.1, max_zoom=1.05, max_warp=0.))
dbunch.show_batch(max_n=9, figsize=(12,9))
learn = cnn_learner(dbunch, resnet34, opt_func=opt_func, metrics=accuracy_multi)
learn.fit_one_cycle(1)
epoch | train_loss | valid_loss | accuracy_multi | time |
---|---|---|---|---|
0 | 0.984292 | 0.989365 | 0.417857 | 00:04 |
Since we have a weird first transform (that is there to preprocess the names in the dataframe and make proper filenames) we remove it when we want to do inference in Learner.predict
or test_dl
with the rm_type_tfms
arg.
learn.predict(planet_source/f'train/train_10030.jpg', rm_type_tfms=1)
((#5) [agriculture,artisinal_mine,cultivation,habitation,haze], tensor([ True, True, False, False, False, False, True, True, True, False, False, False, False, False]), tensor([0.6245, 0.8218, 0.3397, 0.3111, 0.2421, 0.3718, 0.6801, 0.8657, 0.6233, 0.4426, 0.2959, 0.3173, 0.2264, 0.2996]))
learn.show_results(max_n=9)
interp = Interpretation.from_learner(learn)
interp.plot_top_losses(9)
target | predicted | probabilities | loss | |
---|---|---|---|---|
0 | clear;primary;water | agriculture;artisinal_mine;bare_ground;cloudy;cultivation;habitation;haze;partly_cloudy;road;selective_logging | tensor([0.9854, 0.9713, 0.6691, 0.3597, 0.0635, 0.6860, 0.9405, 0.7924, 0.9206,\n 0.9019, 0.2825, 0.5693, 0.8724, 0.1841]) | 2.0251517295837402 |
1 | clear;primary;water | agriculture;artisinal_mine;bare_ground;cloudy;cultivation;haze;partly_cloudy;primary;selective_logging;water | tensor([0.8294, 0.9305, 0.6096, 0.4756, 0.0079, 0.8483, 0.8683, 0.1757, 0.9623,\n 0.8643, 0.7121, 0.4005, 0.6676, 0.6034]) | 1.621275544166565 |
2 | agriculture;clear;primary;road;water | agriculture;artisinal_mine;bare_ground;blooming;cloudy;cultivation;habitation;haze;partly_cloudy;road;selective_logging | tensor([0.9785, 0.9503, 0.5710, 0.5760, 0.1992, 0.7882, 0.9902, 0.9602, 0.7873,\n 0.8915, 0.4699, 0.9225, 0.6990, 0.3487]) | 1.6141308546066284 |
3 | artisinal_mine;clear;primary;road;water | agriculture;artisinal_mine;cultivation;habitation;haze;partly_cloudy;primary;road;selective_logging;water | tensor([0.9320, 0.9605, 0.2803, 0.1407, 0.3144, 0.1803, 0.9808, 0.9928, 0.9133,\n 0.8834, 0.5903, 0.5928, 0.8103, 0.7938]) | 1.4995286464691162 |
4 | artisinal_mine;clear;primary;water | agriculture;artisinal_mine;blooming;cultivation;habitation;haze;partly_cloudy;primary;road | tensor([0.9789, 0.8111, 0.4456, 0.7128, 0.3238, 0.0783, 0.5463, 0.9456, 0.7621,\n 0.9920, 0.9421, 0.8675, 0.1426, 0.2553]) | 1.4772450923919678 |
5 | agriculture;clear;primary;road | agriculture;artisinal_mine;bare_ground;cloudy;cultivation;habitation;haze;partly_cloudy;primary | tensor([0.7373, 0.8035, 0.6012, 0.3029, 0.0942, 0.7587, 0.8844, 0.7664, 0.8060,\n 0.7086, 0.5070, 0.2195, 0.2832, 0.2694]) | 1.165945053100586 |
6 | agriculture;clear;primary;road | agriculture;artisinal_mine;bare_ground;habitation;haze;partly_cloudy;road;selective_logging;water | tensor([0.9951, 0.8910, 0.5872, 0.4731, 0.1032, 0.2087, 0.4710, 0.8880, 0.5711,\n 0.8639, 0.1938, 0.8785, 0.6007, 0.7388]) | 1.139198660850525 |
7 | agriculture;partly_cloudy;primary | agriculture;artisinal_mine;bare_ground;cloudy;cultivation;habitation;haze;road;selective_logging | tensor([0.8243, 0.8928, 0.8016, 0.2904, 0.1221, 0.7102, 0.8035, 0.7743, 0.6381,\n 0.3634, 0.0965, 0.5089, 0.5524, 0.3672]) | 1.0865401029586792 |
8 | clear;primary | agriculture;artisinal_mine;bare_ground;cloudy;cultivation;habitation;haze;partly_cloudy;selective_logging | tensor([0.5802, 0.8437, 0.6628, 0.2134, 0.2427, 0.6364, 0.5345, 0.7982, 0.6738,\n 0.5906, 0.2330, 0.2999, 0.8398, 0.3854]) | 1.0705785751342773 |
camvid = DataBlock(blocks=(ImageBlock, ImageBlock(cls=PILMask)),
get_items=get_image_files,
splitter=RandomSplitter(),
get_y=lambda o: untar_data(URLs.CAMVID_TINY)/'labels'/f'{o.stem}_P{o.suffix}')
dbunch = camvid.databunch(untar_data(URLs.CAMVID_TINY)/"images", batch_tfms=aug_transforms())
dbunch.show_batch(max_n=9, vmin=1, vmax=30)
#TODO: Find a way to pass the classes properly
dbunch.vocab = np.loadtxt(untar_data(URLs.CAMVID_TINY)/'codes.txt', dtype=str)
learn = unet_learner(dbunch, resnet34, opt_func=opt_func, config=unet_config())
learn.fit_one_cycle(1, 1e-3)
# Use the below to get somewhat reasonable results - but takes a bit longer
# learn.fit_one_cycle(8, 3e-3)
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 3.413007 | 4.757127 | 00:05 |
learn.predict(dbunch.train_ds.items[0]);
learn.show_results(max_n=4, figsize=(15,5))
path = untar_data(URLs.BIWI_SAMPLE)
fn2ctr = (path/'centers.pkl').load()
biwi = DataBlock(blocks=(ImageBlock, PointBlock),
get_items = get_image_files,
get_y = lambda o:fn2ctr[o.name].flip(0),
splitter=RandomSplitter())
dbunch = biwi.databunch(path, batch_tfms=[*aug_transforms(size=(120,160)), Normalize(*imagenet_stats)])
dbunch.show_batch(max_n=9, vmin=1, vmax=30)
#TODO: look for attrs in after_item
dbunch.c = dbunch.after_item.c
dbunch.train_ds.loss_func = MSELossFlat()
learn = cnn_learner(dbunch, resnet34, opt_func=opt_func)
learn.fit_one_cycle(3, 1e-3)
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 3.999626 | 1.520517 | 00:03 |
1 | 3.387719 | 1.699020 | 00:03 |
2 | 3.338441 | 1.403246 | 00:03 |
learn.predict(dbunch.train_ds.items[0])
(tensor([56.6983, -2.8033]), tensor([-0.2913, -1.0467]), tensor([-0.2913, -1.0467]))
learn.show_results(max_n=4)