Copyright 2020 by Thomas Viehmann
This is the code for my blog post Visualize PyTorch models. The code has been made available thanks to my single github sponsor at the time of writing. Thank you!
I license this code with the CC-BY-SA 4.0 license. Please link to my blog post or the original github source (linked from the blog post) with the attribution notice.
Did you ever wish to get a concise picture of your PyTorch model's structure and found that too hard to get?
Recently, I did some work that involved looking at model structure in some detail. For my write-up, I wanted to get a diagram of some model structures. Even though it is a relatively common model, searching for a diagram didn't turn up something in the shape what I was looking for.
So how do can we get model structure for PyTorch models? The first stop probably is the neat string representation that PyTorch provides for nn.Modules
- even without doing anything, it'll also cover our custom models pretty well. It is, however not without shortcomings.
Let's look at TorchVision's ResNet18 basic block as an example.
import torchvision
m = torchvision.models.resnet18()
m.layer1[0]
BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) )
So we have two convs and two batch norms. But how are things connected? Is there one ReLU?
Looking at the forward method (you can get this using Python's inspect
module or ??
in IPython), we see some important details not in the summary:
import inspect
print(inspect.getsource(m.layer1[0].forward))
def forward(self, x: Tensor) -> Tensor: identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out
So we missed the entire residual bit. Also, there are two ReLUs. Arguably, it is wrong to re-use stateless modules like this. It'll haunt you when you do things like quantization (because it becomes stateful then due to the quantization parameters) and it's mixing things too much. If you want stateless, use the functional interface.
But so we can build a visualization based on JITed modules.
We recurse into calls to make subgraphs and we have to take some care that the edges connecting the subgraph to the outer graph need to be part of the outer graph, but other than that, it is very straightforward, even though the details are messy.
import graphviz
def make_graph(mod, classes_to_visit=None, classes_found=None, dot=None, prefix="",
input_preds=None,
parent_dot=None):
preds = {}
def find_name(i, self_input, suffix=None):
if i == self_input:
return suffix
cur = i.node().s("name")
if suffix is not None:
cur = cur + '.' + suffix
of = next(i.node().inputs())
return find_name(of, self_input, suffix=cur)
gr = mod.graph
toshow = []
# list(traced_model.graph.nodes())[0]
self_input = next(gr.inputs())
self_type = self_input.type().str().split('.')[-1]
preds[self_input] = (set(), set()) # inps, ops
if dot is None:
dot = graphviz.Digraph(format='svg', graph_attr={'label': self_type, 'labelloc': 't'})
#dot.attr('node', shape='box')
seen_inpnames = set()
seen_edges = set()
def add_edge(dot, n1, n2):
if (n1, n2) not in seen_edges:
seen_edges.add((n1, n2))
dot.edge(n1, n2)
def make_edges(pr, inpname, name, op, edge_dot=dot):
if op:
if inpname not in seen_inpnames:
seen_inpnames.add(inpname)
label_lines = [[]]
line_len = 0
for w in op:
if line_len >= 20:
label_lines.append([])
line_len = 0
label_lines[-1].append(w)
line_len += len(w) + 1
edge_dot.node(inpname, label='\n'.join([' '.join(w) for w in label_lines]), shape='box', style='rounded')
for p in pr:
add_edge(edge_dot, p, inpname)
add_edge(edge_dot, inpname, name)
else:
for p in pr:
add_edge(edge_dot, p, name)
for nr, i in enumerate(list(gr.inputs())[1:]):
name = prefix+'inp_'+i.debugName()
preds[i] = {name}, set()
dot.node(name, shape='ellipse')
if input_preds is not None:
pr, op = input_preds[nr]
make_edges(pr, 'inp_'+name, name, op, edge_dot=parent_dot)
def is_relevant_type(t):
kind = t.kind()
if kind == 'TensorType':
return True
if kind in ('ListType', 'OptionalType'):
return is_relevant_type(t.getElementType())
if kind == 'TupleType':
return any([is_relevant_type(tt) for tt in t.elements()])
return False
for n in gr.nodes():
only_first_ops = {'aten::expand_as'}
rel_inp_end = 1 if n.kind() in only_first_ops else None
relevant_inputs = [i for i in list(n.inputs())[:rel_inp_end] if is_relevant_type(i.type())]
relevant_outputs = [o for o in n.outputs() if is_relevant_type(o.type())]
if n.kind() == 'prim::CallMethod':
fq_submodule_name = '.'.join([nc for nc in list(n.inputs())[0].type().str().split('.') if not nc.startswith('__')])
submodule_type = list(n.inputs())[0].type().str().split('.')[-1]
submodule_name = find_name(list(n.inputs())[0], self_input)
name = prefix+'.'+n.output().debugName()
label = prefix+submodule_name+' (' + submodule_type + ')'
if classes_found is not None:
classes_found.add(fq_submodule_name)
if ((classes_to_visit is None and
(not fq_submodule_name.startswith('torch.nn') or
fq_submodule_name.startswith('torch.nn.modules.container')))
or (classes_to_visit is not None and
(submodule_type in classes_to_visit
or fq_submodule_name in classes_to_visit))):
# go into subgraph
sub_prefix = prefix+submodule_name+'.'
with dot.subgraph(name="cluster_"+name) as sub_dot:
sub_dot.attr(label=label)
submod = mod
for k in submodule_name.split('.'):
submod = getattr(submod, k)
make_graph(submod, dot=sub_dot, prefix=sub_prefix,
input_preds = [preds[i] for i in list(n.inputs())[1:]],
parent_dot=dot, classes_to_visit=classes_to_visit,
classes_found=classes_found)
for i, o in enumerate(n.outputs()):
preds[o] = {sub_prefix+f'out_{i}'}, set()
else:
dot.node(name, label=label, shape='box')
for i in relevant_inputs:
pr, op = preds[i]
make_edges(pr, prefix+i.debugName(), name, op)
for o in n.outputs():
preds[o] = {name}, set()
elif n.kind() == 'prim::CallFunction':
funcname = list(n.inputs())[0].type().__repr__().split('.')[-1]
name = prefix+'.'+n.output().debugName()
label = funcname
dot.node(name, label=label, shape='box')
for i in relevant_inputs:
pr, op = preds[i]
make_edges(pr, prefix+i.debugName(), name, op)
for o in n.outputs():
preds[o] = {name}, set()
else:
unseen_ops = {'prim::ListConstruct', 'prim::TupleConstruct', 'aten::index',
'aten::size', 'aten::slice', 'aten::unsqueeze', 'aten::squeeze',
'aten::to', 'aten::view', 'aten::permute', 'aten::transpose', 'aten::contiguous',
'aten::permute', 'aten::Int', 'prim::TupleUnpack', 'prim::ListUnpack', 'aten::unbind',
'aten::select', 'aten::detach', 'aten::stack', 'aten::reshape', 'aten::split_with_sizes',
'aten::cat', 'aten::expand', 'aten::expand_as', 'aten::_shape_as_tensor',
}
absorbing_ops = ('aten::size', 'aten::_shape_as_tensor') # probably also partially absorbing ops. :/
if False:
print(n.kind())
#DEBUG['kinds'].add(n.kind())
#DEBUG[n.kind()] = n
label = n.kind().split('::')[-1].rstrip('_')
name = prefix+'.'+relevant_outputs[0].debugName()
dot.node(name, label=label, shape='box', style='rounded')
for i in relevant_inputs:
pr, op = preds[i]
make_edges(pr, prefix+i.debugName(), name, op)
for o in n.outputs():
preds[o] = {name}, set()
if True:
label = n.kind().split('::')[-1].rstrip('_')
pr, op = set(), set()
for i in relevant_inputs:
apr, aop = preds[i]
pr |= apr
op |= aop
if pr and n.kind() not in unseen_ops:
print(n.kind(), n)
if n.kind() in absorbing_ops:
pr, op = set(), set()
elif len(relevant_inputs) > 0 and len(relevant_outputs) > 0 and n.kind() not in unseen_ops:
op.add(label)
for o in n.outputs():
preds[o] = pr, op
for i, o in enumerate(gr.outputs()):
name = prefix+f'out_{i}'
dot.node(name, shape='ellipse')
pr, op = preds[o]
make_edges(pr, 'inp_'+name, name, op)
return dot
Let's apply it! These are the pictures from my blog post along with the code that generated them.
The following code is from the transformers library (Copyright 2018- The Hugging Face team. Apache Licensed.).
import transformers
from transformers import BertModel, BertTokenizer, BertConfig
import numpy
import torch
enc = BertTokenizer.from_pretrained("bert-base-uncased")
# Tokenizing input text
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = enc.tokenize(text)
# Masking one of the input tokens
masked_index = 8
tokenized_text[masked_index] = '[MASK]'
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
# Creating a dummy input
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
dummy_input = [tokens_tensor, segments_tensors]
# If you are instantiating the model with `from_pretrained` you can also easily set the TorchScript flag
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)
model.eval()
for p in model.parameters():
p.requires_grad_(False)
transformers.__version__
'2.11.0'
# Creating the trace
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
traced_model.eval()
for p in traced_model.parameters():
p.requires_grad_(False)
if 0:
# resolvign functions?
t = fn.type()
def lookup(fn):
n = str(fn.type()).split('.')[1:]
res = globals()[n[0]]
for nc in n[1:]:
res = getattr(res, nc)
return res
lookup(fn).graph
d = make_graph(traced_model, classes_to_visit={'BertEncoder'})
d.render('bert_model')
d
aten::rsub %668 : Float(1:14, 1:14, 1:14, 14:1) = aten::rsub(%665, %666, %667) # /usr/local/lib/python3.8/dist-packages/torch/tensor.py:395:0 aten::mul %attention_mask : Float(1:14, 1:14, 1:14, 14:1) = aten::mul(%668, %669) # /usr/local/lib/python3.8/dist-packages/transformers/modeling_utils.py:228:0
mod = getattr(traced_model.encoder.layer, "0") # traced_model.encoder.layer[0]
d = make_graph(getattr(traced_model.encoder.layer, "0"), classes_to_visit={'BertAttention', 'BertSelfAttention'})
d.render('bert_layer')
d
aten::matmul %attention_scores.1 : Float(1:2352, 12:196, 14:14, 14:1) = aten::matmul(%query_layer.1, %75), scope: __module.encoder/__module.encoder.layer.0/__module.encoder.layer.0.attention/__module.encoder.layer.0.attention.self # /usr/local/lib/python3.8/dist-packages/transformers/modeling_bert.py:236:0 aten::div %attention_scores.2 : Float(1:2352, 12:196, 14:14, 14:1) = aten::div(%attention_scores.1, %77), scope: __module.encoder/__module.encoder.layer.0/__module.encoder.layer.0.attention/__module.encoder.layer.0.attention.self # /usr/local/lib/python3.8/dist-packages/transformers/modeling_bert.py:237:0 aten::add %input.6 : Float(1:2352, 12:196, 14:14, 14:1) = aten::add(%attention_scores.2, %attention_mask, %79), scope: __module.encoder/__module.encoder.layer.0/__module.encoder.layer.0.attention/__module.encoder.layer.0.attention.self # /usr/local/lib/python3.8/dist-packages/transformers/modeling_bert.py:240:0 aten::softmax %input.7 : Float(1:2352, 12:196, 14:14, 14:1) = aten::softmax(%input.6, %81, %82), scope: __module.encoder/__module.encoder.layer.0/__module.encoder.layer.0.attention/__module.encoder.layer.0.attention.self # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:1498:0 aten::matmul %context_layer.1 : Float(1:10752, 12:896, 14:64, 64:1) = aten::matmul(%114, %value_layer.1), scope: __module.encoder/__module.encoder.layer.0/__module.encoder.layer.0.attention/__module.encoder.layer.0.attention.self # /usr/local/lib/python3.8/dist-packages/transformers/modeling_bert.py:253:0
import torchvision
m = torchvision.models.resnet18()
tm = torch.jit.trace(m, [torch.randn(1, 3, 224, 224)])
m = torchvision.models.resnet18()
m.layer1[0]
BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) )
print(inspect.getsource(m.layer1[0].forward))
def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out
d = make_graph(tm)
d.render("resnet18_full")
d
aten::add_ %input.10 : Float(1:200704, 64:3136, 56:56, 56:1) = aten::add_(%19, %1, %12), scope: __module.layer1/__module.layer1.0 # /usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py:69:0 aten::add_ %input.16 : Float(1:200704, 64:3136, 56:56, 56:1) = aten::add_(%19, %1, %12), scope: __module.layer1/__module.layer1.1 # /usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py:69:0 aten::add_ %input.23 : Float(1:100352, 128:784, 28:28, 28:1) = aten::add_(%21, %22, %14), scope: __module.layer2/__module.layer2.0 # /usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py:69:0 aten::add_ %input.29 : Float(1:100352, 128:784, 28:28, 28:1) = aten::add_(%19, %1, %12), scope: __module.layer2/__module.layer2.1 # /usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py:69:0 aten::add_ %input.36 : Float(1:50176, 256:196, 14:14, 14:1) = aten::add_(%21, %22, %14), scope: __module.layer3/__module.layer3.0 # /usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py:69:0 aten::add_ %input.42 : Float(1:50176, 256:196, 14:14, 14:1) = aten::add_(%19, %1, %12), scope: __module.layer3/__module.layer3.1 # /usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py:69:0 aten::add_ %input.49 : Float(1:25088, 512:49, 7:7, 7:1) = aten::add_(%21, %22, %14), scope: __module.layer4/__module.layer4.0 # /usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py:69:0 aten::add_ %input.55 : Float(1:25088, 512:49, 7:7, 7:1) = aten::add_(%19, %1, %12), scope: __module.layer4/__module.layer4.1 # /usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py:69:0 aten::flatten %input : Float(1:512, 512:1) = aten::flatten(%1536, %1182, %1183) # /usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py:214:0
d = make_graph(tm, classes_to_visit={'Sequential'})
d.render("resnet18_highlevel")
d
aten::flatten %input : Float(1:512, 512:1) = aten::flatten(%1536, %1182, %1183) # /usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py:214:0
d = make_graph(getattr(tm.layer1, "0"))
d.render("resnet18_basicblock")
d
aten::add_ %input.10 : Float(1:200704, 64:3136, 56:56, 56:1) = aten::add_(%19, %1, %12), scope: __module.layer1/__module.layer1.0 # /usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py:69:0
m = torchvision.models.segmentation.fcn_resnet50()
tm = torch.jit.trace(m, [torch.randn(1, 3, 224, 224)], strict=False)
d = make_graph(tm, classes_to_visit={'IntermediateLayerGetter', 'FCNHead'})
d.render("segmentation_fcn_high_level")
d
aten::upsample_bilinear2d %3096 : Float(1:1053696, 21:50176, 224:224, 224:1) = aten::upsample_bilinear2d(%3955, %3092, %3093, %3094, %3095) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3163:0 prim::DictConstruct %3098 : Dict(str, Tensor) = prim::DictConstruct(%3097, %3096)
class Detection(torch.nn.Module):
def __init__(self):
super().__init__()
self.m = torchvision.models.detection.fasterrcnn_resnet50_fpn().eval()
def forward(self, inp):
assert inp.shape[0] == 1
res, = self.m(inp)
return res['boxes'], res['labels'], res['scores']
tm = torch.jit.trace(Detection(), [torch.randn(1, 3, 224, 224)], check_trace=False)
<ipython-input-16-67b9b1140065>:6: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! assert inp.shape[0] == 1 /usr/local/lib/python3.8/dist-packages/torch/tensor.py:457: RuntimeWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results). warnings.warn('Iterating over a tensor might cause the trace to be incorrect. ' /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3000: UserWarning: The default behavior for interpolate/upsample with float scale_factor will change in 1.6.0 to align with other frameworks/libraries, and use scale_factor directly, instead of relying on the computed output size. If you wish to keep the old behavior, please set recompute_scale_factor=True. See the documentation of nn.Upsample for details. warnings.warn("The default behavior for interpolate/upsample with float scale_factor will change " /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3009: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). return [(torch.floor((input.size(i + 2).float() * torch.tensor(scale_factors[i], /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:163: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device), /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:164: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes] /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:125: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device)) /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:127: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device)) /usr/local/lib/python3.8/dist-packages/torchvision/ops/poolers.py:216: UserWarning: This overload of nonzero is deprecated: nonzero(Tensor input, *, Tensor out) Consider using one of the following signatures instead: nonzero(Tensor input, *, bool as_tuple) (Triggered internally at ../torch/csrc/utils/python_arg_parser.cpp:761.) idx_in_level = torch.nonzero(levels == level).squeeze(1) /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:269: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). torch.tensor(s, dtype=torch.float32, device=boxes.device) / /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:270: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
d = make_graph(tm.m, classes_to_visit={})
d.render('fasterrcnn.highlevel')
d
aten::div %ratio_height : Float() = aten::div(%52, %60), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:269:0 aten::div %ratio_width : Float() = aten::div(%69, %77), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:269:0 aten::mul %xmin : Float(0:1) = aten::mul(%xmin.1, %ratio_width), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:276:0 aten::mul %xmax : Float(0:1) = aten::mul(%xmax.1, %ratio_width), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:277:0 aten::mul %ymin : Float(0:1) = aten::mul(%ymin.1, %ratio_height), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:278:0 aten::mul %ymax : Float(0:1) = aten::mul(%ymax.1, %ratio_height), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:279:0
d = make_graph(tm.m, classes_to_visit={'RegionProposalNetwork', 'RoIHeads'})
d.render("fasterrcnn.detail")
d
aten::flatten %objectness.1 : Float(159882:1, 1:1) = aten::flatten(%434, %435, %436), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:248:0 aten::sub %widths.1 : Float(159882:1) = aten::sub(%472, %480, %481), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:192:0 aten::sub %heights.1 : Float(159882:1) = aten::sub(%490, %498, %499), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:193:0 aten::mul %510 : Float(159882:1) = aten::mul(%widths.1, %509), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:194:0 aten::add %ctr_x.1 : Float(159882:1) = aten::add(%508, %510, %511), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:194:0 aten::mul %522 : Float(159882:1) = aten::mul(%heights.1, %521), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:195:0 aten::add %ctr_y.1 : Float(159882:1) = aten::add(%520, %522, %523), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:195:0 aten::div %dx.1 : Float(159882:1, 1:1) = aten::div(%534, %535), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:198:0 aten::div %dy.1 : Float(159882:1, 1:1) = aten::div(%546, %547), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:199:0 aten::div %dw.1 : Float(159882:1, 1:1) = aten::div(%558, %559), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:200:0 aten::div %dh.1 : Float(159882:1, 1:1) = aten::div(%570, %571), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:201:0 aten::clamp %dw.2 : Float(159882:1, 1:1) = aten::clamp(%dw.1, %573, %574), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:204:0 aten::clamp %dh.2 : Float(159882:1, 1:1) = aten::clamp(%dh.1, %576, %577), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:205:0 aten::mul %586 : Float(159882:1, 1:1) = aten::mul(%dx.1, %585), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:207:0 aten::add %pred_ctr_x.1 : Float(159882:1, 1:1) = aten::add(%586, %593, %594), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:207:0 aten::mul %603 : Float(159882:1, 1:1) = aten::mul(%dy.1, %602), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:208:0 aten::add %pred_ctr_y.1 : Float(159882:1, 1:1) = aten::add(%603, %610, %611), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:208:0 aten::exp %613 : Float(159882:1, 1:1) = aten::exp(%dw.2), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:209:0 aten::mul %pred_w.1 : Float(159882:1, 1:1) = aten::mul(%613, %620), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:209:0 aten::exp %622 : Float(159882:1, 1:1) = aten::exp(%dh.2), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:210:0 aten::mul %pred_h.1 : Float(159882:1, 1:1) = aten::mul(%622, %629), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:210:0 aten::mul %639 : Float(159882:1, 1:1) = aten::mul(%638, %pred_w.1), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:212:0 aten::sub %pred_boxes1.1 : Float(159882:1, 1:1) = aten::sub(%pred_ctr_x.1, %639, %640), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:212:0 aten::mul %650 : Float(159882:1, 1:1) = aten::mul(%649, %pred_h.1), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:213:0 aten::sub %pred_boxes2.1 : Float(159882:1, 1:1) = aten::sub(%pred_ctr_y.1, %650, %651), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:213:0 aten::mul %661 : Float(159882:1, 1:1) = aten::mul(%660, %pred_w.1), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:214:0 aten::add %pred_boxes3.1 : Float(159882:1, 1:1) = aten::add(%pred_ctr_x.1, %661, %662), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:214:0 aten::mul %672 : Float(159882:1, 1:1) = aten::mul(%671, %pred_h.1), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:215:0 aten::add %pred_boxes4.1 : Float(159882:1, 1:1) = aten::add(%pred_ctr_y.1, %672, %673), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:215:0 aten::flatten %pred_boxes.1 : Float(159882:4, 4:1) = aten::flatten(%677, %678, %679), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:216:0 aten::topk %778 : Float(1:1000, 1000:1), %top_n_idx.1 : Long(1:1000, 1000:1) = aten::topk(%x.54, %774, %775, %776, %777), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:373:0 aten::add %782 : Long(1:1000, 1000:1) = aten::add(%top_n_idx.1, %780, %781), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:374:0 aten::topk %808 : Float(1:1000, 1000:1), %top_n_idx.2 : Long(1:1000, 1000:1) = aten::topk(%x.55, %804, %805, %806, %807), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:373:0 aten::add %811 : Long(1:1000, 1000:1) = aten::add(%top_n_idx.2, %offset.1, %810), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:374:0 aten::topk %836 : Float(1:1000, 1000:1), %top_n_idx.3 : Long(1:1000, 1000:1) = aten::topk(%x.56, %832, %833, %834, %835), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:373:0 aten::add %839 : Long(1:1000, 1000:1) = aten::add(%top_n_idx.3, %offset.2, %838), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:374:0 aten::topk %864 : Float(1:1000, 1000:1), %top_n_idx.4 : Long(1:1000, 1000:1) = aten::topk(%x.57, %860, %861, %862, %863), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:373:0 aten::add %867 : Long(1:1000, 1000:1) = aten::add(%top_n_idx.4, %offset.3, %866), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:374:0 aten::topk %892 : Float(1:507, 507:1), %top_n_idx.5 : Long(1:507, 507:1) = aten::topk(%x.58, %888, %889, %890, %891), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:373:0 aten::add %895 : Long(1:507, 507:1) = aten::add(%top_n_idx.5, %offset, %894), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:374:0 aten::max %boxes_x.2 : Float(4507:2, 2:1) = aten::max(%boxes_x.1, %1002), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:124:0 aten::min %boxes_x.3 : Float(4507:2, 2:1) = aten::min(%boxes_x.2, %1011), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:125:0 aten::max %boxes_y.2 : Float(4507:2, 2:1) = aten::max(%boxes_y.1, %1020), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:126:0 aten::min %boxes_y.3 : Float(4507:2, 2:1) = aten::min(%boxes_y.2, %1029), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:127:0 aten::sub %1061 : Float(4507:1) = aten::sub(%1051, %1059, %1060), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:100:0 aten::sub %1079 : Float(4507:1) = aten::sub(%1069, %1077, %1078), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:100:0 aten::ge %1081 : Bool(4507:1) = aten::ge(%1061, %1080), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torch/tensor.py:23:0 aten::ge %1083 : Bool(4507:1) = aten::ge(%1079, %1082), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torch/tensor.py:23:0 aten::__and__ %keep.1 : Bool(4507:1) = aten::__and__(%1081, %1083), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:101:0 aten::nonzero %1085 : Long(4507:1, 1:1) = aten::nonzero(%keep.1), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:102:0 aten::sub %widths : Float(1000:1) = aten::sub(%61, %69, %70), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:192:0 aten::sub %heights : Float(1000:1) = aten::sub(%79, %87, %88), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:193:0 aten::mul %99 : Float(1000:1) = aten::mul(%widths, %98), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:194:0 aten::add %ctr_x : Float(1000:1) = aten::add(%97, %99, %100), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:194:0 aten::mul %111 : Float(1000:1) = aten::mul(%heights, %110), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:195:0 aten::add %ctr_y : Float(1000:1) = aten::add(%109, %111, %112), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:195:0 aten::div %dx : Float(1000:91, 91:1) = aten::div(%123, %124), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:198:0 aten::div %dy : Float(1000:91, 91:1) = aten::div(%135, %136), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:199:0 aten::div %dw.3 : Float(1000:91, 91:1) = aten::div(%147, %148), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:200:0 aten::div %dh.3 : Float(1000:91, 91:1) = aten::div(%159, %160), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:201:0 aten::clamp %dw : Float(1000:91, 91:1) = aten::clamp(%dw.3, %162, %163), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:204:0 aten::clamp %dh : Float(1000:91, 91:1) = aten::clamp(%dh.3, %165, %166), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:205:0 aten::mul %175 : Float(1000:91, 91:1) = aten::mul(%dx, %174), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:207:0 aten::add %pred_ctr_x : Float(1000:91, 91:1) = aten::add(%175, %182, %183), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:207:0 aten::mul %192 : Float(1000:91, 91:1) = aten::mul(%dy, %191), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:208:0 aten::add %pred_ctr_y : Float(1000:91, 91:1) = aten::add(%192, %199, %200), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:208:0 aten::exp %202 : Float(1000:91, 91:1) = aten::exp(%dw), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:209:0 aten::mul %pred_w : Float(1000:91, 91:1) = aten::mul(%202, %209), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:209:0 aten::exp %211 : Float(1000:91, 91:1) = aten::exp(%dh), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:210:0 aten::mul %pred_h : Float(1000:91, 91:1) = aten::mul(%211, %218), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:210:0 aten::mul %228 : Float(1000:91, 91:1) = aten::mul(%227, %pred_w), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:212:0 aten::sub %pred_boxes1 : Float(1000:91, 91:1) = aten::sub(%pred_ctr_x, %228, %229), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:212:0 aten::mul %239 : Float(1000:91, 91:1) = aten::mul(%238, %pred_h), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:213:0 aten::sub %pred_boxes2 : Float(1000:91, 91:1) = aten::sub(%pred_ctr_y, %239, %240), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:213:0 aten::mul %250 : Float(1000:91, 91:1) = aten::mul(%249, %pred_w), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:214:0 aten::add %pred_boxes3 : Float(1000:91, 91:1) = aten::add(%pred_ctr_x, %250, %251), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:214:0 aten::mul %261 : Float(1000:91, 91:1) = aten::mul(%260, %pred_h), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:215:0 aten::add %pred_boxes4 : Float(1000:91, 91:1) = aten::add(%pred_ctr_y, %261, %262), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:215:0 aten::flatten %pred_boxes : Float(1000:364, 364:1) = aten::flatten(%266, %267, %268), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:216:0 aten::softmax %276 : Float(1000:91, 91:1) = aten::softmax(%18, %274, %275), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:1498:0 aten::max %boxes_x.5 : Float(1000:182, 91:2, 2:1) = aten::max(%boxes_x.4, %302), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:124:0 aten::min %boxes_x : Float(1000:182, 91:2, 2:1) = aten::min(%boxes_x.5, %311), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:125:0 aten::max %boxes_y.5 : Float(1000:182, 91:2, 2:1) = aten::max(%boxes_y.4, %320), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:126:0 aten::min %boxes_y : Float(1000:182, 91:2, 2:1) = aten::min(%boxes_y.5, %329), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:127:0 aten::gt %399 : Bool(90000:1) = aten::gt(%scores.5, %398), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torch/tensor.py:23:0 aten::nonzero %400 : Long(0:1, 1:1) = aten::nonzero(%399), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/roi_heads.py:699:0 aten::sub %450 : Float(0:1) = aten::sub(%440, %448, %449), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:100:0 aten::sub %468 : Float(0:1) = aten::sub(%458, %466, %467), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:100:0 aten::ge %470 : Bool(0:1) = aten::ge(%450, %469), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torch/tensor.py:23:0 aten::ge %472 : Bool(0:1) = aten::ge(%468, %471), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torch/tensor.py:23:0 aten::__and__ %keep.10 : Bool(0:1) = aten::__and__(%470, %472), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:101:0 aten::nonzero %474 : Long(0:1, 1:1) = aten::nonzero(%keep.10), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:102:0 aten::div %ratio_height : Float() = aten::div(%52, %60), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:269:0 aten::div %ratio_width : Float() = aten::div(%69, %77), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:269:0 aten::mul %xmin : Float(0:1) = aten::mul(%xmin.1, %ratio_width), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:276:0 aten::mul %xmax : Float(0:1) = aten::mul(%xmax.1, %ratio_width), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:277:0 aten::mul %ymin : Float(0:1) = aten::mul(%ymin.1, %ratio_height), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:278:0 aten::mul %ymax : Float(0:1) = aten::mul(%ymax.1, %ratio_height), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:279:0