Visualizing JIT Modules

Copyright 2020 by Thomas Viehmann

This is the code for my blog post Visualize PyTorch models. The code has been made available thanks to my single github sponsor at the time of writing. Thank you!

I license this code with the CC-BY-SA 4.0 license. Please link to my blog post or the original github source (linked from the blog post) with the attribution notice.

Introduction

Did you ever wish to get a concise picture of your PyTorch model's structure and found that too hard to get?

Recently, I did some work that involved looking at model structure in some detail. For my write-up, I wanted to get a diagram of some model structures. Even though it is a relatively common model, searching for a diagram didn't turn up something in the shape what I was looking for.

So how do can we get model structure for PyTorch models? The first stop probably is the neat string representation that PyTorch provides for nn.Modules - even without doing anything, it'll also cover our custom models pretty well. It is, however not without shortcomings.

Let's look at TorchVision's ResNet18 basic block as an example.

In [1]:
import torchvision
m = torchvision.models.resnet18()
m.layer1[0]
Out[1]:
BasicBlock(
  (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

So we have two convs and two batch norms. But how are things connected? Is there one ReLU?

Looking at the forward method (you can get this using Python's inspect module or ?? in IPython), we see some important details not in the summary:

In [3]:
import inspect
print(inspect.getsource(m.layer1[0].forward))
    def forward(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

So we missed the entire residual bit. Also, there are two ReLUs. Arguably, it is wrong to re-use stateless modules like this. It'll haunt you when you do things like quantization (because it becomes stateful then due to the quantization parameters) and it's mixing things too much. If you want stateless, use the functional interface.

But so we can build a visualization based on JITed modules.

We recurse into calls to make subgraphs and we have to take some care that the edges connecting the subgraph to the outer graph need to be part of the outer graph, but other than that, it is very straightforward, even though the details are messy.

In [1]:
import graphviz

def make_graph(mod, classes_to_visit=None, classes_found=None, dot=None, prefix="",
               input_preds=None, 
               parent_dot=None):
    preds = {}
    
    def find_name(i, self_input, suffix=None):
        if i == self_input:
            return suffix
        cur = i.node().s("name")
        if suffix is not None:
            cur = cur + '.' + suffix
        of = next(i.node().inputs())
        return find_name(of, self_input, suffix=cur)

    gr = mod.graph
    toshow = []
    # list(traced_model.graph.nodes())[0]
    self_input = next(gr.inputs())
    self_type = self_input.type().str().split('.')[-1]
    preds[self_input] = (set(), set()) # inps, ops
    
    if dot is None:
        dot = graphviz.Digraph(format='svg', graph_attr={'label': self_type, 'labelloc': 't'})
        #dot.attr('node', shape='box')

    seen_inpnames = set()
    seen_edges = set()
    
    def add_edge(dot, n1, n2):
        if (n1, n2) not in seen_edges:
            seen_edges.add((n1, n2))
            dot.edge(n1, n2)

    def make_edges(pr, inpname, name, op, edge_dot=dot):
        if op:
            if inpname not in seen_inpnames:
                seen_inpnames.add(inpname)
                label_lines = [[]]
                line_len = 0
                for w in op:
                    if line_len >= 20:
                        label_lines.append([])
                        line_len = 0
                    label_lines[-1].append(w)
                    line_len += len(w) + 1
                edge_dot.node(inpname, label='\n'.join([' '.join(w) for w in label_lines]), shape='box', style='rounded')
                for p in pr:
                    add_edge(edge_dot, p, inpname)
            add_edge(edge_dot, inpname, name)
        else:
            for p in pr:
                add_edge(edge_dot, p, name)

    for nr, i in enumerate(list(gr.inputs())[1:]):
        name = prefix+'inp_'+i.debugName()
        preds[i] = {name}, set()
        dot.node(name, shape='ellipse')
        if input_preds is not None:
            pr, op = input_preds[nr]
            make_edges(pr, 'inp_'+name, name, op, edge_dot=parent_dot)
        
    def is_relevant_type(t):
        kind = t.kind()
        if kind == 'TensorType':
            return True
        if kind in ('ListType', 'OptionalType'):
            return is_relevant_type(t.getElementType())
        if kind == 'TupleType':
            return any([is_relevant_type(tt) for tt in t.elements()])
        return False

    for n in gr.nodes():
        only_first_ops = {'aten::expand_as'}
        rel_inp_end = 1 if n.kind() in only_first_ops else None
            
        relevant_inputs = [i for i in list(n.inputs())[:rel_inp_end] if is_relevant_type(i.type())]
        relevant_outputs = [o for o in n.outputs() if is_relevant_type(o.type())]
        if n.kind() == 'prim::CallMethod':
            fq_submodule_name = '.'.join([nc for nc in list(n.inputs())[0].type().str().split('.') if not nc.startswith('__')])
            submodule_type = list(n.inputs())[0].type().str().split('.')[-1]
            submodule_name = find_name(list(n.inputs())[0], self_input)
            name = prefix+'.'+n.output().debugName()
            label = prefix+submodule_name+' (' + submodule_type + ')'
            if classes_found is not None:
                classes_found.add(fq_submodule_name)
            if ((classes_to_visit is None and
                 (not fq_submodule_name.startswith('torch.nn') or 
                  fq_submodule_name.startswith('torch.nn.modules.container')))
                or (classes_to_visit is not None and 
                    (submodule_type in classes_to_visit
                    or fq_submodule_name in classes_to_visit))):
                # go into subgraph
                sub_prefix = prefix+submodule_name+'.'
                with dot.subgraph(name="cluster_"+name) as sub_dot:
                    sub_dot.attr(label=label)
                    submod = mod
                    for k in  submodule_name.split('.'):
                        submod = getattr(submod, k)
                    make_graph(submod, dot=sub_dot, prefix=sub_prefix,
                              input_preds = [preds[i] for i in list(n.inputs())[1:]],
                              parent_dot=dot, classes_to_visit=classes_to_visit,
                              classes_found=classes_found)
                for i, o in enumerate(n.outputs()):
                    preds[o] = {sub_prefix+f'out_{i}'}, set()
            else:
                dot.node(name, label=label, shape='box')
                for i in relevant_inputs:
                    pr, op = preds[i]
                    make_edges(pr, prefix+i.debugName(), name, op)
                for o in n.outputs():
                    preds[o] = {name}, set()
        elif n.kind() == 'prim::CallFunction':
            funcname = list(n.inputs())[0].type().__repr__().split('.')[-1]
            name = prefix+'.'+n.output().debugName()
            label = funcname
            dot.node(name, label=label, shape='box')
            for i in relevant_inputs:
                pr, op = preds[i]
                make_edges(pr, prefix+i.debugName(), name, op)
            for o in n.outputs():
                preds[o] = {name}, set()
        else:
            unseen_ops = {'prim::ListConstruct', 'prim::TupleConstruct', 'aten::index', 
                          'aten::size', 'aten::slice', 'aten::unsqueeze', 'aten::squeeze',
                          'aten::to', 'aten::view', 'aten::permute', 'aten::transpose', 'aten::contiguous',
                          'aten::permute', 'aten::Int', 'prim::TupleUnpack', 'prim::ListUnpack', 'aten::unbind',
                          'aten::select', 'aten::detach', 'aten::stack', 'aten::reshape', 'aten::split_with_sizes',
                          'aten::cat', 'aten::expand', 'aten::expand_as', 'aten::_shape_as_tensor',
                          }
        
            absorbing_ops = ('aten::size', 'aten::_shape_as_tensor') # probably also partially absorbing ops. :/
            if False:
                print(n.kind())
                #DEBUG['kinds'].add(n.kind())
                #DEBUG[n.kind()] = n
                label = n.kind().split('::')[-1].rstrip('_')
                name = prefix+'.'+relevant_outputs[0].debugName()
                dot.node(name, label=label, shape='box', style='rounded')
                for i in relevant_inputs:
                    pr, op = preds[i]
                    make_edges(pr, prefix+i.debugName(), name, op)
                for o in n.outputs():
                    preds[o] = {name}, set()
            if True:
                label = n.kind().split('::')[-1].rstrip('_')
                pr, op = set(), set()
                for i in relevant_inputs:
                    apr, aop = preds[i]
                    pr |= apr
                    op |= aop
                if pr and n.kind() not in unseen_ops:
                    print(n.kind(), n)
                if n.kind() in absorbing_ops:
                    pr, op = set(), set()
                elif len(relevant_inputs) > 0 and len(relevant_outputs) > 0 and n.kind() not in unseen_ops:
                    op.add(label)
                for o in n.outputs():
                    preds[o] = pr, op

    for i, o in enumerate(gr.outputs()):
        name = prefix+f'out_{i}'
        dot.node(name, shape='ellipse')
        pr, op = preds[o]
        make_edges(pr, 'inp_'+name, name, op)
    return dot

Applications

Let's apply it! These are the pictures from my blog post along with the code that generated them.

The following code is from the transformers library (Copyright 2018- The Hugging Face team. Apache Licensed.).

In [2]:
import transformers

from transformers import BertModel, BertTokenizer, BertConfig
import numpy

import torch

enc = BertTokenizer.from_pretrained("bert-base-uncased")

# Tokenizing input text
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = enc.tokenize(text)

# Masking one of the input tokens
masked_index = 8
tokenized_text[masked_index] = '[MASK]'
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]

# Creating a dummy input
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
dummy_input = [tokens_tensor, segments_tensors]

# If you are instantiating the model with `from_pretrained` you can also easily set the TorchScript flag
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)

model.eval()
for p in model.parameters():
    p.requires_grad_(False)

transformers.__version__
Out[2]:
'2.11.0'
In [3]:
# Creating the trace
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
traced_model.eval()
for p in traced_model.parameters():
    p.requires_grad_(False)
In [4]:
 
In [5]:
if 0:
    # resolvign functions?
    t = fn.type()
    def lookup(fn):
        n = str(fn.type()).split('.')[1:]
        res = globals()[n[0]]
        for nc in n[1:]:
            res = getattr(res, nc)
        return res
    lookup(fn).graph
In [6]:
d = make_graph(traced_model, classes_to_visit={'BertEncoder'})
d.render('bert_model')
d
aten::rsub %668 : Float(1:14, 1:14, 1:14, 14:1) = aten::rsub(%665, %666, %667) # /usr/local/lib/python3.8/dist-packages/torch/tensor.py:395:0

aten::mul %attention_mask : Float(1:14, 1:14, 1:14, 14:1) = aten::mul(%668, %669) # /usr/local/lib/python3.8/dist-packages/transformers/modeling_utils.py:228:0

Out[6]:
%3 BertModel cluster_.3358 encoder (BertEncoder) inp_input_ids inp_input_ids .3357 embeddings (BertEmbeddings) inp_input_ids->.3357 inp_attention_mask.1 inp_attention_mask.1 inp_encoder.inp_attention_mask mul rsub inp_attention_mask.1->inp_encoder.inp_attention_mask encoder.inp_26 encoder.inp_26 .3357->encoder.inp_26 encoder..39 encoder.layer.0 (BertLayer) encoder.inp_26->encoder..39 encoder.inp_attention_mask encoder.inp_attention_mask inp_encoder.inp_attention_mask->encoder.inp_attention_mask encoder.inp_attention_mask->encoder..39 encoder..40 encoder.layer.1 (BertLayer) encoder.inp_attention_mask->encoder..40 encoder..41 encoder.layer.2 (BertLayer) encoder.inp_attention_mask->encoder..41 encoder..42 encoder.layer.3 (BertLayer) encoder.inp_attention_mask->encoder..42 encoder..43 encoder.layer.4 (BertLayer) encoder.inp_attention_mask->encoder..43 encoder..44 encoder.layer.5 (BertLayer) encoder.inp_attention_mask->encoder..44 encoder..45 encoder.layer.6 (BertLayer) encoder.inp_attention_mask->encoder..45 encoder..46 encoder.layer.7 (BertLayer) encoder.inp_attention_mask->encoder..46 encoder..47 encoder.layer.8 (BertLayer) encoder.inp_attention_mask->encoder..47 encoder..48 encoder.layer.9 (BertLayer) encoder.inp_attention_mask->encoder..48 encoder..49 encoder.layer.10 (BertLayer) encoder.inp_attention_mask->encoder..49 encoder..50 encoder.layer.11 (BertLayer) encoder.inp_attention_mask->encoder..50 encoder..39->encoder..40 encoder..40->encoder..41 encoder..41->encoder..42 encoder..42->encoder..43 encoder..43->encoder..44 encoder..44->encoder..45 encoder..45->encoder..46 encoder..46->encoder..47 encoder..47->encoder..48 encoder..48->encoder..49 encoder..49->encoder..50 encoder.out_0 encoder.out_0 encoder..50->encoder.out_0 .3359 pooler (BertPooler) encoder.out_0->.3359 out_0 out_0 encoder.out_0->out_0 .3359->out_0
In [7]:
mod = getattr(traced_model.encoder.layer, "0") # traced_model.encoder.layer[0]
d = make_graph(getattr(traced_model.encoder.layer, "0"), classes_to_visit={'BertAttention', 'BertSelfAttention'})
d.render('bert_layer')
d
aten::matmul %attention_scores.1 : Float(1:2352, 12:196, 14:14, 14:1) = aten::matmul(%query_layer.1, %75), scope: __module.encoder/__module.encoder.layer.0/__module.encoder.layer.0.attention/__module.encoder.layer.0.attention.self # /usr/local/lib/python3.8/dist-packages/transformers/modeling_bert.py:236:0

aten::div %attention_scores.2 : Float(1:2352, 12:196, 14:14, 14:1) = aten::div(%attention_scores.1, %77), scope: __module.encoder/__module.encoder.layer.0/__module.encoder.layer.0.attention/__module.encoder.layer.0.attention.self # /usr/local/lib/python3.8/dist-packages/transformers/modeling_bert.py:237:0

aten::add %input.6 : Float(1:2352, 12:196, 14:14, 14:1) = aten::add(%attention_scores.2, %attention_mask, %79), scope: __module.encoder/__module.encoder.layer.0/__module.encoder.layer.0.attention/__module.encoder.layer.0.attention.self # /usr/local/lib/python3.8/dist-packages/transformers/modeling_bert.py:240:0

aten::softmax %input.7 : Float(1:2352, 12:196, 14:14, 14:1) = aten::softmax(%input.6, %81, %82), scope: __module.encoder/__module.encoder.layer.0/__module.encoder.layer.0.attention/__module.encoder.layer.0.attention.self # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:1498:0

aten::matmul %context_layer.1 : Float(1:10752, 12:896, 14:64, 64:1) = aten::matmul(%114, %value_layer.1), scope: __module.encoder/__module.encoder.layer.0/__module.encoder.layer.0.attention/__module.encoder.layer.0.attention.self # /usr/local/lib/python3.8/dist-packages/transformers/modeling_bert.py:253:0

Out[7]:
%3 BertLayer cluster_.9 attention (BertAttention) cluster_attention..7 attention.self (BertSelfAttention) inp_1 inp_1 attention.inp_1 attention.inp_1 inp_1->attention.inp_1 inp_attention_mask inp_attention_mask attention.inp_attention_mask attention.inp_attention_mask inp_attention_mask->attention.inp_attention_mask attention.self.inp_1 attention.self.inp_1 attention.inp_1->attention.self.inp_1 attention..8 attention.output (BertSelfOutput) attention.inp_1->attention..8 attention.self.inp_attention_mask attention.self.inp_attention_mask attention.inp_attention_mask->attention.self.inp_attention_mask attention.self..111 attention.self.query (Linear) attention.self.inp_1->attention.self..111 attention.self..112 attention.self.key (Linear) attention.self.inp_1->attention.self..112 attention.self..113 attention.self.value (Linear) attention.self.inp_1->attention.self..113 attention.self.input.7 matmul div softmax add attention.self.inp_attention_mask->attention.self.input.7 attention.self..111->attention.self.input.7 attention.self..112->attention.self.input.7 inp_attention.self.out_0 matmul attention.self..113->inp_attention.self.out_0 attention.self..114 attention.self.dropout (Dropout) attention.self..114->inp_attention.self.out_0 attention.self.input.7->attention.self..114 attention.self.out_0 attention.self.out_0 attention.self.out_0->attention..8 inp_attention.self.out_0->attention.self.out_0 attention.out_0 attention.out_0 attention..8->attention.out_0 .10 intermediate (BertIntermediate) attention.out_0->.10 .11 output (BertOutput) attention.out_0->.11 .10->.11 out_0 out_0 .11->out_0
In [8]:
import torchvision
In [9]:
m = torchvision.models.resnet18()
tm = torch.jit.trace(m, [torch.randn(1, 3, 224, 224)])
In [10]:
m = torchvision.models.resnet18()
m.layer1[0]
Out[10]:
BasicBlock(
  (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
In [11]:
print(inspect.getsource(m.layer1[0].forward))
    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

In [12]:
d = make_graph(tm)
d.render("resnet18_full")
d
aten::add_ %input.10 : Float(1:200704, 64:3136, 56:56, 56:1) = aten::add_(%19, %1, %12), scope: __module.layer1/__module.layer1.0 # /usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py:69:0

aten::add_ %input.16 : Float(1:200704, 64:3136, 56:56, 56:1) = aten::add_(%19, %1, %12), scope: __module.layer1/__module.layer1.1 # /usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py:69:0

aten::add_ %input.23 : Float(1:100352, 128:784, 28:28, 28:1) = aten::add_(%21, %22, %14), scope: __module.layer2/__module.layer2.0 # /usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py:69:0

aten::add_ %input.29 : Float(1:100352, 128:784, 28:28, 28:1) = aten::add_(%19, %1, %12), scope: __module.layer2/__module.layer2.1 # /usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py:69:0

aten::add_ %input.36 : Float(1:50176, 256:196, 14:14, 14:1) = aten::add_(%21, %22, %14), scope: __module.layer3/__module.layer3.0 # /usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py:69:0

aten::add_ %input.42 : Float(1:50176, 256:196, 14:14, 14:1) = aten::add_(%19, %1, %12), scope: __module.layer3/__module.layer3.1 # /usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py:69:0

aten::add_ %input.49 : Float(1:25088, 512:49, 7:7, 7:1) = aten::add_(%21, %22, %14), scope: __module.layer4/__module.layer4.0 # /usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py:69:0

aten::add_ %input.55 : Float(1:25088, 512:49, 7:7, 7:1) = aten::add_(%19, %1, %12), scope: __module.layer4/__module.layer4.1 # /usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py:69:0

aten::flatten %input : Float(1:512, 512:1) = aten::flatten(%1536, %1182, %1183) # /usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py:214:0

Out[12]:
%3 ResNet cluster_.1532 layer1 (Sequential) cluster_layer1..6 layer1.0 (BasicBlock) cluster_layer1..7 layer1.1 (BasicBlock) cluster_.1533 layer2 (Sequential) cluster_layer2..6 layer2.0 (BasicBlock) cluster_layer2.0..22 layer2.0.downsample (Sequential) cluster_layer2..7 layer2.1 (BasicBlock) cluster_.1534 layer3 (Sequential) cluster_layer3..6 layer3.0 (BasicBlock) cluster_layer3.0..22 layer3.0.downsample (Sequential) cluster_layer3..7 layer3.1 (BasicBlock) cluster_.1535 layer4 (Sequential) cluster_layer4..6 layer4.0 (BasicBlock) cluster_layer4.0..22 layer4.0.downsample (Sequential) cluster_layer4..7 layer4.1 (BasicBlock) inp_input.1 inp_input.1 .1528 conv1 (Conv2d) inp_input.1->.1528 .1529 bn1 (BatchNorm2d) .1528->.1529 .1530 relu (ReLU) .1529->.1530 .1531 maxpool (MaxPool2d) .1530->.1531 layer1.inp_4 layer1.inp_4 .1531->layer1.inp_4 layer1.0.inp_1 layer1.0.inp_1 layer1.inp_4->layer1.0.inp_1 layer1.0..15 layer1.0.conv1 (Conv2d) layer1.0.inp_1->layer1.0..15 layer1.0.input.10 add layer1.0.inp_1->layer1.0.input.10 layer1.0..16 layer1.0.bn1 (BatchNorm2d) layer1.0..15->layer1.0..16 layer1.0..17 layer1.0.relu (ReLU) layer1.0..16->layer1.0..17 layer1.0..18 layer1.0.conv2 (Conv2d) layer1.0..17->layer1.0..18 layer1.0..19 layer1.0.bn2 (BatchNorm2d) layer1.0..18->layer1.0..19 layer1.0..19->layer1.0.input.10 layer1.0..20 layer1.0.relu (ReLU) layer1.0.out_0 layer1.0.out_0 layer1.0..20->layer1.0.out_0 layer1.0.input.10->layer1.0..20 layer1.1.inp_1 layer1.1.inp_1 layer1.0.out_0->layer1.1.inp_1 layer1.1..15 layer1.1.conv1 (Conv2d) layer1.1.inp_1->layer1.1..15 layer1.1.input.16 add layer1.1.inp_1->layer1.1.input.16 layer1.1..16 layer1.1.bn1 (BatchNorm2d) layer1.1..15->layer1.1..16 layer1.1..17 layer1.1.relu (ReLU) layer1.1..16->layer1.1..17 layer1.1..18 layer1.1.conv2 (Conv2d) layer1.1..17->layer1.1..18 layer1.1..19 layer1.1.bn2 (BatchNorm2d) layer1.1..18->layer1.1..19 layer1.1..19->layer1.1.input.16 layer1.1..20 layer1.1.relu (ReLU) layer1.1.out_0 layer1.1.out_0 layer1.1..20->layer1.1.out_0 layer1.1.input.16->layer1.1..20 layer1.out_0 layer1.out_0 layer1.1.out_0->layer1.out_0 layer2.inp_4 layer2.inp_4 layer1.out_0->layer2.inp_4 layer2.0.inp_1 layer2.0.inp_1 layer2.inp_4->layer2.0.inp_1 layer2.0..17 layer2.0.conv1 (Conv2d) layer2.0.inp_1->layer2.0..17 layer2.0.downsample.inp_1 layer2.0.downsample.inp_1 layer2.0.inp_1->layer2.0.downsample.inp_1 layer2.0..18 layer2.0.bn1 (BatchNorm2d) layer2.0..17->layer2.0..18 layer2.0..19 layer2.0.relu (ReLU) layer2.0..18->layer2.0..19 layer2.0..20 layer2.0.conv2 (Conv2d) layer2.0..19->layer2.0..20 layer2.0..21 layer2.0.bn2 (BatchNorm2d) layer2.0..20->layer2.0..21 layer2.0.input.23 add layer2.0..21->layer2.0.input.23 layer2.0.downsample..6 layer2.0.downsample.0 (Conv2d) layer2.0.downsample.inp_1->layer2.0.downsample..6 layer2.0.downsample..7 layer2.0.downsample.1 (BatchNorm2d) layer2.0.downsample..6->layer2.0.downsample..7 layer2.0.downsample.out_0 layer2.0.downsample.out_0 layer2.0.downsample..7->layer2.0.downsample.out_0 layer2.0.downsample.out_0->layer2.0.input.23 layer2.0..23 layer2.0.relu (ReLU) layer2.0.out_0 layer2.0.out_0 layer2.0..23->layer2.0.out_0 layer2.0.input.23->layer2.0..23 layer2.1.inp_1 layer2.1.inp_1 layer2.0.out_0->layer2.1.inp_1 layer2.1..15 layer2.1.conv1 (Conv2d) layer2.1.inp_1->layer2.1..15 layer2.1.input.29 add layer2.1.inp_1->layer2.1.input.29 layer2.1..16 layer2.1.bn1 (BatchNorm2d) layer2.1..15->layer2.1..16 layer2.1..17 layer2.1.relu (ReLU) layer2.1..16->layer2.1..17 layer2.1..18 layer2.1.conv2 (Conv2d) layer2.1..17->layer2.1..18 layer2.1..19 layer2.1.bn2 (BatchNorm2d) layer2.1..18->layer2.1..19 layer2.1..19->layer2.1.input.29 layer2.1..20 layer2.1.relu (ReLU) layer2.1.out_0 layer2.1.out_0 layer2.1..20->layer2.1.out_0 layer2.1.input.29->layer2.1..20 layer2.out_0 layer2.out_0 layer2.1.out_0->layer2.out_0 layer3.inp_4 layer3.inp_4 layer2.out_0->layer3.inp_4 layer3.0.inp_1 layer3.0.inp_1 layer3.inp_4->layer3.0.inp_1 layer3.0..17 layer3.0.conv1 (Conv2d) layer3.0.inp_1->layer3.0..17 layer3.0.downsample.inp_1 layer3.0.downsample.inp_1 layer3.0.inp_1->layer3.0.downsample.inp_1 layer3.0..18 layer3.0.bn1 (BatchNorm2d) layer3.0..17->layer3.0..18 layer3.0..19 layer3.0.relu (ReLU) layer3.0..18->layer3.0..19 layer3.0..20 layer3.0.conv2 (Conv2d) layer3.0..19->layer3.0..20 layer3.0..21 layer3.0.bn2 (BatchNorm2d) layer3.0..20->layer3.0..21 layer3.0.input.36 add layer3.0..21->layer3.0.input.36 layer3.0.downsample..6 layer3.0.downsample.0 (Conv2d) layer3.0.downsample.inp_1->layer3.0.downsample..6 layer3.0.downsample..7 layer3.0.downsample.1 (BatchNorm2d) layer3.0.downsample..6->layer3.0.downsample..7 layer3.0.downsample.out_0 layer3.0.downsample.out_0 layer3.0.downsample..7->layer3.0.downsample.out_0 layer3.0.downsample.out_0->layer3.0.input.36 layer3.0..23 layer3.0.relu (ReLU) layer3.0.out_0 layer3.0.out_0 layer3.0..23->layer3.0.out_0 layer3.0.input.36->layer3.0..23 layer3.1.inp_1 layer3.1.inp_1 layer3.0.out_0->layer3.1.inp_1 layer3.1..15 layer3.1.conv1 (Conv2d) layer3.1.inp_1->layer3.1..15 layer3.1.input.42 add layer3.1.inp_1->layer3.1.input.42 layer3.1..16 layer3.1.bn1 (BatchNorm2d) layer3.1..15->layer3.1..16 layer3.1..17 layer3.1.relu (ReLU) layer3.1..16->layer3.1..17 layer3.1..18 layer3.1.conv2 (Conv2d) layer3.1..17->layer3.1..18 layer3.1..19 layer3.1.bn2 (BatchNorm2d) layer3.1..18->layer3.1..19 layer3.1..19->layer3.1.input.42 layer3.1..20 layer3.1.relu (ReLU) layer3.1.out_0 layer3.1.out_0 layer3.1..20->layer3.1.out_0 layer3.1.input.42->layer3.1..20 layer3.out_0 layer3.out_0 layer3.1.out_0->layer3.out_0 layer4.inp_4 layer4.inp_4 layer3.out_0->layer4.inp_4 layer4.0.inp_1 layer4.0.inp_1 layer4.inp_4->layer4.0.inp_1 layer4.0..17 layer4.0.conv1 (Conv2d) layer4.0.inp_1->layer4.0..17 layer4.0.downsample.inp_1 layer4.0.downsample.inp_1 layer4.0.inp_1->layer4.0.downsample.inp_1 layer4.0..18 layer4.0.bn1 (BatchNorm2d) layer4.0..17->layer4.0..18 layer4.0..19 layer4.0.relu (ReLU) layer4.0..18->layer4.0..19 layer4.0..20 layer4.0.conv2 (Conv2d) layer4.0..19->layer4.0..20 layer4.0..21 layer4.0.bn2 (BatchNorm2d) layer4.0..20->layer4.0..21 layer4.0.input.49 add layer4.0..21->layer4.0.input.49 layer4.0.downsample..6 layer4.0.downsample.0 (Conv2d) layer4.0.downsample.inp_1->layer4.0.downsample..6 layer4.0.downsample..7 layer4.0.downsample.1 (BatchNorm2d) layer4.0.downsample..6->layer4.0.downsample..7 layer4.0.downsample.out_0 layer4.0.downsample.out_0 layer4.0.downsample..7->layer4.0.downsample.out_0 layer4.0.downsample.out_0->layer4.0.input.49 layer4.0..23 layer4.0.relu (ReLU) layer4.0.out_0 layer4.0.out_0 layer4.0..23->layer4.0.out_0 layer4.0.input.49->layer4.0..23 layer4.1.inp_1 layer4.1.inp_1 layer4.0.out_0->layer4.1.inp_1 layer4.1..15 layer4.1.conv1 (Conv2d) layer4.1.inp_1->layer4.1..15 layer4.1.input.55 add layer4.1.inp_1->layer4.1.input.55 layer4.1..16 layer4.1.bn1 (BatchNorm2d) layer4.1..15->layer4.1..16 layer4.1..17 layer4.1.relu (ReLU) layer4.1..16->layer4.1..17 layer4.1..18 layer4.1.conv2 (Conv2d) layer4.1..17->layer4.1..18 layer4.1..19 layer4.1.bn2 (BatchNorm2d) layer4.1..18->layer4.1..19 layer4.1..19->layer4.1.input.55 layer4.1..20 layer4.1.relu (ReLU) layer4.1.out_0 layer4.1.out_0 layer4.1..20->layer4.1.out_0 layer4.1.input.55->layer4.1..20 layer4.out_0 layer4.out_0 layer4.1.out_0->layer4.out_0 .1536 avgpool (AdaptiveAvgPool2d) layer4.out_0->.1536 input flatten .1536->input .1537 fc (Linear) out_0 out_0 .1537->out_0 input->.1537
In [13]:
d = make_graph(tm, classes_to_visit={'Sequential'})
d.render("resnet18_highlevel")
d
aten::flatten %input : Float(1:512, 512:1) = aten::flatten(%1536, %1182, %1183) # /usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py:214:0

Out[13]:
%3 ResNet cluster_.1532 layer1 (Sequential) cluster_.1533 layer2 (Sequential) cluster_.1534 layer3 (Sequential) cluster_.1535 layer4 (Sequential) inp_input.1 inp_input.1 .1528 conv1 (Conv2d) inp_input.1->.1528 .1529 bn1 (BatchNorm2d) .1528->.1529 .1530 relu (ReLU) .1529->.1530 .1531 maxpool (MaxPool2d) .1530->.1531 layer1.inp_4 layer1.inp_4 .1531->layer1.inp_4 layer1..6 layer1.0 (BasicBlock) layer1.inp_4->layer1..6 layer1..7 layer1.1 (BasicBlock) layer1..6->layer1..7 layer1.out_0 layer1.out_0 layer1..7->layer1.out_0 layer2.inp_4 layer2.inp_4 layer1.out_0->layer2.inp_4 layer2..6 layer2.0 (BasicBlock) layer2.inp_4->layer2..6 layer2..7 layer2.1 (BasicBlock) layer2..6->layer2..7 layer2.out_0 layer2.out_0 layer2..7->layer2.out_0 layer3.inp_4 layer3.inp_4 layer2.out_0->layer3.inp_4 layer3..6 layer3.0 (BasicBlock) layer3.inp_4->layer3..6 layer3..7 layer3.1 (BasicBlock) layer3..6->layer3..7 layer3.out_0 layer3.out_0 layer3..7->layer3.out_0 layer4.inp_4 layer4.inp_4 layer3.out_0->layer4.inp_4 layer4..6 layer4.0 (BasicBlock) layer4.inp_4->layer4..6 layer4..7 layer4.1 (BasicBlock) layer4..6->layer4..7 layer4.out_0 layer4.out_0 layer4..7->layer4.out_0 .1536 avgpool (AdaptiveAvgPool2d) layer4.out_0->.1536 input flatten .1536->input .1537 fc (Linear) out_0 out_0 .1537->out_0 input->.1537
In [14]:
d = make_graph(getattr(tm.layer1, "0"))
d.render("resnet18_basicblock")
d
aten::add_ %input.10 : Float(1:200704, 64:3136, 56:56, 56:1) = aten::add_(%19, %1, %12), scope: __module.layer1/__module.layer1.0 # /usr/local/lib/python3.8/dist-packages/torchvision/models/resnet.py:69:0

Out[14]:
%3 BasicBlock inp_1 inp_1 .15 conv1 (Conv2d) inp_1->.15 input.10 add inp_1->input.10 .16 bn1 (BatchNorm2d) .15->.16 .17 relu (ReLU) .16->.17 .18 conv2 (Conv2d) .17->.18 .19 bn2 (BatchNorm2d) .18->.19 .19->input.10 .20 relu (ReLU) out_0 out_0 .20->out_0 input.10->.20
In [15]:
m = torchvision.models.segmentation.fcn_resnet50()
tm = torch.jit.trace(m, [torch.randn(1, 3, 224, 224)], strict=False)
d = make_graph(tm, classes_to_visit={'IntermediateLayerGetter', 'FCNHead'})
d.render("segmentation_fcn_high_level")
d
aten::upsample_bilinear2d %3096 : Float(1:1053696, 21:50176, 224:224, 224:1) = aten::upsample_bilinear2d(%3955, %3092, %3093, %3094, %3095) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3163:0

prim::DictConstruct %3098 : Dict(str, Tensor) = prim::DictConstruct(%3097, %3096)

Out[15]:
%3 FCN cluster_.3954 backbone (IntermediateLayerGetter) cluster_.3955 classifier (FCNHead) inp_x inp_x backbone.inp_x backbone.inp_x inp_x->backbone.inp_x backbone..18 backbone.conv1 (Conv2d) backbone.inp_x->backbone..18 backbone..19 backbone.bn1 (BatchNorm2d) backbone..18->backbone..19 backbone..20 backbone.relu (ReLU) backbone..19->backbone..20 backbone..21 backbone.maxpool (MaxPool2d) backbone..20->backbone..21 backbone..22 backbone.layer1 (Sequential) backbone..21->backbone..22 backbone..23 backbone.layer2 (Sequential) backbone..22->backbone..23 backbone..24 backbone.layer3 (Sequential) backbone..23->backbone..24 backbone..25 backbone.layer4 (Sequential) backbone..24->backbone..25 backbone.out_0 backbone.out_0 backbone..25->backbone.out_0 classifier.inp_7 classifier.inp_7 backbone.out_0->classifier.inp_7 classifier..12 classifier.0 (Conv2d) classifier.inp_7->classifier..12 classifier..13 classifier.1 (BatchNorm2d) classifier..12->classifier..13 classifier..14 classifier.2 (ReLU) classifier..13->classifier..14 classifier..15 classifier.3 (Dropout) classifier..14->classifier..15 classifier..16 classifier.4 (Conv2d) classifier..15->classifier..16 classifier.out_0 classifier.out_0 classifier..16->classifier.out_0 inp_out_0 upsample_bilinear2d classifier.out_0->inp_out_0 out_0 out_0 inp_out_0->out_0
In [ ]:
 
In [ ]:
 
In [16]:
class Detection(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.m = torchvision.models.detection.fasterrcnn_resnet50_fpn().eval()
    def forward(self, inp):
        assert inp.shape[0] == 1
        res, = self.m(inp)
        return res['boxes'], res['labels'], res['scores']

tm = torch.jit.trace(Detection(), [torch.randn(1, 3, 224, 224)], check_trace=False)
<ipython-input-16-67b9b1140065>:6: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert inp.shape[0] == 1
/usr/local/lib/python3.8/dist-packages/torch/tensor.py:457: RuntimeWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  warnings.warn('Iterating over a tensor might cause the trace to be incorrect. '
/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3000: UserWarning: The default behavior for interpolate/upsample with float scale_factor will change in 1.6.0 to align with other frameworks/libraries, and use scale_factor directly, instead of relying on the computed output size. If you wish to keep the old behavior, please set recompute_scale_factor=True. See the documentation of nn.Upsample for details. 
  warnings.warn("The default behavior for interpolate/upsample with float scale_factor will change "
/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:3009: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  return [(torch.floor((input.size(i + 2).float() * torch.tensor(scale_factors[i],
/usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:163: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device),
/usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:164: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes]
/usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:125: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device))
/usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:127: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device))
/usr/local/lib/python3.8/dist-packages/torchvision/ops/poolers.py:216: UserWarning: This overload of nonzero is deprecated:
	nonzero(Tensor input, *, Tensor out)
Consider using one of the following signatures instead:
	nonzero(Tensor input, *, bool as_tuple) (Triggered internally at  ../torch/csrc/utils/python_arg_parser.cpp:761.)
  idx_in_level = torch.nonzero(levels == level).squeeze(1)
/usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:269: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  torch.tensor(s, dtype=torch.float32, device=boxes.device) /
/usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:270: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
In [17]:
d = make_graph(tm.m, classes_to_visit={})
d.render('fasterrcnn.highlevel')
d
aten::div %ratio_height : Float() = aten::div(%52, %60), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:269:0

aten::div %ratio_width : Float() = aten::div(%69, %77), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:269:0

aten::mul %xmin : Float(0:1) = aten::mul(%xmin.1, %ratio_width), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:276:0

aten::mul %xmax : Float(0:1) = aten::mul(%xmax.1, %ratio_width), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:277:0

aten::mul %ymin : Float(0:1) = aten::mul(%ymin.1, %ratio_height), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:278:0

aten::mul %ymax : Float(0:1) = aten::mul(%ymax.1, %ratio_height), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:279:0

Out[17]:
%3 FasterRCNN inp_inp inp_inp .93 transform (GeneralizedRCNNTransform) inp_inp->.93 .94 backbone (BackboneWithFPN) .93->.94 .95 rpn (RegionProposalNetwork) .93->.95 .96 roi_heads (RoIHeads) .93->.96 inp_out_0 div mul .93->inp_out_0 .94->.95 .94->.96 .95->.96 .96->inp_out_0 out_0 out_0 inp_out_0->out_0
In [18]:
d = make_graph(tm.m, classes_to_visit={'RegionProposalNetwork', 'RoIHeads'})
d.render("fasterrcnn.detail")
d
aten::flatten %objectness.1 : Float(159882:1, 1:1) = aten::flatten(%434, %435, %436), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:248:0

aten::sub %widths.1 : Float(159882:1) = aten::sub(%472, %480, %481), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:192:0

aten::sub %heights.1 : Float(159882:1) = aten::sub(%490, %498, %499), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:193:0

aten::mul %510 : Float(159882:1) = aten::mul(%widths.1, %509), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:194:0

aten::add %ctr_x.1 : Float(159882:1) = aten::add(%508, %510, %511), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:194:0

aten::mul %522 : Float(159882:1) = aten::mul(%heights.1, %521), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:195:0

aten::add %ctr_y.1 : Float(159882:1) = aten::add(%520, %522, %523), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:195:0

aten::div %dx.1 : Float(159882:1, 1:1) = aten::div(%534, %535), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:198:0

aten::div %dy.1 : Float(159882:1, 1:1) = aten::div(%546, %547), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:199:0

aten::div %dw.1 : Float(159882:1, 1:1) = aten::div(%558, %559), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:200:0

aten::div %dh.1 : Float(159882:1, 1:1) = aten::div(%570, %571), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:201:0

aten::clamp %dw.2 : Float(159882:1, 1:1) = aten::clamp(%dw.1, %573, %574), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:204:0

aten::clamp %dh.2 : Float(159882:1, 1:1) = aten::clamp(%dh.1, %576, %577), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:205:0

aten::mul %586 : Float(159882:1, 1:1) = aten::mul(%dx.1, %585), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:207:0

aten::add %pred_ctr_x.1 : Float(159882:1, 1:1) = aten::add(%586, %593, %594), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:207:0

aten::mul %603 : Float(159882:1, 1:1) = aten::mul(%dy.1, %602), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:208:0

aten::add %pred_ctr_y.1 : Float(159882:1, 1:1) = aten::add(%603, %610, %611), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:208:0

aten::exp %613 : Float(159882:1, 1:1) = aten::exp(%dw.2), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:209:0

aten::mul %pred_w.1 : Float(159882:1, 1:1) = aten::mul(%613, %620), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:209:0

aten::exp %622 : Float(159882:1, 1:1) = aten::exp(%dh.2), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:210:0

aten::mul %pred_h.1 : Float(159882:1, 1:1) = aten::mul(%622, %629), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:210:0

aten::mul %639 : Float(159882:1, 1:1) = aten::mul(%638, %pred_w.1), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:212:0

aten::sub %pred_boxes1.1 : Float(159882:1, 1:1) = aten::sub(%pred_ctr_x.1, %639, %640), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:212:0

aten::mul %650 : Float(159882:1, 1:1) = aten::mul(%649, %pred_h.1), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:213:0

aten::sub %pred_boxes2.1 : Float(159882:1, 1:1) = aten::sub(%pred_ctr_y.1, %650, %651), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:213:0

aten::mul %661 : Float(159882:1, 1:1) = aten::mul(%660, %pred_w.1), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:214:0

aten::add %pred_boxes3.1 : Float(159882:1, 1:1) = aten::add(%pred_ctr_x.1, %661, %662), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:214:0

aten::mul %672 : Float(159882:1, 1:1) = aten::mul(%671, %pred_h.1), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:215:0

aten::add %pred_boxes4.1 : Float(159882:1, 1:1) = aten::add(%pred_ctr_y.1, %672, %673), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:215:0

aten::flatten %pred_boxes.1 : Float(159882:4, 4:1) = aten::flatten(%677, %678, %679), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:216:0

aten::topk %778 : Float(1:1000, 1000:1), %top_n_idx.1 : Long(1:1000, 1000:1) = aten::topk(%x.54, %774, %775, %776, %777), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:373:0

aten::add %782 : Long(1:1000, 1000:1) = aten::add(%top_n_idx.1, %780, %781), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:374:0

aten::topk %808 : Float(1:1000, 1000:1), %top_n_idx.2 : Long(1:1000, 1000:1) = aten::topk(%x.55, %804, %805, %806, %807), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:373:0

aten::add %811 : Long(1:1000, 1000:1) = aten::add(%top_n_idx.2, %offset.1, %810), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:374:0

aten::topk %836 : Float(1:1000, 1000:1), %top_n_idx.3 : Long(1:1000, 1000:1) = aten::topk(%x.56, %832, %833, %834, %835), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:373:0

aten::add %839 : Long(1:1000, 1000:1) = aten::add(%top_n_idx.3, %offset.2, %838), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:374:0

aten::topk %864 : Float(1:1000, 1000:1), %top_n_idx.4 : Long(1:1000, 1000:1) = aten::topk(%x.57, %860, %861, %862, %863), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:373:0

aten::add %867 : Long(1:1000, 1000:1) = aten::add(%top_n_idx.4, %offset.3, %866), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:374:0

aten::topk %892 : Float(1:507, 507:1), %top_n_idx.5 : Long(1:507, 507:1) = aten::topk(%x.58, %888, %889, %890, %891), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:373:0

aten::add %895 : Long(1:507, 507:1) = aten::add(%top_n_idx.5, %offset, %894), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/rpn.py:374:0

aten::max %boxes_x.2 : Float(4507:2, 2:1) = aten::max(%boxes_x.1, %1002), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:124:0

aten::min %boxes_x.3 : Float(4507:2, 2:1) = aten::min(%boxes_x.2, %1011), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:125:0

aten::max %boxes_y.2 : Float(4507:2, 2:1) = aten::max(%boxes_y.1, %1020), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:126:0

aten::min %boxes_y.3 : Float(4507:2, 2:1) = aten::min(%boxes_y.2, %1029), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:127:0

aten::sub %1061 : Float(4507:1) = aten::sub(%1051, %1059, %1060), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:100:0

aten::sub %1079 : Float(4507:1) = aten::sub(%1069, %1077, %1078), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:100:0

aten::ge %1081 : Bool(4507:1) = aten::ge(%1061, %1080), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torch/tensor.py:23:0

aten::ge %1083 : Bool(4507:1) = aten::ge(%1079, %1082), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torch/tensor.py:23:0

aten::__and__ %keep.1 : Bool(4507:1) = aten::__and__(%1081, %1083), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:101:0

aten::nonzero %1085 : Long(4507:1, 1:1) = aten::nonzero(%keep.1), scope: __module.m/__module.m.rpn # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:102:0

aten::sub %widths : Float(1000:1) = aten::sub(%61, %69, %70), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:192:0

aten::sub %heights : Float(1000:1) = aten::sub(%79, %87, %88), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:193:0

aten::mul %99 : Float(1000:1) = aten::mul(%widths, %98), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:194:0

aten::add %ctr_x : Float(1000:1) = aten::add(%97, %99, %100), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:194:0

aten::mul %111 : Float(1000:1) = aten::mul(%heights, %110), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:195:0

aten::add %ctr_y : Float(1000:1) = aten::add(%109, %111, %112), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:195:0

aten::div %dx : Float(1000:91, 91:1) = aten::div(%123, %124), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:198:0

aten::div %dy : Float(1000:91, 91:1) = aten::div(%135, %136), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:199:0

aten::div %dw.3 : Float(1000:91, 91:1) = aten::div(%147, %148), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:200:0

aten::div %dh.3 : Float(1000:91, 91:1) = aten::div(%159, %160), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:201:0

aten::clamp %dw : Float(1000:91, 91:1) = aten::clamp(%dw.3, %162, %163), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:204:0

aten::clamp %dh : Float(1000:91, 91:1) = aten::clamp(%dh.3, %165, %166), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:205:0

aten::mul %175 : Float(1000:91, 91:1) = aten::mul(%dx, %174), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:207:0

aten::add %pred_ctr_x : Float(1000:91, 91:1) = aten::add(%175, %182, %183), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:207:0

aten::mul %192 : Float(1000:91, 91:1) = aten::mul(%dy, %191), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:208:0

aten::add %pred_ctr_y : Float(1000:91, 91:1) = aten::add(%192, %199, %200), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:208:0

aten::exp %202 : Float(1000:91, 91:1) = aten::exp(%dw), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:209:0

aten::mul %pred_w : Float(1000:91, 91:1) = aten::mul(%202, %209), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:209:0

aten::exp %211 : Float(1000:91, 91:1) = aten::exp(%dh), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:210:0

aten::mul %pred_h : Float(1000:91, 91:1) = aten::mul(%211, %218), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:210:0

aten::mul %228 : Float(1000:91, 91:1) = aten::mul(%227, %pred_w), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:212:0

aten::sub %pred_boxes1 : Float(1000:91, 91:1) = aten::sub(%pred_ctr_x, %228, %229), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:212:0

aten::mul %239 : Float(1000:91, 91:1) = aten::mul(%238, %pred_h), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:213:0

aten::sub %pred_boxes2 : Float(1000:91, 91:1) = aten::sub(%pred_ctr_y, %239, %240), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:213:0

aten::mul %250 : Float(1000:91, 91:1) = aten::mul(%249, %pred_w), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:214:0

aten::add %pred_boxes3 : Float(1000:91, 91:1) = aten::add(%pred_ctr_x, %250, %251), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:214:0

aten::mul %261 : Float(1000:91, 91:1) = aten::mul(%260, %pred_h), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:215:0

aten::add %pred_boxes4 : Float(1000:91, 91:1) = aten::add(%pred_ctr_y, %261, %262), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:215:0

aten::flatten %pred_boxes : Float(1000:364, 364:1) = aten::flatten(%266, %267, %268), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/_utils.py:216:0

aten::softmax %276 : Float(1000:91, 91:1) = aten::softmax(%18, %274, %275), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:1498:0

aten::max %boxes_x.5 : Float(1000:182, 91:2, 2:1) = aten::max(%boxes_x.4, %302), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:124:0

aten::min %boxes_x : Float(1000:182, 91:2, 2:1) = aten::min(%boxes_x.5, %311), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:125:0

aten::max %boxes_y.5 : Float(1000:182, 91:2, 2:1) = aten::max(%boxes_y.4, %320), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:126:0

aten::min %boxes_y : Float(1000:182, 91:2, 2:1) = aten::min(%boxes_y.5, %329), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:127:0

aten::gt %399 : Bool(90000:1) = aten::gt(%scores.5, %398), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torch/tensor.py:23:0

aten::nonzero %400 : Long(0:1, 1:1) = aten::nonzero(%399), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/roi_heads.py:699:0

aten::sub %450 : Float(0:1) = aten::sub(%440, %448, %449), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:100:0

aten::sub %468 : Float(0:1) = aten::sub(%458, %466, %467), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:100:0

aten::ge %470 : Bool(0:1) = aten::ge(%450, %469), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torch/tensor.py:23:0

aten::ge %472 : Bool(0:1) = aten::ge(%468, %471), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torch/tensor.py:23:0

aten::__and__ %keep.10 : Bool(0:1) = aten::__and__(%470, %472), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:101:0

aten::nonzero %474 : Long(0:1, 1:1) = aten::nonzero(%keep.10), scope: __module.m/__module.m.roi_heads # /usr/local/lib/python3.8/dist-packages/torchvision/ops/boxes.py:102:0

aten::div %ratio_height : Float() = aten::div(%52, %60), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:269:0

aten::div %ratio_width : Float() = aten::div(%69, %77), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:269:0

aten::mul %xmin : Float(0:1) = aten::mul(%xmin.1, %ratio_width), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:276:0

aten::mul %xmax : Float(0:1) = aten::mul(%xmax.1, %ratio_width), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:277:0

aten::mul %ymin : Float(0:1) = aten::mul(%ymin.1, %ratio_height), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:278:0

aten::mul %ymax : Float(0:1) = aten::mul(%ymax.1, %ratio_height), scope: __module.m # /usr/local/lib/python3.8/dist-packages/torchvision/models/detection/transform.py:279:0

Out[18]:
%3 FasterRCNN cluster_.95 rpn (RegionProposalNetwork) cluster_.96 roi_heads (RoIHeads) inp_inp inp_inp .93 transform (GeneralizedRCNNTransform) inp_inp->.93 .94 backbone (BackboneWithFPN) .93->.94 rpn.inp_10 rpn.inp_10 .93->rpn.inp_10 rpn.inp_11 rpn.inp_11 .93->rpn.inp_11 rpn.inp_12 rpn.inp_12 .93->rpn.inp_12 roi_heads.inp_2 roi_heads.inp_2 .93->roi_heads.inp_2 roi_heads.inp_3 roi_heads.inp_3 .93->roi_heads.inp_3 inp_out_0 div mul .93->inp_out_0 rpn.inp_1 rpn.inp_1 .94->rpn.inp_1 rpn.inp_2 rpn.inp_2 .94->rpn.inp_2 rpn.inp_3 rpn.inp_3 .94->rpn.inp_3 rpn.inp_4 rpn.inp_4 .94->rpn.inp_4 rpn.inp_5 rpn.inp_5 .94->rpn.inp_5 rpn.inp_6 rpn.inp_6 .94->rpn.inp_6 rpn.inp_7 rpn.inp_7 .94->rpn.inp_7 rpn.inp_8 rpn.inp_8 .94->rpn.inp_8 rpn.inp_9 rpn.inp_9 .94->rpn.inp_9 roi_heads.inp_4 roi_heads.inp_4 .94->roi_heads.inp_4 roi_heads.inp_5 roi_heads.inp_5 .94->roi_heads.inp_5 roi_heads.inp_6 roi_heads.inp_6 .94->roi_heads.inp_6 roi_heads.inp_7 roi_heads.inp_7 .94->roi_heads.inp_7 roi_heads.inp_8 roi_heads.inp_8 .94->roi_heads.inp_8 roi_heads.inp_9 roi_heads.inp_9 .94->roi_heads.inp_9 roi_heads.inp_10 roi_heads.inp_10 .94->roi_heads.inp_10 roi_heads.inp_11 roi_heads.inp_11 .94->roi_heads.inp_11 rpn..1146 rpn.head (RPNHead) rpn.inp_1->rpn..1146 rpn..1147 rpn.anchor_generator (AnchorGenerator) rpn.inp_1->rpn..1147 rpn.inp_2->rpn..1146 rpn.inp_2->rpn..1147 rpn.inp_3->rpn..1146 rpn.inp_3->rpn..1147 rpn.inp_4->rpn..1146 rpn.inp_5->rpn..1146 rpn.inp_5->rpn..1147 rpn.inp_6->rpn..1147 rpn.inp_7->rpn..1147 rpn.inp_8->rpn..1147 rpn.inp_9->rpn..1147 rpn.inp_10->rpn..1147 rpn.boxes.5 div clamp topk mul sub exp ge max add flatten min __and nonzero rpn.inp_11->rpn.boxes.5 rpn.scores.2 div clamp topk mul sub exp ge max add flatten min __and nonzero rpn.inp_11->rpn.scores.2 rpn.lvl div clamp topk mul sub exp ge max add flatten min __and nonzero rpn.inp_11->rpn.lvl inp_rpn.out_0 ge div clamp topk mul max add flatten min __and nonzero sub exp rpn.inp_11->inp_rpn.out_0 rpn.inp_12->rpn.boxes.5 rpn.inp_12->rpn.scores.2 rpn.inp_12->rpn.lvl rpn.inp_12->inp_rpn.out_0 rpn..1146->rpn.boxes.5 rpn..1146->rpn.scores.2 rpn..1146->rpn.lvl rpn..1146->inp_rpn.out_0 rpn..1147->rpn.boxes.5 rpn..1147->rpn.scores.2 rpn..1147->rpn.lvl rpn..1147->inp_rpn.out_0 rpn..keep.6 batched_nms rpn..keep.6->inp_rpn.out_0 rpn.boxes.5->rpn..keep.6 rpn.scores.2->rpn..keep.6 rpn.lvl->rpn..keep.6 rpn.out_0 rpn.out_0 roi_heads.inp_1 roi_heads.inp_1 rpn.out_0->roi_heads.inp_1 inp_rpn.out_0->rpn.out_0 roi_heads..546 roi_heads.box_roi_pool (MultiScaleRoIAlign) roi_heads.inp_1->roi_heads..546 roi_heads.boxes.17 div clamp mul sub exp ge max add softmax flatten min nonzero __and gt roi_heads.inp_1->roi_heads.boxes.17 roi_heads.scores div clamp mul sub exp ge softmax max add flatten min nonzero __and gt roi_heads.inp_1->roi_heads.scores roi_heads.labels div clamp mul sub exp ge softmax max add flatten min nonzero __and gt roi_heads.inp_1->roi_heads.labels inp_roi_heads.out_0 div clamp mul sub exp ge max add softmax flatten min nonzero __and gt roi_heads.inp_1->inp_roi_heads.out_0 roi_heads.inp_2->roi_heads..546 roi_heads.inp_2->roi_heads.boxes.17 roi_heads.inp_2->roi_heads.scores roi_heads.inp_2->roi_heads.labels roi_heads.inp_2->inp_roi_heads.out_0 roi_heads.inp_3->roi_heads..546 roi_heads.inp_3->roi_heads.boxes.17 roi_heads.inp_3->roi_heads.scores roi_heads.inp_3->roi_heads.labels roi_heads.inp_3->inp_roi_heads.out_0 roi_heads.inp_4->roi_heads..546 roi_heads.inp_5->roi_heads..546 roi_heads.inp_6->roi_heads..546 roi_heads.inp_7->roi_heads..546 roi_heads.inp_8->roi_heads..546 roi_heads.inp_9->roi_heads..546 roi_heads.inp_10->roi_heads..546 roi_heads.inp_11->roi_heads..546 roi_heads..547 roi_heads.box_head (TwoMLPHead) roi_heads..546->roi_heads..547 roi_heads..548 roi_heads.box_predictor (FastRCNNPredictor) roi_heads..547->roi_heads..548 roi_heads..548->roi_heads.boxes.17 roi_heads..548->roi_heads.scores roi_heads..548->roi_heads.labels roi_heads..548->inp_roi_heads.out_0 roi_heads..keep.15 batched_nms roi_heads..keep.15->inp_roi_heads.out_0 roi_heads.boxes.17->roi_heads..keep.15 roi_heads.scores->roi_heads..keep.15 roi_heads.labels->roi_heads..keep.15 roi_heads.out_0 roi_heads.out_0 roi_heads.out_0->inp_out_0 inp_roi_heads.out_0->roi_heads.out_0 out_0 out_0 inp_out_0->out_0
In [ ]:
 
In [ ]:
 
In [ ]:
 
In [ ]: