Skip to content

comet_ml.integration.pytorch ¶

load_model ¶

load_model(
    model_uri: str,
    map_location: Any = None,
    pickle_module: Optional[Module] = None,
    **torch_load_args
) -> ModelStateDict

Load model's state_dict from experiment, registry or from disk by uri. This will returns a Pytorch state_dict that you will need to load into your model. This will load the model using torch.load.

Parameters:

  • model_uri (str) –

    string (required), a uri string defining model location. Possible options are:

    • file://data/my-model
    • file:///path/to/my-model
    • registry://workspace/registry_name (takes the last version)
    • registry://workspace/registry_name:version
    • experiment://experiment_key/model_name
    • experiment://workspace/project_name/experiment_name/model_name
  • map_location (Any, default: None ) –

    Passed to torch.load (see torch.load)

  • pickle_module (Optional[Module], default: None ) –

    Passed to torch.load (see torch.load)

  • torch_load_args –

    Passed to torch.load (see torch.load)

Example

Here is an example of loading a model from the Model Registry for inference:

from comet_ml.integration.pytorch import load_model

class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        ...

    def forward(self, x):
        ...
        return x

# Initialize model
model = TheModelClass()

# Load the model state dict from Comet Registry
model.load_state_dict(load_model("registry://WORKSPACE/TheModel:1.2.4"))

model.eval()

prediction = model(...)

Here is an example of loading a model from an Experiment for Resume Training:

from comet_ml.integration.pytorch import load_model

# Initialize model
model = TheModelClass()

# Load the model state dict from a Comet Experiment
checkpoint = load_model("experiment://e1098c4e1e764ff89881b868e4c70f5/TheModel")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.train()

log_model ¶

log_model(
    experiment,
    model,
    model_name,
    metadata=None,
    pickle_module=None,
    **torch_save_args
)

Logs a Pytorch model to an experiment. This will save the model using torch.save and save it as an Experiment Model.

The model parameter can either be an instance of torch.nn.Module or any input supported by torch.save, see the tutorial about saving and loading Pytorch models for more details.

Parameters:

  • experiment (Experiment) –

    Instance of experiment to log model

  • model (dict | Module) –

    Model to log

  • model_name (str) –

    The name of the model

  • metadata (dict, default: None ) –

    Some additional data to attach to the the data. Must be a JSON-encodable dict

  • pickle_module –

    Passed to torch.save (see torch.save documentation)

  • torch_save_args –

    Passed to torch.save (see torch.save documentation)

Example

Here is an example of logging a model for inference:

class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        ...

    def forward(self, x):
        ...
        return x

# Initialize model
model = TheModelClass()

# Train model
train(model)

# Save the model for inference
log_model(experiment, model, model_name="TheModel")

Here is an example of logging a checkpoint for resume training:

model_checkpoint = {
    "epoch": epoch,
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "loss": loss,
    ...
}
log_model(experiment, model_checkpoint, model_name="TheModel")

watch ¶

watch(model: torch.nn.Module, log_step_interval: int = 1000) -> None

Enables automatic logging of each layer's parameters and gradients in the given PyTorch module. These will be logged as histograms. Note that an Experiment must be created before calling this function.

Parameters:

  • model (Module) –

    An instance of torch.nn.Module.

  • log_step_interval (int, default: 1000 ) –

    Determines how often layers are logged (default is every 1000 steps).

Jan. 17, 2025