from fastai import *
from fastai.vision import *
DATA = untar_data(URLs.IMAGENETTE_160)
src = (ImageList.from_folder(DATA).filter_by_rand(0.3, seed=42)
.split_by_folder(valid='val')
.label_from_folder()
.transform(([flip_lr(p=0.5)], []), size=160))
data = (src.databunch(bs=64, num_workers=6)
.normalize(imagenet_stats))
data
ImageDataBunch; Train: LabelList (3798 items) x: ImageList Image (3, 160, 160),Image (3, 160, 160),Image (3, 160, 160),Image (3, 160, 160),Image (3, 160, 160) y: CategoryList n03028079,n03028079,n03028079,n03028079,n03028079 Path: /home/user/.fastai/data/imagenette-160; Valid: LabelList (161 items) x: ImageList Image (3, 160, 160),Image (3, 160, 160),Image (3, 160, 160),Image (3, 160, 160),Image (3, 160, 160) y: CategoryList n03028079,n03028079,n03028079,n03028079,n03028079 Path: /home/user/.fastai/data/imagenette-160; Test: None
from fastai import layers
def conv_layer(ni:int, nf:int, ks:int=3, stride:int=1, padding:int=None, bias:bool=None, is_1d:bool=False,
norm_type:Optional[NormType]=NormType.Batch, use_activ:bool=True, activ_fn:Callable=None, leaky:float=None,
transpose:bool=False, init:Callable=nn.init.kaiming_normal_, self_attention:bool=False):
"Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers."
activ_fn = ifnone(activ_fn, partial(relu, inplace=True, leaky=leaky))
if padding is None: padding = (ks-1)//2 if not transpose else 0
bn = norm_type in (NormType.Batch, NormType.BatchZero)
if bias is None: bias = not bn
conv_func = nn.ConvTranspose2d if transpose else nn.Conv1d if is_1d else nn.Conv2d
conv = init_default(conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding), init)
if norm_type==NormType.Weight: conv = weight_norm(conv)
elif norm_type==NormType.Spectral: conv = spectral_norm(conv)
layers = [conv]
if use_activ: layers.append(activ_fn())
if bn: layers.append((nn.BatchNorm1d if is_1d else nn.BatchNorm2d)(nf))
if self_attention: layers.append(SelfAttention(nf))
return nn.Sequential(*layers)
def simple_cnn(data, actns:Collection[int], kernel_szs:Collection[int]=None,
strides:Collection[int]=None, bn=False, activ_fn=None,
lin_ftrs:Optional[Collection[int]]=None, ps:Floats=0.5,
concat_pool:bool=True, bn_final:bool=False) -> nn.Sequential:
"CNN with `conv_layer` defined by `actns`, `kernel_szs` and `strides`, plus batchnorm if `bn`."
nl = len(actns)-1
kernel_szs = ifnone(kernel_szs, [3]*nl)
strides = ifnone(strides , [2]*nl)
layers = [conv_layer(actns[i], actns[i+1], kernel_szs[i], stride=strides[i],
norm_type=(NormType.Batch if bn and i<(len(strides)-1) else None), activ_fn=activ_fn) for i in range_of(strides)]
nf_head = actns[-1] * (2 if concat_pool else 1)
head = create_head(nf_head, data.c, lin_ftrs=lin_ftrs, ps=ps, concat_pool=concat_pool, bn_final=bn_final)
return nn.Sequential(*layers, head)
actns = [3,64,64,128,128,256,256,512,512]
strides = [1,2]*(len(actns)//2)
mdl_relu = simple_cnn(data, actns=actns, strides=strides)
mdl_relu
Sequential( (0): Sequential( (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace=True) ) (1): Sequential( (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (1): ReLU(inplace=True) ) (2): Sequential( (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace=True) ) (3): Sequential( (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (1): ReLU(inplace=True) ) (4): Sequential( (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace=True) ) (5): Sequential( (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (1): ReLU(inplace=True) ) (6): Sequential( (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace=True) ) (7): Sequential( (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (1): ReLU(inplace=True) ) (8): Sequential( (0): AdaptiveConcatPool2d( (ap): AdaptiveAvgPool2d(output_size=1) (mp): AdaptiveMaxPool2d(output_size=1) ) (1): Flatten() (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (3): Dropout(p=0.25, inplace=False) (4): Linear(in_features=1024, out_features=512, bias=True) (5): ReLU(inplace=True) (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (7): Dropout(p=0.5, inplace=False) (8): Linear(in_features=512, out_features=10, bias=True) ) )
lrn = Learner(data, mdl_relu, metrics=[accuracy,top_k_accuracy])
lrn.fit_one_cycle(5, 1e-3)
epoch | train_loss | valid_loss | accuracy | top_k_accuracy | time |
---|---|---|---|---|---|
0 | 2.142710 | 2.980025 | 0.223602 | 0.739130 | 00:13 |
1 | 1.928795 | 1.946811 | 0.347826 | 0.813665 | 00:13 |
2 | 1.725688 | 1.865101 | 0.391304 | 0.826087 | 00:13 |
3 | 1.494881 | 1.939941 | 0.416149 | 0.795031 | 00:13 |
4 | 1.306941 | 1.135048 | 0.633540 | 0.950311 | 00:13 |
class Mish(nn.Module):
def forward(self, x):
return x * torch.tanh(F.softplus(x))
mdl_mish = simple_cnn(data, actns=actns, strides=strides, activ_fn=Mish)
mdl_mish
Sequential( (0): Sequential( (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): Mish() ) (1): Sequential( (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (1): Mish() ) (2): Sequential( (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): Mish() ) (3): Sequential( (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (1): Mish() ) (4): Sequential( (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): Mish() ) (5): Sequential( (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (1): Mish() ) (6): Sequential( (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): Mish() ) (7): Sequential( (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (1): Mish() ) (8): Sequential( (0): AdaptiveConcatPool2d( (ap): AdaptiveAvgPool2d(output_size=1) (mp): AdaptiveMaxPool2d(output_size=1) ) (1): Flatten() (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (3): Dropout(p=0.25, inplace=False) (4): Linear(in_features=1024, out_features=512, bias=True) (5): ReLU(inplace=True) (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (7): Dropout(p=0.5, inplace=False) (8): Linear(in_features=512, out_features=10, bias=True) ) )
lrn = Learner(data, mdl_mish, metrics=[accuracy,top_k_accuracy])
lrn.fit_one_cycle(5, 1e-3)
epoch | train_loss | valid_loss | accuracy | top_k_accuracy | time |
---|---|---|---|---|---|
0 | 2.076530 | 2.108528 | 0.304348 | 0.763975 | 00:16 |
1 | 1.843536 | 2.559659 | 0.267081 | 0.763975 | 00:15 |
2 | 1.580714 | 1.727570 | 0.360248 | 0.875776 | 00:15 |
3 | 1.308276 | 0.981897 | 0.658385 | 0.968944 | 00:15 |
4 | 1.075476 | 0.898451 | 0.708075 | 0.962733 | 00:15 |
Not sure this is the right way to create a JIT module
class MishJit(torch.jit.ScriptModule):
# Note: No self for forward or you get an error
@torch.jit.script
def forward(x):
return x * torch.tanh(F.softplus(x))
This seems to be the recommended way:
MishJit = lambda: torch.jit.script(Mish())
mdl_mishjit = simple_cnn(data, actns=actns, strides=strides, activ_fn=MishJit)
mdl_mishjit
Sequential( (0): Sequential( (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): MishJit() ) (1): Sequential( (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (1): MishJit() ) (2): Sequential( (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): MishJit() ) (3): Sequential( (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (1): MishJit() ) (4): Sequential( (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): MishJit() ) (5): Sequential( (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (1): MishJit() ) (6): Sequential( (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): MishJit() ) (7): Sequential( (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (1): MishJit() ) (8): Sequential( (0): AdaptiveConcatPool2d( (ap): AdaptiveAvgPool2d(output_size=1) (mp): AdaptiveMaxPool2d(output_size=1) ) (1): Flatten() (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (3): Dropout(p=0.25, inplace=False) (4): Linear(in_features=1024, out_features=512, bias=True) (5): ReLU(inplace=True) (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (7): Dropout(p=0.5, inplace=False) (8): Linear(in_features=512, out_features=10, bias=True) ) )
lrn = Learner(data, mdl_mishjit, metrics=[accuracy,top_k_accuracy])
lrn.fit_one_cycle(5, 1e-3)
epoch | train_loss | valid_loss | accuracy | top_k_accuracy | time |
---|---|---|---|---|---|
0 | 2.008492 | 2.380090 | 0.304348 | 0.726708 | 00:16 |
1 | 1.768086 | 1.570313 | 0.472050 | 0.888199 | 00:15 |
2 | 1.452909 | 1.287506 | 0.534162 | 0.944099 | 00:15 |
3 | 1.180452 | 0.982612 | 0.664596 | 0.956522 | 00:15 |
4 | 0.939465 | 0.796275 | 0.732919 | 0.968944 | 00:16 |
Doesn't look like any performance gain from JIT.
But you can see some stuff that's happening under the hood (code taken from https://github.com/pytorch/pytorch/blob/master/test/test_jit.py#L233):
mj = MishJit()
mj.graph
graph(%x.1 : Tensor): %6 : int = prim::Constant[value=20]() %5 : int = prim::Constant[value=1]() %7 : Tensor = aten::softplus(%x.1, %5, %6) # <ipython-input-14-d5e1efb10fcc>:5:31 %8 : Tensor = aten::tanh(%7) # <ipython-input-14-d5e1efb10fcc>:5:20 %9 : Tensor = aten::mul(%x.1, %8) # <ipython-input-14-d5e1efb10fcc>:5:16 return (%9)
ds = mj.forward.get_debug_state()
fwd_plan = list(ds.execution_plans.values())[0]
ges = list(fwd_plan.code.grad_executor_states())
assert len(ges)==1
bwd_plan = ges[0]
bwd_plan.graph
graph(%0 : Tensor, %1 : Tensor, %2 : Tensor, %3 : Tensor, %4 : int[]?, %5 : int[]?): %6 : int = prim::Constant[value=1]() # <string>:154:39 %grad_self.1 : Tensor, %grad_other.1 : Tensor = prim::GradOf[name="aten::mul"](%0) block0(): %9 : Tensor = aten::mul(%0, %3) # <string>:11:30 %grad_self.2 : Tensor = aten::_grad_sum_to_size(%9, %4) # <string>:11:30 %11 : Tensor = aten::mul(%0, %2) # <string>:12:31 %grad_other.2 : Tensor = aten::_grad_sum_to_size(%11, %5) # <string>:12:31 -> (%grad_self.2, %grad_other.2) %13 : Tensor = prim::AutogradAdd(%1, %grad_other.1) %14 : Tensor = prim::GradOf[name="aten::tanh"](%13) block0(): %15 : Tensor = aten::mul(%3, %3) # <string>:154:43 %16 : Tensor = aten::neg(%15) # <string>:18:10 %17 : Tensor = aten::add(%16, %6, %6) # <string>:18:10 %18 : Tensor = aten::mul(%13, %17) # <string>:154:24 -> (%18) return (%grad_self.1, %14)
# Profiler doesn't like multiple workers
data_prof = (src.databunch(bs=64, num_workers=0)
.normalize(imagenet_stats))
lrn = Learner(data_prof, mdl_mish, metrics=[accuracy,top_k_accuracy])
with torch.autograd.profiler.profile(use_cuda=True) as prof_mish:
lrn.fit_one_cycle(3, 1e-3)
epoch | train_loss | valid_loss | accuracy | top_k_accuracy | time |
---|---|---|---|---|---|
0 | 2.090991 | 2.282425 | 0.242236 | 0.776398 | 00:19 |
1 | 1.843356 | 1.696035 | 0.397516 | 0.844720 | 00:18 |
2 | 1.575372 | 1.370669 | 0.515528 | 0.913043 | 00:19 |
print(prof_mish.key_averages().table(sort_by="cuda_time_total", row_limit=20))
----------------------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg CUDA total % CUDA total CUDA time avg Number of Calls ----------------------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- CudnnConvolutionBackward 0.07% 32.445ms 1.33% 593.029ms 418.806us 17.95% 18.293s 12.919ms 1416 cudnn_convolution_backward 1.26% 560.585ms 1.26% 560.585ms 395.893us 17.94% 18.289s 12.916ms 1416 conv2d 0.02% 8.777ms 0.52% 230.733ms 155.063us 10.68% 10.883s 7.314ms 1488 convolution 0.02% 9.372ms 0.50% 221.956ms 149.164us 10.67% 10.878s 7.311ms 1488 _convolution 0.04% 18.914ms 0.48% 212.584ms 142.865us 10.67% 10.874s 7.308ms 1488 cudnn_convolution 0.42% 188.122ms 0.42% 188.122ms 126.426us 10.66% 10.865s 7.302ms 1488 mul 0.26% 117.852ms 0.26% 117.852ms 25.123us 3.96% 4.035s 860.178us 4691 MulBackward0 0.09% 40.653ms 0.29% 130.582ms 92.219us 2.60% 2.653s 1.874ms 1416 add 0.12% 54.580ms 0.12% 54.580ms 30.732us 1.32% 1.343s 756.336us 1776 TanhBackward 0.06% 24.925ms 0.17% 77.403ms 54.663us 1.30% 1.329s 938.273us 1416 SoftplusBackward 0.06% 26.375ms 0.16% 72.244ms 51.020us 1.30% 1.328s 937.749us 1416 tanh_backward 0.12% 52.478ms 0.12% 52.478ms 37.060us 1.30% 1.324s 935.328us 1416 softplus_backward 0.10% 45.869ms 0.10% 45.869ms 32.393us 1.30% 1.324s 934.842us 1416 div_ 2.90% 1.296s 2.90% 1.296s 109.706us 1.26% 1.282s 108.573us 11811 softplus 0.06% 26.970ms 0.06% 26.970ms 18.125us 1.03% 1.046s 702.681us 1488 stack 2.32% 1.035s 2.32% 1.035s 5.474ms 1.01% 1.030s 5.449ms 189 tanh 0.05% 24.128ms 0.05% 24.128ms 16.215us 0.95% 967.134ms 649.956us 1488 clone 2.05% 915.029ms 2.05% 915.029ms 76.329us 0.90% 918.066ms 76.582us 11988 contiguous 1.48% 660.752ms 1.83% 814.909ms 53.521us 0.80% 818.778ms 53.775us 15226 pin_memory 1.07% 478.388ms 1.09% 485.908ms 1.306ms 0.48% 487.357ms 1.310ms 372 ----------------------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- Self CPU time total: 44.610s CUDA time total: 101.941s
lrn = Learner(data_prof, mdl_relu, metrics=[accuracy,top_k_accuracy])
with torch.autograd.profiler.profile(use_cuda=True) as prof_relu:
lrn.fit_one_cycle(3, 1e-3)
epoch | train_loss | valid_loss | accuracy | top_k_accuracy | time |
---|---|---|---|---|---|
0 | 2.144278 | 3.227488 | 0.198758 | 0.503106 | 00:16 |
1 | 1.964869 | 1.908126 | 0.310559 | 0.795031 | 00:16 |
2 | 1.797392 | 1.615858 | 0.447205 | 0.857143 | 00:16 |
print(prof_relu.key_averages().table(sort_by="cuda_time_total", row_limit=20))
----------------------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg CUDA total % CUDA total CUDA time avg Number of Calls ----------------------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- CudnnConvolutionBackward 0.08% 30.636ms 1.59% 590.901ms 417.303us 20.24% 18.338s 12.951ms 1416 cudnn_convolution_backward 1.51% 560.265ms 1.51% 560.265ms 395.668us 20.24% 18.334s 12.948ms 1416 conv2d 0.02% 8.686ms 0.64% 237.165ms 159.385us 12.05% 10.915s 7.335ms 1488 convolution 0.03% 9.691ms 0.62% 228.479ms 153.548us 12.04% 10.910s 7.332ms 1488 _convolution 0.06% 21.295ms 0.59% 218.788ms 147.035us 12.04% 10.906s 7.329ms 1488 cudnn_convolution 0.52% 193.073ms 0.52% 193.073ms 129.753us 12.03% 10.897s 7.323ms 1488 ReluBackward1 0.08% 30.172ms 0.23% 86.354ms 54.208us 1.47% 1.335s 838.307us 1593 threshold_backward 0.15% 56.182ms 0.15% 56.182ms 35.268us 1.47% 1.328s 833.722us 1593 div_ 3.61% 1.340s 3.61% 1.340s 113.430us 1.46% 1.327s 112.313us 11811 stack 2.82% 1.045s 2.82% 1.045s 5.529ms 1.16% 1.046s 5.537ms 189 clone 2.56% 949.062ms 2.56% 949.062ms 79.168us 1.05% 954.423ms 79.615us 11988 relu_ 0.06% 24.079ms 0.06% 24.079ms 14.384us 1.05% 952.122ms 568.771us 1674 contiguous 1.78% 661.230ms 2.26% 838.856ms 55.090us 0.93% 843.215ms 55.376us 15227 pin_memory 1.30% 481.157ms 1.31% 487.426ms 1.310ms 0.54% 488.730ms 1.314ms 372 to 81.89% 30.382s 81.93% 30.400s 2.374ms 0.41% 371.878ms 29.037us 12807 empty_like 0.22% 80.294ms 0.48% 177.626ms 15.665us 0.19% 176.648ms 15.579us 11339 add_ 0.55% 204.223ms 0.55% 204.223ms 15.606us 0.16% 142.126ms 10.861us 13086 slice 0.34% 124.978ms 0.34% 124.978ms 5.426us 0.13% 119.899ms 5.206us 23032 empty 0.32% 117.106ms 0.32% 117.106ms 9.110us 0.13% 115.692ms 9.000us 12854 torch::autograd::AccumulateGrad 0.21% 79.715ms 0.53% 197.215ms 46.425us 0.12% 108.346ms 25.505us 4248 ----------------------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- Self CPU time total: 37.103s CUDA time total: 90.599s
keys = set([ev.key for prof in [prof_mish,prof_relu] for ev in prof.function_events])
ka_mish,ka_relu = prof_mish.key_averages(), prof_relu.key_averages()
ka_mish,ka_relu = [{ev.key: ev for ev in prof.key_averages()} for prof in [prof_mish,prof_relu]]
keys = set(list(ka_mish.keys()) + list(ka_relu.keys()))
keys -= {'to','contiguous','pin_memory'} # Dataloader stuff
ev_mish,ev_relu = [],[]
for key in keys:
if ( key not in ka_mish or key not in ka_relu or
ka_mish[key].count != ka_relu[key].count or
np.abs(ka_mish[key].cuda_time - ka_relu[key].cuda_time) > 100): # cuda_time in us
if key in ka_mish: ev_mish.append(ka_mish[key])
if key in ka_relu: ev_relu.append(ka_relu[key])
print(torch.autograd.profiler.EventList(ev_mish).table(sort_by="cuda_time_total"))
---------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg CUDA total % CUDA total CUDA time avg Number of Calls ---------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- mul 12.55% 123.558ms 12.55% 123.558ms 26.339us 25.25% 4.034s 860.011us 4691 MulBackward0 4.14% 40.766ms 13.87% 136.489ms 96.391us 16.60% 2.653s 1.874ms 1416 add 5.54% 54.531ms 5.54% 54.531ms 30.704us 8.41% 1.343s 756.388us 1776 TanhBackward 2.60% 25.567ms 8.09% 79.661ms 56.258us 8.33% 1.331s 940.024us 1416 SoftplusBackward 2.82% 27.774ms 7.65% 75.281ms 53.164us 8.32% 1.329s 938.424us 1416 tanh_backward 5.50% 54.094ms 5.50% 54.094ms 38.202us 8.30% 1.327s 937.036us 1416 softplus_backward 4.83% 47.506ms 4.83% 47.506ms 33.550us 8.29% 1.324s 935.374us 1416 softplus 2.80% 27.553ms 2.80% 27.553ms 18.517us 6.56% 1.048s 704.625us 1488 tanh 2.23% 21.921ms 2.23% 21.921ms 14.732us 6.06% 967.746ms 650.367us 1488 empty_like 9.27% 91.280ms 21.58% 212.374ms 18.731us 1.33% 211.789ms 18.680us 11338 add_ 20.93% 205.969ms 20.93% 205.969ms 15.711us 0.92% 147.439ms 11.246us 13110 empty 13.35% 131.339ms 13.35% 131.339ms 10.219us 0.83% 132.016ms 10.271us 12853 slice 12.08% 118.912ms 12.08% 118.912ms 5.163us 0.71% 112.897ms 4.902us 23030 ReluBackward1 0.35% 3.435ms 1.01% 9.969ms 56.324us 0.06% 10.011ms 56.562us 177 threshold_backward 0.66% 6.534ms 0.66% 6.534ms 36.915us 0.04% 6.778ms 38.293us 177 relu_ 0.35% 3.427ms 0.35% 3.427ms 18.426us 0.00% 558.594us 3.003us 186 ---------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- Self CPU time total: 984.168ms CUDA time total: 15.979s
print(torch.autograd.profiler.EventList(ev_relu).table(sort_by="cuda_time_total"))
---------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg CUDA total % CUDA total CUDA time avg Number of Calls ---------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- ReluBackward1 5.15% 33.543ms 13.58% 88.343ms 55.457us 31.57% 1.335s 838.335us 1593 threshold_backward 8.42% 54.799ms 8.42% 54.799ms 34.400us 31.40% 1.328s 833.837us 1593 relu_ 3.73% 24.283ms 3.73% 24.283ms 14.506us 22.52% 952.441ms 568.961us 1674 empty_like 9.78% 63.670ms 29.12% 189.489ms 16.714us 4.46% 188.768ms 16.651us 11337 add_ 30.49% 198.388ms 30.49% 198.388ms 15.160us 3.34% 141.360ms 10.802us 13086 empty 20.89% 135.937ms 20.89% 135.937ms 10.577us 3.23% 136.804ms 10.645us 12852 slice 19.79% 128.792ms 19.79% 128.792ms 5.593us 2.92% 123.493ms 5.363us 23028 add 1.15% 7.458ms 1.15% 7.458ms 20.717us 0.45% 19.196ms 53.321us 360 mul 0.59% 3.861ms 0.59% 3.861ms 10.406us 0.09% 3.815ms 10.284us 371 ---------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- Self CPU time total: 650.733ms CUDA time total: 4.230s
The multiple kernel launches for Mish add up (Softplus, Tanh and some of the muls, from call count they're not all Mish).
You can't do anything inplace as this will cause errors in gradient calculation:
class MishInplace(nn.Module):
def forward(self, x):
return x.mul_(torch.tanh_(F.softplus(x)))
inp = torch.rand(5, requires_grad=True)
mdl = MishInplace()
out = torch.sum(mdl(inp))
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) <ipython-input-40-1667fcc83abe> in <module> 1 inp = torch.rand(5, requires_grad=True) 2 mdl = MishInplace() ----> 3 out = torch.sum(mdl(inp)) ~/.conda/envs/fastai/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs) 545 result = self._slow_forward(*input, **kwargs) 546 else: --> 547 result = self.forward(*input, **kwargs) 548 for hook in self._forward_hooks.values(): 549 hook_result = hook(self, input, result) <ipython-input-39-5adf038e9e50> in forward(self, x) 1 class MishInplace(nn.Module): 2 def forward(self, x): ----> 3 return x.mul_(torch.tanh_(F.softplus(x))) RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.
Autograd functions allow such inplace operations through saving stuff in forward to be used in backwards. But you have to do the gradient calculations then. Being able to manually compute gradients would also be needed for a CUDA implementation as you can't really re-use the existing stuff for Softplus/Tanh (you'd have to have separate kernel launches and be back to where the straight Python is).
This seems to be the idea:
class MishFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, inp):
ctx.mark_dirty(inp)
tsp = torch.tanh_(F.softplus(inp))
ctx.save_for_backward(inp, tsp)
return x.mul_(tsp)
@staticmethod
def backward(ctx, grad_out):
inp,tsp = ctx.saved_tensors
grad_tsp = torch.autograd.grad(grad_out, tsp)
grad_inp = ...
return grad_inp
Playing around with some of the autograd stuff:
grad = lambda o,i: torch.autograd.grad(o, i, retain_graph=True) # Need to use retain graph or you can only call once
x = torch.rand(3)
inp = x.clone().requires_grad_(True)
tsp = torch.tanh(F.softplus(inp))
out = x.mul(tsp)
l = torch.sum(out)
grad(l, inp) # Gradient of loss w.r.t input
(tensor([0.0430, 0.0735, 0.1787]),)
grad(l, tsp) # Gradient of loss w.r.t intermediate
(tensor([0.1375, 0.2415, 0.8298]),)
grad_out = torch.autograd.grad(l, out)
grad_out
(tensor([1., 1., 1.]),)
txt = """Mish,87.48%,0.3967,-
Swish-1,87.32%,0.414,-0.3975 to 0.0844
E-Swish (?=1.75),87.49%,0.411,-0.2261 to 0.2539
GELU,87.37%,0.472,-0.3682 to 0.1499
ReLU,86.66%,0.584,-1.1179 to -0.5247
ELU(?=1.0),86.41%,0.3371,-1.2931 to -0.8556
Leaky ReLU(?=0.3),86.85%,0.4569,-0.8860 to -0.3774
RReLU,86.87%,0.4478,-0.8623 to -0.3595
SELU,83.91%,0.5995,-3.8713 to -3.2670
SoftPlus(? = 1),83.00%,1.4015,-4.7778 to -4.1735
HardShrink(? = 0.5),75.03%,0.98345,-12.8948 to -12.0035
Hardtanh,82.78%,0.4491,-4.9522 to -4.4486"""
sp = [l.split(',') for l in txt.split('\n')]
d = {n:v for n,v in zip(['Name','Acc','SD','CI'], zip(*sp))}
df = pd.DataFrame.from_dict(d)
df
Name | Acc | SD | CI | |
---|---|---|---|---|
0 | Mish | 87.48% | 0.3967 | - |
1 | Swish-1 | 87.32% | 0.414 | -0.3975 to 0.0844 |
2 | E-Swish (?=1.75) | 87.49% | 0.411 | -0.2261 to 0.2539 |
3 | GELU | 87.37% | 0.472 | -0.3682 to 0.1499 |
4 | ReLU | 86.66% | 0.584 | -1.1179 to -0.5247 |
5 | ELU(?=1.0) | 86.41% | 0.3371 | -1.2931 to -0.8556 |
6 | Leaky ReLU(?=0.3) | 86.85% | 0.4569 | -0.8860 to -0.3774 |
7 | RReLU | 86.87% | 0.4478 | -0.8623 to -0.3595 |
8 | SELU | 83.91% | 0.5995 | -3.8713 to -3.2670 |
9 | SoftPlus(? = 1) | 83.00% | 1.4015 | -4.7778 to -4.1735 |
10 | HardShrink(? = 0.5) | 75.03% | 0.98345 | -12.8948 to -12.0035 |
11 | Hardtanh | 82.78% | 0.4491 | -4.9522 to -4.4486 |
df.Name = df.Name.apply(lambda s: s.split('(')[0].strip())
df.Acc = df.Acc.str.slice(stop=5).astype(np.float)
df['ci_lo'] = df[1:].CI.apply(lambda s: s.split(' to ')[0]).astype(np.float) + df.iloc[0,1]
df['ci_hi'] = df[1:].CI.apply(lambda s: s.split(' to ')[1]).astype(np.float) + df.iloc[0,1]
df.ci_lo[0] = 87.3085
df.ci_hi[0] = 87.6515
df = df.drop(index=[10]) # Outlier
df[['Name','Acc','ci_lo','ci_hi']]
Name | Acc | ci_lo | ci_hi | |
---|---|---|---|---|
0 | Mish | 87.48 | 87.3085 | 87.6515 |
1 | Swish-1 | 87.32 | 87.0825 | 87.5644 |
2 | E-Swish (?=1.75) | 87.49 | 87.2539 | 87.7339 |
3 | GELU | 87.37 | 87.1118 | 87.6299 |
4 | ReLU | 86.66 | 86.3621 | 86.9553 |
5 | ELU(?=1.0) | 86.41 | 86.1869 | 86.6244 |
6 | Leaky ReLU(?=0.3) | 86.85 | 86.5940 | 87.1026 |
7 | RReLU | 86.87 | 86.6177 | 87.1205 |
8 | SELU | 83.91 | 83.6087 | 84.2130 |
9 | SoftPlus(? = 1) | 83.00 | 82.7022 | 83.3065 |
10 | HardShrink(? = 0.5) | 75.03 | 74.5852 | 75.4765 |
11 | Hardtanh | 82.78 | 82.5278 | 83.0314 |
errs = df[['ci_lo','ci_hi']].to_numpy() - df['Acc'].to_numpy()[:,None]
errs = np.abs(errs).transpose()
fig = plt.figure(figsize=(7,3))
plt.errorbar(df.Acc, range(len(df)), xerr=errs, ls='', marker='o', ms='3', capsize=1, capthick=0.5);
plt.yticks(range(len(df)), df.Name);
plt.gca().invert_yaxis()