Home » Capturing and Deploying PyTorch Models with torch.export

Capturing and Deploying PyTorch Models with torch.export

 we think of model project, most of our attention tends to focus on the big problems, such as: creating and curating datasets, designing the best ML architecture, acquiring appropriately large GPU clusters for training, and building an inference solution that meets target quality-of-service (QOS) requirements. However, it’s often the small details that become our Achilles’ heel, leading to unanticipated bugs and significant production delays.

One detail that’s often overlooked is the handoff of a trained model to the inference environment. While this handoff may seem trivial, it can easily become the source of a great deal of frustration. The training and inference environments are rarely identical, with differences varying from runtime libraries to hardware targets. To navigate these differences, the AI/ML model developer must ensure:

  1. that the model definition, along with its trained weights, is loaded properly into the inference environment, and
  2. that the model’s behavior does not change.

This post will focus on the first challenge — reliable restoration of the model definition and state in an inference environment. We will survey some of the legacy options and their shortcomings. We will then introduce the new torch.export API and demonstrate it in action on a toy model built with HuggingFace’s transformers library (version 4.54.0). For our experiments, we will use an Amazon EC2 g5.xlarge instance (containing an NVIDIA A10G GPU and 4 vCPUs) running a PyTorch (2.7) Deep Learning AMI (DLAMI).

One of the technologies underlying torch.export is Torch Dynamo, a key component of PyTorch’s graph compilation solution, torch.compile. In a recent post we demonstrated the power of torch.compile and its importance in optimizing the runtime performance of AI/ML models. In some ways this post can be viewed as a sequel: we will revisit some of the key concepts — particularly graph-breaks, we will reuse the same toy model, and we will demonstrate the use of torch.compile for just-in-time (JIT) compilation in our inference environment. While we recommend reading the previous post, it is not a prerequisite for this one.

Disclaimers:

As of the time of this writing (PyTorch 2.8.0), torch.export is a “prototype” feature. While the behaviors discussed in this post are likely to mostly remain, the API definitions could change. Be sure to align your use of torch.export to the latest API versions.

This post will cover only a subset of torch.export’s features and behaviors. It should not be viewed as an alternative to the official PyTorch documentation. There are a number of pages covering the export feature, including an introductory tutorial, an overview of the programming model, and a guide to solving export challenges.

While we will demonstrate its functionality on a toy example, torch.export depends heavily on the model’s details and may exhibit very different behavior in your own project.

The code we will share is for demonstration purposes only and should not be relied on for correctness or optimality. Please do not interpret our choice of platform, framework, or any other tool or library as an endorsement for its use.

A Toy HuggingFace Model

To facilitate our discussion, we define a simple image-to-text generative model using the HuggingFace transformers library. For simplicity, we skip the training phase and assume that the default constructor returns a pretrained model.

import torch

NUM_TOKENS = 1024
MAX_SEQ_LEN = 256
PAD_ID = 0
START_ID = 1
END_ID = 2

# Set up an image-to-text model.
def get_model():

    # import transformers utilities
    from transformers import (
        VisionEncoderDecoderModel,
        VisionEncoderDecoderConfig,
        AutoConfig
    )

    config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(
        encoder_config=AutoConfig.for_model("vit"),  # vit encoder
        decoder_config=AutoConfig.for_model("gpt2")  # gpt2 decoder
    )
    config.decoder.vocab_size = NUM_TOKENS
    config.decoder.use_cache = False
    config.decoder_start_token_id = START_ID
    config.pad_token_id = PAD_ID
    config.eos_token_id = END_ID
    config.max_length = MAX_SEQ_LEN

    model = VisionEncoderDecoderModel(config=config)
    model.encoder.pooler = None  # remove unused pooler
    model.eval() # prepare the model for evaluation
    return model

We define an auto-regressive image-to-text generator that uses the encoder and decoder components of the model to produce a caption for the input image. For simplicity, we use a basic implementation, leaving out common optimization techniques, such as KV caching.

# generate the next token
def generate_token(decoder, encoder_hidden_states, sequence):
    outputs = decoder(
        sequence,
        encoder_hidden_states
    )
    logits = outputs[0][:, -1, :]
    return torch.argmax(logits, dim=-1, keepdim=True)

# simple auto-regressive sequence generator
def image_to_text_generator(encoder, decoder, image):
    # run encoder
    encoder_hidden_states = encoder(image)[0]

    # initialize sequence
    generated_ids = torch.ones(
        (image.shape[0], 1),
        dtype=torch.long,
        device=image.device
    ) * START_ID

    for _ in range(MAX_SEQ_LEN):
        # generate next token
        next_token = generate_token(
            decoder,
            encoder_hidden_states,
            generated_ids
        )
        generated_ids = torch.cat([generated_ids, next_token], dim=-1)
        if (next_token == END_ID).all():
            break

    return generated_ids

The following code block demonstrates the use of our generator on a batch of random input images.

import os, time, random, torch

torch.manual_seed(42)
random.seed(42)

BATCH_SIZE = 64
EXPORT_PATH = '/tmp/export/'

def test_inference(model_path=EXPORT_PATH, mode=None, compile=False):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    rnd_image = torch.randn(BATCH_SIZE, 3, 224, 224).to(device)
    encoder, decoder = load_model(model_path, mode)
    encoder = encoder.to(device)
    decoder = decoder.to(device)

    if compile:
        encoder = torch.compile(encoder, mode="reduce-overhead")
        decoder = torch.compile(decoder, dynamic=True)
        # run a few warmup rounds
        for i in range(10):
            image_to_text_generator(encoder, decoder, random_image)

    t0 = time.perf_counter()
    
    # optionally enable mixed precision
    with torch.amp.autocast(device, dtype=torch.bfloat16, enabled=True):
        with torch.no_grad():
            caption = image_to_text_generator(encoder, decoder, rnd_image)

    total_time = time.perf_counter() - t0
    print(f'batched inference total time: {total_time}')

The encoder and decoder models are loaded via a load_model utility. We define an initial implementation and amend it later based on our choice of model capturing strategy.

To prepare our model for the export process, we define a pass-through wrapper class for the decoder. This wrapper ensures the model can be traced using positional (rather than keyword) arguments, which is a current requirement of torch.export.

class DecoderWrapper(torch.nn.Module):
    def __init__(self, decoder_model):
        super().__init__()
        self.decoder = decoder_model

    def forward(self, input_ids, encoder_hidden_states):
        return self.decoder(
            input_ids=input_ids,
            encoder_hidden_states=encoder_hidden_states,
            use_cache=False,
            output_attentions=False,
            output_hidden_states=False,
            return_dict=False
        )

def load_model(path=EXPORT_PATH, mode=None):
    model = get_model()
    encoder = model.encoder
    decoder = model.decoder
    return encoder, DecoderWrapper(decoder)

Now that we have defined our toy model, let’s explore different strategies for capturing and deploying it to an inference environment.

Model Capturing and Deployment Strategies

In this section, we’ll review two common methods for capturing and deploying an AI/ML model’s state: weights-only capture and conversion to a serializable intermediate graph-representation using TorchScript.

Weights Only Capture

The first option is to use torch.save to capture only the model weights, not the model definition. This requires you to explicitly redefine the PyTorch model in the inference environment. There are several ways to carryg over the code definition including: copy-pasting (which is very error-prone), pulling from a shared code repository, or using Python packages. If the model definition relies on specific Python package dependencies, you will need to make sure that those packages exist in the inference environment and that their versions match the training environment.

The following code block demonstrates capturing and loading the model weights for our toy model:

def capture_model(model, path=EXPORT_PATH):
    # weights only
    weights_path = os.path.join(EXPORT_PATH, "weights.pth")
    torch.save(model.state_dict(), weights_path)

def load_model(path=EXPORT_PATH, mode=None):
    if mode == 'weights':
        model = get_model()
        weights_path = os.path.join(path,"weights.pth")
        state_dict = torch.load(weights_path, map_location="cpu")
        model.load_state_dict(state_dict)
        return model.encoder, DecoderWrapper(model.decoder)
    else:
        model = get_model()
        return model.encoder, DecoderWrapper(model.decoder)

