Skip to content

Integrate with PyTorch¶

PyTorch is a popular open source machine learning framework based on the Torch library, used for applications such as computer vision and natural language processing.

PyTorch enables fast, flexible experimentation and efficient production through a user-friendly front-end, distributed training, and ecosystem of tools and libraries.

Instrument PyTorch with Comet to start managing experiments, create dataset versions and track hyperparameters for faster and easier reproducibility and collaboration.

Open In Colab

Note: If you are using Pytorch Tensorboard, see our Tensorboard Integration.

Note: This integration also supports PyTorch Distributed Data Parallel. See below.

Start logging¶

Connect Comet to your existing code by adding in a simple Comet Experiment.

Add the following lines of code to your script or notebook:

import comet_ml
from comet_ml.integration.pytorch import watch
import torch

experiment = comet_ml.start()

# Your code here

model = get_model()
watch(model)

Note

There are other ways to configure Comet. See more here.

Log automatically¶

After an Experiment has been created, Comet automatically logs the following PyTorch items, by default, with no additional configuration:

  • Model and graph description
  • Training loss

You can easily turn the automatic logging on and off for any or all items. See Configure Comet for PyTorch for more details.

Note

Don't see what you need to log here? We have your back. You can manually log any kind of data to Comet using the Experiment object. For example, use experiment.log_image to log images, or experiment.log_audio to log audio.

End-to-end example¶

Following is a basic example of using Comet with PyTorch.

If you can't wait, check out the results of this example PyTorch experiment for a preview of what's to come.

Install dependencies¶

python -m pip install "comet_ml>=3.44.0" torch torchvision tqdm

Run the example¶

# coding: utf-8
from comet_ml import start, login
from comet_ml.integration.pytorch import log_model, watch

import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.autograd import Variable

hyper_params = {
    "sequence_length": 28,
    "input_size": 28,
    "hidden_size": 128,
    "num_layers": 2,
    "num_classes": 10,
    "batch_size": 100,
    "num_epochs": 5,
    "learning_rate": 0.01,
}


# Login to Comet if needed
login()

experiment = start(project_name="comet-example-pytorch-doc")
experiment.log_parameters(hyper_params)


# MNIST Dataset
train_dataset = dsets.MNIST(
    root="./data/", train=True, transform=transforms.ToTensor(), download=True
)

test_dataset = dsets.MNIST(root="./data/", train=False, transform=transforms.ToTensor())

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=hyper_params["batch_size"], shuffle=True
)

test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset, batch_size=hyper_params["batch_size"], shuffle=False
)


# RNN Model (Many-to-One)
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        # Set initial states
        h0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size))
        c0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size))

        # Forward propagate RNN
        out, _ = self.lstm(x, (h0, c0))

        # Decode hidden state of last time step
        out = self.fc(out[:, -1, :])
        return out


rnn = RNN(
    hyper_params["input_size"],
    hyper_params["hidden_size"],
    hyper_params["num_layers"],
    hyper_params["num_classes"],
)

# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(rnn.parameters(), lr=hyper_params["learning_rate"])

# Train the Model
total_steps = len(train_dataset) // hyper_params["batch_size"]

