Home » Torchvista: Building an Interactive Pytorch Visualization Package for Notebooks

Torchvista: Building an Interactive Pytorch Visualization Package for Notebooks

In this post, I talk through the motivation, complexities and implementation details of building torchvista, an open-source package to interactively visualize the forward pass of any Pytorch model from within web-based notebooks.

To get a sense of the workings of torchvista while reading this post, you can check out:

  • Github page if you want to install it via pip and use it from web-based notebooks (Jupyter, Colab, Kaggle, VSCode, etc)
  • An interactive demo page with various well-known models visualized
  • A Google Colab tutorial
  • A video demo:

Motivation

Pytorch models can get very large and complex, and making sense of one from the code alone can be a tiresome and even intractable exercise. Having a graph-like visualization of it is just what we need to make this easier.

While there exist tools like Netron, pytorchviz, and torchview that make this easier, my motivation for building torchvista was that I found that they were lacking in some or all of these requirements:

  • Interaction support: The visualized graph should be interactive and not a static image. It should be a structure you can zoom, drag, expand/collapse, etc. Models can get very large, and if all you are see is a gigantic static image of the graph, how can you really explore it?
Drag and zoom to explore a large model
  • Modular exploration: Large Pytorch models are modular in thought and implementation. For example, think of a module which has a Sequential module which contains a few Attention blocks, which in turn each has Fully connected blocks which contain Linear layers with activation functions and so on. The tool should allow you to tap into this modular structure, and not just present a low-level tensor link graph.
Expanding modules in a modular fashion
  • Notebook support: We tend to prototype and build our models in notebooks. If a tool were provided as a standalone application that required you to build your model and load it to visualize it, it’s just too long a feedback loop. So the tool has to ideally work from within notebooks.
Visualization within a Jupyter notebook
  • Error debugging support: While building models from scratch, we often run into many errors until the model is able to run a full forward pass end-to-end. So the visualization tool should be error tolerant and show you a partial visualization graph even if there are errors, so that you can debug the error.
A sample visualization of when torch.cat failed due to mismatched tensor shapes
  • Forward pass tracing: Pytorch natively exposes a backward pass graph through its autograd system, which the package pytorchviz exposes as a graph, but this is different from the forward pass. When we build, study and imagine models, we think more about the forward pass, and this can be very useful to visualize.

Building torchvista

Basic API

The goal was to have a simple API that works with almost any Pytorch model.

import torch
from transformers import XLNetModel
from torchvista import trace_model

model = XLNetModel.from_pretrained("xlnet-base-cased")
example_input = torch.randint(0, 32000, (1, 10))

# Trace it!
trace_model(model, example_input)

With one line of code calling trace_model(, ) it should just produce an interactive visualization of the forward pass.

Steps involved

Behind the scenes, torchvista, when called, works in two phases:

  1. Tracing: This is where torchvista extracts a graph data structure from the forward pass of the model. Pytorch does not inherently expose this graph structure (even though it does expose a graph for the backward pass), so torchvista has to build this data structure by itself.
  2. Visualization: Once the graph is extracted, torchvista has to produce the actual visualization as an interactive graph. torchvista’s tracer does this by loading a template HTML file (with JS embedded within it), and injecting serialized graph data structure objects as strings into the template to be subsequently loaded by the browser engine.
Behind the scenes of trace_model()

Tracing

Tracing is essentially done by (temporarily) wrapping all the important and known tensor operations, and standard Pytorch modules. The goal of wrapping is to modify the functions so that when called, they additionally do the bookkeeping necessary for tracing.

Structure of the graph

The graph we extract from the model is a directed graph where:

  • The nodes are the various Tensor operations and the various inbuilt Pytorch modules that get called during the forward pass
    • Additionally, input and output tensors, and constant valued tensors are also nodes in the graph.
  • An edge exists from one node to the other for each tensor sent from the former to the latter.
  • The edge label is the dimension of the associated tensor.
Example graph with operations and input/output/constant tensors as nodes, and an edge for every tensor that’s sent, with edge label set as the dimensions of the tensor

But, the structure of our graph can be more complicated because most Pytorch modules call tensor operations and sometimes other modules’ forward method. This means we have to maintain a graph structure that holds information to visually explore it at any level of depth.

