%reload_ext autoreload
%autoreload 2
#export
from nb_005b import *
(See final section of notebook for one-time data processing steps.)
PATH = Path('../../data/carvana')
PATH_PNG = PATH/'train_masks_png'
PATH_X_FULL = PATH/'train'
PATH_X_128 = PATH/'train-128'
PATH_Y_FULL = PATH_PNG
PATH_Y_128 = PATH/'train_masks-128'
# start with the 128x128 images
PATH_X = PATH_X_128
PATH_Y = PATH_Y_128
img_f = next(PATH_X.iterdir())
open_image(img_f).show()
#export
class ImageMask(Image):
"Class for image segmentation target"
def clone(self)->'ImageBase':
"Clones this item"
return self.__class__(self.px.clone())
def lighting(self, func:LightingFunc, *args:Any, **kwargs:Any)->'Image': return self
def refresh(self):
self.sample_kwargs['mode'] = 'nearest'
return super().refresh()
@property
def data(self)->TensorImage:
"Returns this images pixels as a tensor"
return self.px.long()
def open_mask(fn:PathOrStr) -> ImageMask:
"Return `ImageMask` object create from mask in file `fn`"
return ImageMask(pil2tensor(PIL.Image.open(fn)).float())
def get_y_fn(x_fn): return PATH_Y/f'{x_fn.name[:-4]}_mask.png'
img_y_f = get_y_fn(img_f)
y = open_mask(img_y_f)
y.show()
#export
# Same as `show_image`, but renamed with _ prefix
def _show_image(img:Image, ax:plt.Axes=None, figsize:tuple=(3,3), hide_axis:bool=True, cmap:str='binary',
alpha:float=None) -> plt.Axes:
if ax is None: fig,ax = plt.subplots(figsize=figsize)
ax.imshow(image2np(img), cmap=cmap, alpha=alpha)
if hide_axis: ax.axis('off')
return ax
def show_image(x:Image, y:Image=None, ax:plt.Axes=None, figsize:tuple=(3,3), alpha:float=0.5,
hide_axis:bool=True, cmap:str='viridis'):
ax1 = _show_image(x, ax=ax, hide_axis=hide_axis, cmap=cmap)
if y is not None: _show_image(y, ax=ax1, alpha=alpha, hide_axis=hide_axis, cmap=cmap)
if hide_axis: ax1.axis('off')
def _show(self:Image, ax:plt.Axes=None, y:Image=None, **kwargs):
if y is not None: y=y.data
return show_image(self.data, ax=ax, y=y, **kwargs)
Image.show = _show
x = open_image(img_f)
x.show(y=y)
x.shape
y.shape
#export
class DatasetTfm(Dataset):
"`Dataset` that applies a list of transforms to every item drawn"
def __init__(self, ds:Dataset, tfms:TfmList=None, tfm_y:bool=False, **kwargs:Any):
"this dataset will apply `tfms` to `ds`"
self.ds,self.tfms,self.kwargs,self.tfm_y = ds,tfms,kwargs,tfm_y
self.y_kwargs = {**self.kwargs, 'do_resolve':False}
def __len__(self)->int: return len(self.ds)
def __getitem__(self,idx:int)->Tuple[Image,Any]:
"returns tfms(x),y"
x,y = self.ds[idx]
x = apply_tfms(self.tfms, x, **self.kwargs)
if self.tfm_y: y = apply_tfms(self.tfms, y, **self.y_kwargs)
return x, y
def __getattr__(self,k):
"passthrough access to wrapped dataset attributes"
return getattr(self.ds, k)
import nb_002b
nb_002b.DatasetTfm = DatasetTfm
#export
class SegmentationDataset(DatasetBase):
"A dataset for segmentation task"
def __init__(self, x:Collection[PathOrStr], y:Collection[PathOrStr]):
assert len(x)==len(y)
self.x,self.y = np.array(x),np.array(y)
def __getitem__(self, i:int) -> Tuple[Image,ImageMask]:
return open_image(self.x[i]), open_mask(self.y[i])
def get_datasets(path):
x_fns = [o for o in path.iterdir() if o.is_file()]
y_fns = [get_y_fn(o) for o in x_fns]
mask = [o>=1008 for o in range(len(x_fns))]
arrs = arrays_split(mask, x_fns, y_fns)
return [SegmentationDataset(*o) for o in arrs]
train_ds,valid_ds = get_datasets(PATH_X_128)
train_ds,valid_ds
x,y = next(iter(train_ds))
x.shape, y.shape, type(x), type(y)
size=128
def get_tfm_datasets(size):
datasets = get_datasets(PATH_X_128 if size<=128 else PATH_X_FULL)
tfms = get_transforms(do_flip=True, max_rotate=4, max_lighting=0.2)
return transform_datasets(train_ds, valid_ds, tfms=tfms, tfm_y=True, size=size, padding_mode='border')
transform_datasets
train_tds,*_ = get_tfm_datasets(size)
_,axes = plt.subplots(1,4, figsize=(12,6))
for i, ax in enumerate(axes.flat):
imgx,imgy = train_tds[i]
imgx.show(ax, y=imgy)
default_norm,default_denorm = normalize_funcs(*imagenet_stats)
bs = 64
def get_data(size, bs):
return DataBunch.create(*get_tfm_datasets(size), bs=bs, tfms=default_norm)
data = get_data(size, bs)
#export
def show_xy_images(x:Tensor,y:Tensor,rows:int,figsize:tuple=(9,9)):
"Shows a selection of images and targets from a given batch."
fig, axs = plt.subplots(rows,rows,figsize=figsize)
for i, ax in enumerate(axs.flatten()): show_image(x[i], y=y[i], ax=ax)
plt.tight_layout()
x,y = next(iter(data.train_dl))
x,y = x.cpu(),y.cpu()
x = default_denorm(x)
show_xy_images(x,y,4, figsize=(9,9))
x.shape, y.shape
#export
class Debugger(nn.Module):
"A module to debug inside a model"
def forward(self,x:Tensor) -> Tensor:
set_trace()
return x
class StdUpsample(nn.Module):
"Standard upsample module"
def __init__(self, n_in:int, n_out:int):
super().__init__()
self.conv = conv2d_trans(n_in, n_out)
self.bn = nn.BatchNorm2d(n_out)
def forward(self, x:Tensor) -> Tensor:
return self.bn(F.relu(self.conv(x)))
def std_upsample_head(c, *nfs:Collection[int]) -> Model:
"Creates a sequence of upsample layers"
return nn.Sequential(
nn.ReLU(),
*(StdUpsample(nfs[i],nfs[i+1]) for i in range(4)),
conv2d_trans(nfs[-1], c)
)
head = std_upsample_head(2, 512,256,256,256,256)
head
#export
def dice(input:Tensor, targs:Tensor) -> Rank0Tensor:
"Dice coefficient metric for binary target"
n = targs.shape[0]
input = input.argmax(dim=1).view(n,-1)
targs = targs.view(n,-1)
intersect = (input*targs).sum().float()
union = (input+targs).sum().float()
return 2. * intersect / union
def accuracy(input:Tensor, targs:Tensor) -> Rank0Tensor:
"Accuracy"
n = targs.shape[0]
input = input.argmax(dim=1).view(n,-1)
targs = targs.view(n,-1)
return (input==targs).float().mean()
class CrossEntropyFlat(nn.CrossEntropyLoss):
"Same as `nn.CrossEntropyLoss`, but flattens input and target"
def forward(self, input:Tensor, target:Tensor) -> Rank0Tensor:
n,c,*_ = input.shape
return super().forward(input.view(n, c, -1), target.view(n, -1))
metrics=[accuracy, dice]
learn = ConvLearner(data, tvm.resnet34, 2, custom_head=head,
metrics=metrics, loss_fn=CrossEntropyFlat())
lr_find(learn)
learn.recorder.plot()
lr = 1e-1
learn.fit_one_cycle(10, slice(lr))
learn.unfreeze()
learn.save('0')
learn.load('0')
lr = 2e-2
learn.fit_one_cycle(10, slice(lr/100,lr))
x,y,py = learn.pred_batch()
py = py.argmax(dim=1).unsqueeze(1)
for i, ax in enumerate(plt.subplots(4,4,figsize=(10,10))[1].flat):
show_image(default_denorm(x[i].cpu()), py[i], ax=ax)
learn.save('1')
size=512
bs = 8
data = get_data(size, bs)
learn.data = data
learn.load('1')
learn.freeze()
lr = 2e-2
learn.fit_one_cycle(5, slice(lr))
learn.save('2')
learn.load('2')
lr = 2e-2
learn.unfreeze()
learn.fit_one_cycle(8, slice(lr/100,lr))
learn.save('3')
x,py = learn.pred_batch()
for i, ax in enumerate(plt.subplots(4,4,figsize=(10,10))[1].flat):
show_image(default_denorm(x[i].cpu()), py[i]>0, ax=ax)
def convert_img(fn): Image.open(fn).save(PATH_PNG/f'{fn.name[:-4]}.png')
def resize_img(fn, dirname):
Image.open(fn).resize((128,128)).save((fn.parent.parent)/dirname/fn.name)
def do_conversion():
PATH_PNG.mkdir(exist_ok=True)
PATH_X.mkdir(exist_ok=True)
PATH_Y.mkdir(exist_ok=True)
files = list((PATH/'train_masks').iterdir())
with ThreadPoolExecutor(8) as e: e.map(convert_img, files)
files = list((PATH_PNG).iterdir())
with ThreadPoolExecutor(8) as e: e.map(partial(resize_img, dirname='train_masks-128'), files)
files = list((PATH/'train').iterdir())
with ThreadPoolExecutor(8) as e: e.map(partial(resize_img, dirname='train-128'), files)