with experiment.train():
    watch(rnn)

    step = 0
    for epoch in range(hyper_params["num_epochs"]):
        experiment.log_current_epoch(epoch)
        correct = 0
        total = 0
        for i, (images, labels) in enumerate(train_loader):
            images = Variable(
                images.view(
                    -1, hyper_params["sequence_length"], hyper_params["input_size"]
                )
            )
            labels = Variable(labels)

            # Forward + Backward + Optimize
            optimizer.zero_grad()
            outputs = rnn(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # Compute train accuracy
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += float((predicted == labels.data).sum())

            # Log accuracy to Comet.ml
            experiment.log_metric("accuracy", 100 * correct / total, step=step)
            step += 1

            if (i + 1) % 100 == 0:
                print(
                    "Epoch [%d/%d], Step [%d/%d], Loss: %.4f"
                    % (
                        epoch + 1,
                        hyper_params["num_epochs"],
                        i + 1,
                        total_steps,
                        loss.data.item(),
                    )
                )


with experiment.test():
    # Test the Model
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = Variable(
            images.view(-1, hyper_params["sequence_length"], hyper_params["input_size"])
        )
        outputs = rnn(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += float((predicted == labels).sum())

    experiment.log_metric("accuracy", correct / total)
    print(
        "Test Accuracy of the model on the 10000 test images: %d %%"
        % (100 * correct / total)
    )

# Log the model to Comet for easy tracking and deployment
log_model(experiment, rnn, "Pytorch-Mnist")

Try it out!¶

Don't just take our word for it, try it out for yourself.

Weights/Biases and Gradients logging¶

You can log your Pytorch model Weights, Biases and Gradients during training with just one additional line of code, using watch:

from comet_ml import start, login
from comet_ml.integration.pytorch import watch

login()
experiment = start()

class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        ...

    def forward(self, x):
        ...
        return x

# Initialize model
model = TheModelClass()
watch(model)

This will log of your Model layers Weights, Biases and Gradients regularly during training. These will be logged as histograms on the Experiment Histograms tab.

Comet Experiment Management - Single Experiment Page: Histograms tab

Check out the reference documentation for more details.

Pytorch model saving and loading¶

Comet provides user-friendly helpers to allow you to easily save your model and load them back.

Saving a model¶

To save a Pytorch model, you can use the comet_ml.integration.pytorch.log_model helper like this:

from comet_ml import start, login
from comet_ml.integration.pytorch import log_model

login()
experiment = start()

class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        ...

    def forward(self, x):
        ...
        return x

# Initialize model
model = TheModelClass()

# Train model
train(model)

# Save the model for inference
log_model(experiment, model, model_name="TheModel")

The model file will be saved as an Experiment Model which is visible in the Experiment assets tab. From there you will be able to register it in the Model Registry.

The previous code snippet is tailored for inference needs. If you want to log a general checkpoint for Resume Training, you can update the last line of the snippet to be:

# Save the model for Resume Training
model_checkpoint = {
    "epoch": epoch,
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "loss": loss,
    ...
}
log_model(experiment, model_checkpoint, model_name="TheModel")

comet_ml.integration.pytorch.log_model is using torch.save under the hood, consult the official Pytorch documentation for more details and for instructions for more advanced use-cases.

Check out the reference documentation for more details.

Loading a model¶

Once you have saved a model using comet_ml.integration.pytorch.log_model, you can load it back with its counterpart comet_ml.integration.pytorch.load_model.

Here is how you can load a model from the Model Registry for Inference:

from comet_ml.integration.pytorch import load_model

class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        ...

    def forward(self, x):
        ...
        return x

# Initialize model
model = TheModelClass()

# Load the model state dict from Comet Registry
model.load_state_dict(load_model("registry://WORKSPACE/TheModel:1.2.4"))

model.eval()

prediction = model(...)

You can load Pytorch Model from various sources:

  • file://data/my-model, load the state_dict from the file path data/my-model (relative path)
  • file:///path/to/my-model, load the state_dict from the file path /path/to/-my-model (absolute path)
  • registry://<workspace>/<registry_name>, load the state_dict from the Model Registry identified by the workspace and registry name, take the last version of it.
  • registry://<workspace>/<registry_name>:version, load the state_dict from the Model Registry identified by the workspace, registry name and explicit version.
  • experiment://<experiment_key>/<model_name>, load the state_dict from an Experiment, identified by the Experiment key and the model_name.
  • experiment://<workspace>/<project_name>/<experiment_name>/<model_name>, load the state_dict from an Experiment, identified by the workspace name, project name, experiment name and the model_name.

The previous code snippet is tailored for inference needs. If you want to load a general checkpoint for Resume Training, you can update the last line of the snippet to be:

# Initialize model
model = TheModelClass()

# Load the model state dict from a Comet Experiment
checkpoint = load_model("experiment://e1098c4e1e764ff89881b868e4c70f5/TheModel")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.train()

comet_ml.integration.pytorch.load_modle is using torch.load under the hood, consult the official Pytorch documentation for more details and for instructions for more advanced use-cases.

Check out the reference documentation for more details.

PyTorch Distributed Data Parallel¶

Are you running distributed training with PyTorch? There is an example for logging PyTorch DDP with Comet in the comet-example repository.

Configure Comet for PyTorch¶

You can control which PyTorch items are logged automatically. Use any of the following methods:

experiment = comet_ml.start(
    log_graph=True, # Can be True or False.
    auto_metric_logging=True # Can be True or False
)

Add or remove these fields from your .comet.config file under the [comet_auto_log] section to enable or disable logging.

[comet_auto_log]
graph=true # can be true or false
metrics=true # can be true or false
export COMET_AUTO_LOG_GRAPH=true # Can be true or false
export COMET_AUTO_LOG_METRICS=true # Can be true or false

For more information about configuring Comet, see Configure Comet.

Jan. 17, 2025