An example of nested modules shown various depths: TransformerEncoder uses TransformerEncoderLayer which calls multi_head_attention_forward, dropout, and other operations.

Therefore, the structure that torchvista extracts includes two main data structures:

  • Adjacency list of the lowest level operations/modules that get called.
input_0 -> [ linear ]
linear -> [ __add__ ]
__getitem__ -> [ __add__ ]
__add__ -> [ multi_head_attention_forward ]
multi_head_attention_forward -> [ dropout ]
dropout -> [ __add__ ]
  • Hierarchy map that maps each node to its parent module container (if present)
linear -> Linear
multi_head_attention_forward -> MultiheadAttention
MultiheadAttention -> TransformerEncoderLayer
TransformerEncoderLayer -> TransformerEncoder

With both of these, we are able to construct any desired views of the forward pass in the visualization layer.

Wrapping operations and modules

The whole idea behind wrapping is to do some bookkeeping before and after the actual operation, so that when the operation is called, our wrapped function instead gets called, and the bookkeeping is carried out. The goals of bookkeeping are:

  • Record connections between nodes based on tensor references.
  • Record tensor dimensions to show as edge labels.
  • Record module hierarchy for modules in the case where modules are nested within one another

Here is a simplified code snippet of how wrapping works:

original_operations = {}
def wrap_operation(module, operation):
  original_operations[get_hashable_key(module, operation)] = operation
  def wrapped_operation(*args, **kwargs):
    # Do the necessary pre-call bookkeeping
    do_pre_call_bookkeeping()

    # Call the original operation
    result = operation(*args, **kwargs)

    do_post_call_bookkeeping()

    return result
  setattr(module, func_name, wrapped_operation)

for module, operation in LONG_LIST_OF_PYTORCH_OPS:
  wrap_operation(module, operation)

And when trace_model is about to complete, we must reset everything back to its original state:

for module, operation in LONG_LIST_OF_PYTORCH_OPS:
  setattr(module, func_name, original_operations[get_hashable_key(module,
    operation)])

This is done in the same way for the forward() methods of inbuilt Pytorch modules like Linear, Conv2d etc.

Connections between nodes

As stated previously, an edge exists between two nodes if a tensor was sent from one to the other. This forms the basis of creating connections between nodes while building the graph.

Here is a simplified code snippet of how this works:

adj_list = {}
def do_post_call_bookkeeping(module, operation, tensor_output):
  # Set a "marker" on the output tensor so that whoever consumes it
  # knows which operation produced it
  tensor_output._source_node = get_hashable_key(module, operation)

def do_pre_call_bookkeeping(module, operation, tensor_input):
  source_node = tensor_input._source_node

  # Add a link from the producer of the tensor to this node (the consumer)
  adj_list[source_node].append(get_hashable_key(module, operation))
How graph edges are created

Module hierarchy map

When we wrap modules, things have to be done a little differently to build the module hierarchy map. The idea is to maintain a stack of modules currently being called so that the top of the stack always represents in the immediate parent in the hierarchy map.

Here is a simplified code snippet of how this works:

hierarchy_map = {}
module_call_stack = []
def do_pre_call_bookkeeping_for_module(package, module, tensor_output):
  # Add it to the stack
  module_call_stack.append(get_hashable_key(package, module))

def do_post_call_bookkeeping_for_module(module, operation, tensor_input):
  module_call_stack.pop()
  # Top of the stack now is the parent node
  hierarchy_map[get_hashable_key(package, module)] = module_call_stack[-1]

Visualization

This part is entirely handled in Javscript because the visualization happens in web-based notebooks. The key libraries that are used here are:

  • graphviz: for generating the layout for the graph (viz-js is the JS port)
  • d3: for drawing the interactive graph on a canvas
  • iPython: to render HTML contents within a notebook

Graph Layout

Getting the layout for the graph right is an extremely complex problem. The main goal is for the graph to have a top-to-bottom “flow” of edges, and most importantly, for there to not be an overlap between the various nodes, edges, and edge labels.

This is made all the more complex when we are working with a “hierarchical” graph where there are “container” boxes for modules within which the underlying nodes and subcomponents are shown.

A complex layout with a neat top-to-bottom flow and no overlaps

Thankfully, graphviz (viz-js) comes to the rescue for us. graphviz uses a language called “DOT language” through which we specify how we require the graph layout to be constructed.