One advantage of this method is that it provides maximum flexibility for tuning the model’s configuration to the inference environment. For example, you can apply machine-specific optimizations that increase the throughput of the inference workload. This freedom is especially important for very large models that require advanced sharding techniques for model execution.

However, this method assumes you can easily design and configure the inference environment as desired — which cannot be taken for granted. Some inference environments are extremely constrained — with limited control over the runtime libraries and package installations.

The separation between the model definition and the model weights also can be the source of all sorts of ugly bugs. Ensuring appropriate alignment between the source code and the model weights requires disciplined version management. A preferred approach would be to bundle the model definition and weights into a single archive.

TorchScript Variants

For many years, the primary method for capturing a PyTorch model along with its weights was TorchScript. And despite its recent deprecation, it remains widely popular. TorchScript encapsulates two different capturing solutions, torch.jit.script and torch.jit.trace, both of which convert PyTorch models into serializable graph representations. This graph can be loaded into another environment and run as a standalone PyTorch program, with minimal runtime dependencies.

The scripting and tracing functionalities are complementary. Scripting performs ahead-of-time (AOT) static analysis of the source code, whereas tracing traces the actual execution of the model on a sample input. Scripting is able to capture more complexity in the graph, such as conditional control-flow and dynamic shapes. This is contrary to tracing that captures just the execution path and tensor shapes dictated by the input sample it runs on. On the other hand, torch.jit.trace supports more operations than torch.jit.script (e.g., see here). Often, scripting will succeed where tracing will fail and vice versa. Sometimes, some form of combination of both methods is required.

Let’s now attempt to convert our toy model to TorchScript. Since the sequence generator makes separate calls to the encoder (once) and the decoder (iteratively) separately, we’ll capture them as separate graphs. The encoder takes fixed-shape inputs and has no input-dependent conditional logic, so we can apply the more flexible tracing option. However, the sequence that we input to the decoder increases in size each time a token is generated, so we have no choice but to use the scripting option.

HuggingFace supports TorchScript via a dedicated configuration. See the HuggingFace documentation for more details.

config.decoder.torchscript = True
config.encoder.torchscript = True

In the code block below, we extend our capture and loading utilities with TorchScript support. Note the inclusion of the torch.jit.freeze optimization during capture and the use the torch.jit.optimize_for_inference optimization on the target device.

def capture_model(model, path=EXPORT_PATH):
    # weights only
    weights_path = os.path.join(EXPORT_PATH, "weights.pth")
    torch.save(model.state_dict(), weights_path)

    encoder = model.encoder
    decoder = DecoderWrapper(model.decoder)

    # torchscript encoder using trace
    example = torch.randn(1, 3, 224, 224)
    encoder_jit = torch.jit.trace(encoder, example)
    # optionally apply jit.freeze optimization
    encoder_jit = torch.jit.freeze(encoder_jit)
    encoder_path = os.path.join(path, "encoder.pt")
    torch.jit.save(encoder_jit, encoder_path)

    try:
        # torchscript decoder using scripting
        decoder_jit = torch.jit.script(decoder)
        # optionally apply jit.freeze optimization
        decoder_jit = torch.jit.freeze(decoder_jit)
        decoder_path = os.path.join(path, "decoder.pt")
        torch.jit.save(decoder_jit, decoder_path)
    except Exception as e:
        print(f'torch.jit.script(model.decoder) failedn{e}')

def load_model(path=EXPORT_PATH, mode=None):
    if mode == 'weights':
        model = get_model()
        weights_path = os.path.join(path,"weights.pth")
        state_dict = torch.load(weights_path, map_location="cpu")
        model.load_state_dict(state_dict)
        return model.encoder, DecoderWrapper(model.decoder)
    elif mode == 'torchscript':
        encoder_path = os.path.join(path, "encoder.pt")
        decoder_path = os.path.join(path, "decoder.pt")
        encoder = torch.jit.load(encoder_path)
        decoder = torch.jit.load(decoder_path)
        # optionally apply target-device optimization 
        encoder = torch.jit.optimize_for_inference(encoder)
        decoder = torch.jit.optimize_for_inference(decoder)
        return encoder, decoder
    else:
        model = get_model()
        return model.encoder, DecoderWrapper(model.decoder)

