Skip to content

Commit

Permalink
option to use torch.compile (#336)
Browse files Browse the repository at this point in the history
  • Loading branch information
vitkl authored Nov 26, 2023
1 parent 3af08a1 commit b2f3894
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions cell2location/models/_cell2location_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import pandas as pd
import scanpy
import torch
from anndata import AnnData
from pyro import clear_param_store
from pyro.infer import Trace_ELBO, TraceEnum_ELBO
Expand Down Expand Up @@ -160,6 +161,12 @@ def setup_anndata(
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)

def train_compiled(self, compile_mode=None, compile_dynamic=None, **kwargs):
self.train(**kwargs, max_steps=1)
self.module._model = torch.compile(self.module.model, mode=compile_mode, dynamic=compile_dynamic)
self.module._guide = torch.compile(self.module.guide, mode=compile_mode, dynamic=compile_dynamic)
self.train(**kwargs)

def train(
self,
max_epochs: int = 30000,
Expand Down

0 comments on commit b2f3894

Please sign in to comment.