Here is a sample of the DOT syntax for the above graph:

# Edges and nodes
  "input_0" [width=1.2, height=0.5];
  "output_0" [width=1.2, height=0.5];
  "input_0" -> "linear_1"[label="(1, 16)", fontsize="10", edge_data_id="5623840688" ];
  "linear_1" -> "layer_norm_1"[label="(1, 32)", fontsize="10", edge_data_id="5801314448" ];
  "linear_1" -> "layer_norm_2"[label="(1, 32)", fontsize="10", edge_data_id="5801314448" ];
...

# Module hierarchy specified using clusters
subgraph cluster_FeatureEncoder_1 {
  label="FeatureEncoder_1";
  style=rounded;
  subgraph cluster_MiddleBlock_1 {
    label="MiddleBlock_1";
    style=rounded;
    subgraph cluster_InnerBlock_1 {
      label="InnerBlock_1";
      style=rounded;
      subgraph cluster_LayerNorm_1 {
        label="LayerNorm_1";
        style=rounded;
        "layer_norm_1";
      }
      subgraph cluster_TinyBranch_1 {
        label="TinyBranch_1";
        style=rounded;
        subgraph cluster_MicroBranch_1 {
          label="MicroBranch_1";
          style=rounded;
          subgraph cluster_Linear_2 {
            label="Linear_2";
            style=rounded;
            "linear_2";
          }
...

Once this DOT representation is generated from our adjacency list and hierarchy map, graphviz produces a layout with positions and sizes of all nodes and paths for edges.

Rendering

Once the layout is generated, d3 is used to render the graph visually. Everything is drawn on a canvas (which is easy to make draggable and zoomable), and we set various event handlers to detect user clicks.

When the user makes these two types of expand/collapse clicks on modules (using the ‘+’ ‘-‘ buttons), torchvista records which node the action was performed on, and just re-renders the graph because the layout has to be reconstructed, and then automatically drags and zooms in to an appropriate level based on the recorded pre-click position.

Rendering a graph using d3 is a very detailed topic and otherwise not to unique to torchvista, and hence I leave out the details from this post.

[Bonus] Handling errors in Pytorch models

When users trace their Pytorch models (especially while developing the models), sometimes the models throw errors. It would have been easy for torchvista to just give up when this happens and let the user fix the error first before they could use torchvista. But torchvista instead lends a hand at debugging these errors by doing best-effort tracing of the model. The idea is simple – just trace the maximum it can until the error happens, and then render the graph with just so much (with visual indicators showing where the error happened), and then just raise the exception so that the user can also see the stacktrace like they normally would.

When an error is thrown, the stack trace is also shown below the partially rendered graph

Here is a simplified code snippet of how this works:

def trace_model(...):
  exception = None
  try:
    # All the tracing code
  except Exception as e:
    exception = e
  finally:
    # do all the necessary cleanups (unwrapping all the operations and modules)
  if exception is not None:
    raise exception

Wrapping up

This post shed some light on the journey of building a Pytorch visualization package. We first talked about the very specific motivation for building such a tool by comparing with other similar tools. Then, we discussed the design and implementation of torchvista in two parts. The first part was about the process of tracing the forward pass of a Pytorch model using (temporary) wrapping of operations and modules to extract detailed information about the model’s forward pass, including not only the connections between various operations, but also the module hierarchy. Then, in the second part, we went over the visualization layer, and the complexities of layout generation, which were solved using the right choice of libraries.

torchvista is open source, and all contributions, including feedback, issues and pull requests, are welcome. I hope torchvista helps people of all levels of expertise in building and visualizing their models (regardless of model size), showcasing their work, and as a tool for educating others about machine learning models.

Future directions

Potential future enhancements to torchvista include:

  • Adding support for “rolling”, where if the same substructure of a model is repeated several times, it is shown just once with a count of how many times it repeats
  • Systematic exploration of state-of-the-art models to ensure all their tensor operations are adequately covered
  • Support for exporting static images of models as png or pdf files
  • Efficiency and speed improvements

References

  • Open source libraries used:
  • Dot language from graphviz
  • Other similar visualization tools:
  • torchvista:

All images unless otherwise stated are by the author.

Related Posts

Leave a Reply

Your email address will not be published. Required fields are marked *