Unfortunately, our capture utility fails when trying to script the decoder model. A common problem with TorchScript is that it often fails on complex models. While many issues can be bypassed by using torch.jit.trace, it cannot be applied to models with dynamic shapes or conditional logic, such as our decoder. It is often possible to rewrite the model implementation to be script-compliant, but it can require a lot of painstaking (and ugly) work. In the case of our decoder model, it would require a lot of intrusive patchwork to the transformers library source code.

For more on the topic of TorchScript, please see the official documentation.

Model Capturing with torch.export

The new way of capturing a model for deployment is torch.export. Similar to torch.jit.trace, export works by tracing the model’s execution on input samples. However, unlike torch.jit.trace, export includes support for dynamism and conditional control flow. The output of the export utility is an intermediate graph representation called Export IR, which can be loaded and executed in a clean inference environment. One of the advantages of our exported model is that, contrary to TorchScript models, they can be optimized for the inference environment using torch.compile. On the other hand, optimizations that require source code changes (e.g., configuring the attn_implementation) cannot be applied.

Overcoming Graph Breaks

graph break occurs when the export function encounters an “untraceable” portion of Python code (e.g., see here for unsupported operations). We encountered the concept of graph breaks in our previous post on graph compilation. However, contrary to model compilation where PyTorch will simply fall back to eager mode, torch.export forbids the presence of graph breaks. If export fails on your model due to a graph break, you will have to rewrite your code to bypass it.

There are several resources at your disposal for overcoming graph breaks, including the Draft Export utility, which generates a detailed report of export issues, ExportDB, which maintains a list of supported and unsupported export cases, and a tutorial on overcoming common export issues.

Debugging an Exported Graph

On some occasions, you may succeed in exporting a graph only to find that running it on the target device either fails with an error or returns incorrect output. A common cause of such errors is that some of the variables from the export environment are treated as constants and baked into the exported graph.

Unfortunately, (as of the time of this writing) the tools for debugging Export IR graphs are somewhat limited. Although the exported model is a torch.nn.Module with a forward function, you cannot use a debugger to step into it to find the source of the errors.

We can, however, inspect the contents of the generated graph using GraphModule.print_readable. This will print out all of the graph operations along with comments that point to the source code from which they were generated. Sometimes, this, combined with the output error information, is enough to find the source of the errors and tweak the source code accordingly. See below for an example.

Exporting a HuggingFace Model

To apply torch.export to our toy model, we first make sure to update the PyTorch library to the latest version (2.8.0 at the time of this writing). The export utility is under rapid development and we want to make sure to get the most up-do-date feature support.

In the code block below we revise our capture and load utility functions to support torch.export. Note our use of torch.export.Dim to specify dynamic dimensions:

def capture_model(model, path=EXPORT_PATH):
    encoder = model.encoder
    decoder = DecoderWrapper(model.decoder)

    # define dynamic dimensions
    batch = torch.export.Dim("batch")
    seq_len = torch.export.Dim("seq_len", min=2, max=MAX_SEQ_LEN)

    # export encoder
    # sample input
    example = torch.randn(4, 3, 224, 224)
    encoder_export = torch.export.export(
        encoder,
        (example,),
        dynamic_shapes=((batch,
                         torch.export.Dim.STATIC,
                         torch.export.Dim.STATIC,
                         torch.export.Dim.STATIC),
                        )
    )
    torch.export.save(encoder_export, os.path.join(path, "encoder.pt2"))


    # export decoder
    # get sample input for decoder
    encoder_hidden_states = encoder_export.module()(example)[0]
    decoder_input_ids = torch.ones((4, MAX_SEQ_LEN),
                                   dtype=torch.long)*START_ID


    decoder_export = torch.export.export(
        decoder,
        (decoder_input_ids, encoder_hidden_states),
        dynamic_shapes={
                  'input_ids': (batch,seq_len),
                  'encoder_hidden_states': (batch,
                                            torch.export.Dim.STATIC,
                                            torch.export.Dim.STATIC)
                       }
    )
    torch.export.save(decoder_export, os.path.join(path, "decoder.pt2"))


def load_model(path=EXPORT_PATH, mode=None):
    if mode == 'weights':
        model = get_model()
        weights_path = os.path.join(path,"weights.pth")
        state_dict = torch.load(weights_path, map_location="cpu")
        model.load_state_dict(state_dict)
        return model.encoder, DecoderWrapper(model.decoder)
   elif mode == 'export':
        encoder_path = os.path.join(path, "encoder.pt2")
        decoder_path = os.path.join(path, "decoder.pt2")
        encoder = load(encoder_path).module()
        decoder = load(decoder_path).module()
        return encoder, decoder
    else:
        model = get_model()
        return model.encoder, DecoderWrapper(model.decoder)

Contrary to TorchScript, torch.export had no issues capturing our decoder and encoder models. In particular, no graph breaks were encountered during tracing.

Deploying an Exported Model

To complete our demonstration, we’ll test our exported model in a clean inference environment. For this experiment, we’ll use an Amazon EC2 g5.xlarge instance (containing an NVIDIA A10G GPU and 4 vCPUs) running a PyTorch (2.7) Deep Learning AMI (DLAMI). We’ll update to the latest PyTorch version but intentionally won’t install the transformers package

Unfortunately, our excitement at the success of exporting our model was premature, as running the exported decoder on the GPU results in a runtime error, a portion of which we’ve pasted below:

File ".24", line 306, in forward

File "/opt/pytorch/lib/python3.12/site-packages/torch/_ops.py", line 829, in __call__

  return self._op(*args, **kwargs)

         ^^^^^^^^^^^^^^^^^^^^^^^^^

RuntimeError: Expected all tensors to be on the same device, but got index is on cpu, different from other tensors on cuda:0 (when checking argument in method wrapper_CUDA__index_select)

The error indicates during the execution the exported graph, some of the tensors reside on the CPU when they are expected to be on the GPU. Since we explicitly copy the input tensors onto the GPU, we can conclude that these refer to tensors that the graph is creating internally. Since the graph was exported on a CPU device, the use of device=”cpu” was baked into the graph creation, resulting in a runtime error when running on a GPU.

Although the error message points to the faulty line of code (File “.24”, line 306, in forward), there is no actual file that we can add a breakpoint to and debug. We can, however, inspect the contents of the decoder graph and search for places where the use of the CPU device has been inadvertently baked in:

decoder_export.module().print_readable()

Analyzing the output (searching for “cpu” references) and cross-referencing the source code using the embedded comments, we discover four locations where the transformers library (modeling_gpt2.py file) creates tensors on the CPU:

  1. If not specified, the GPT2 model creates of a cache_position tensor on the baked-in CPU device using torch.arange. This can be fixed by passing in a user-defined value for cache_position.
  2. On line 861, a torch.Tensor.to operation is performed to ensure proper placement of the position_embeds tensor. While this may be required in a case of model parallelism, we do not require it.
  3. If not specified, the model creates a causal mask using tensors that it creates on the baked-in CPU device. We could bypass this by passing in a user-defined causal mask, but we are perfectly happy keeping this None and relying on the use of the is_causal flag of the sdpa attention function.
  4. If not specified, the model creates an ecnoder_attention_mask tensor on the baked-in CPU. Once again, we could specify a value for this tensor, but since the mask would be all True, setting it None achieves the same purpose.

The following patch summarizes the changes we performed on the modeling_gpt2.py file. Importantly, these changes were specific to our toy model and will not generalize to all use cases. This kind of monkey-patching is ill-advised and requires extensive testing before being used in a production setting.

@@ -861 +861 @@
-        hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device)
+        hidden_states = inputs_embeds + position_embeds
@@ -867,3 +867 @@
-        causal_mask = self._update_causal_mask(
-            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
-        )
+        causal_mask = None
@@ -877 +875 @@
-            if encoder_attention_mask is None:
+            if not _use_sdpa and encoder_attention_mask is None:
@@ -880,3 +878 @@
-                encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
-                    mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1]
-                )
+                pass

In the following code block, we extend the decoder definitions and export implementation with an explicit value for cache_position.

def generate_token(decoder, encoder_hidden_states, sequence):
    outputs = decoder(
        sequence,
        encoder_hidden_states,
        torch.arange(sequence.shape[1], device=sequence.device)
    )
    logits = outputs[0][:, -1, :]
    return torch.argmax(logits, dim=-1, keepdim=True)

class DecoderWrapper(torch.nn.Module):
    def __init__(self, decoder_model):
        super().__init__()
        self.decoder = decoder_model

    def forward(self, input_ids, encoder_hidden_states, cache_position):
        return self.decoder(
            input_ids=input_ids,
            cache_position=cache_position,
            encoder_hidden_states=encoder_hidden_states,
            use_cache=False,
            output_attentions=False,
            output_hidden_states=False,
            return_dict=False
        )

def capture_model(model, path=EXPORT_PATH):
    encoder = model.encoder
    decoder = DecoderWrapper(model.decoder)

    # define dynamic dimensions
    batch = torch.export.Dim("batch")
    seq_len = torch.export.Dim("seq_len", min=2, max=MAX_SEQ_LEN)

    # export encoder
    # sample tensor
    example = torch.randn(4, 3, 224, 224)
    encoder_export = torch.export.export(
        encoder,
        (example,),
        dynamic_shapes=((batch,
                         torch.export.Dim.STATIC,
                         torch.export.Dim.STATIC,
                         torch.export.Dim.STATIC),
                        )
    )
    torch.export.save(encoder_export, os.path.join(path, "encoder.pt2"))

    # export decoder
    # get sample input for decoder
    encoder_hidden_states = encoder_export.module()(example)[0]
    decoder_input_ids = torch.ones((4, MAX_SEQ_LEN), 
                                    dtype=torch.long)*START_ID

    decoder_args = (
        decoder_input_ids,
        encoder_hidden_states,
        torch.arange(MAX_SEQ_LEN)
    )

    dynamic_shapes = {
        'input_ids': (batch,seq_len),
        'encoder_hidden_states': (batch,
                                  torch.export.Dim.STATIC,
                                  torch.export.Dim.STATIC),
        'cache_position': (seq_len,),
    }

    decoder_export = torch.export.export(
        decoder,
        decoder_args,
        dynamic_shapes=dynamic_shapes
    )
    torch.export.save(decoder_export, os.path.join(path, "decoder.pt2"))

Following these change, the exported decoder succeeds in generating sequences on the GPU device. We hope (and expect) that as the torch.export feature evolves, these kinds of issues will be handled automatically by the internal tracing mechanism.

We further test the potential for machine-specific optimizations by applying graph compilation. For details on our choice of compilation parameters, see our previous post.

encoder = torch.compile(encoder, mode="reduce-overhead")
decoder = torch.compile(decoder, dynamic=True)

The table below captures the execution time of our model on a batch of random images, with and without torch.compile. (Note, that running the original model requires installation of the transformers library, version 4.54.0).

Runtime Results (by Author)

We can see that exporting the model to graph representation results in a speed-up of 10.7%. The model compilation, however, had the opposite effect, significantly increasing the execution time. It is likely that this could be fixed through appropriate tuning.

Summary

In this post we explored the new torch.export utility and demonstrated its use in capturing and deploying a toy HuggingFace model. We found that is has a number of powerful and compelling properties, including:

  • Support for complex models: torch.export succeeded in capturing models that failed with TorchScript.
  • Portability: Exported models can be loaded and executed as standalone programs without specific package dependencies.
  • Machine-specific optimizations: exported models are compatible with graph compilation, enabling the application of machine-specific optimizations.

We also encountered some of torch.export’s limitations:

  • Unintended consequences of graph creation: If we are not careful about how we design our model, values from the export environment can be inadvertently baked into the resulting graph, breaking its compatibility with the inference environment.
  • Limited debugging tools: As of this writing, the tools for debugging the execution of exported graphs are limited.

Although it still requires some refinement, torch.export is already a huge improvement over previous capturing solutions like TorchScript. We look forward to seeing it continue to evolve and improve.

For more details on capturing AI/ML models using PyTorch Export and to keep track of API changes, please see the torch.export documentation. If you’re running inference on an edge device, also see the ExecuTorch solution for model deployment, which is based on torch.export.

Related Posts

Leave a Reply

Your email address will not be published. Required fields are marked *