cnellington / Contextualized

An SKLearn-style toolbox for estimating and analyzing models, distributions, and functions with context-specific parameters.
http://contextualized.ml/
GNU General Public License v3.0
65 stars 9 forks source link

Moving model and data to and from device #254

Open alexanderchang1 opened 6 days ago

alexanderchang1 commented 6 days ago

Is there something I'm missing in the documentation as to how to move C-ML models and data to and from CUDA/CPU? I keep getting errors when trying to run code based on wrong location of torch.eye, etc. etc.

alexanderchang1 commented 6 days ago

Code to reproduce

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch  # Import torch to handle tensors and CUDA devices
from contextualized.easy import ContextualizedBayesianNetworks
from sklearn.preprocessing import StandardScaler

# Detect CUDA device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Synthetic data example (replace this with your actual data)
n_samples = 50
n_genes = 10  # Number of genes
n_context_features = 20  # Number of context features (cna + rna)

# Generate random data for illustration purposes
np.random.seed(42)
C_array = np.random.randn(n_samples, n_context_features)
X_array = np.random.randn(n_samples, n_genes)

# Standardize the synthetic data
scaler_C = StandardScaler()
C_array = scaler_C.fit_transform(C_array)

scaler_X = StandardScaler()
X_array = scaler_X.fit_transform(X_array)

# Convert data to torch tensors and move them to the device
C_tensor = torch.tensor(C_array, dtype=torch.float32).to(device)
X_tensor = torch.tensor(X_array, dtype=torch.float32).to(device)

# Initialize the model (no need to move it to the device manually)
cbn = ContextualizedBayesianNetworks(
    encoder_type='mlp',
    num_archetypes=16,
    n_bootstraps=2,
    archetype_dag_loss_type="DAGMA",
    archetype_alpha=0.,
    sample_specific_dag_loss_type="DAGMA",
    sample_specific_alpha=1e-1,
    learning_rate=1e-3
)

# Train the model (the model expects tensors on the same device)
cbn.fit(C_tensor, X_tensor, max_epochs=100)

# Evaluate the model (again with data tensors on the same device)
mses = cbn.measure_mses(C_tensor, X_tensor)
print(f"Mean Squared Error: {np.mean(mses)}")

# Predict and visualize networks (ensure data stays on the correct device)
predicted_networks = cbn.predict_networks(C_tensor)

# Move the predicted network back to CPU for visualization
predicted_networks_cpu = predicted_networks[0].cpu().detach().numpy()

plt.imshow(predicted_networks_cpu)
plt.title("Predicted Network for First Sample")
plt.colorbar()
plt.show()
cnellington commented 5 days ago

Hey @alexanderchang1 thanks for the comment and the reproduction. The easy models handle cpu/cuda device switches internally, and they expect data as numpy arrays or pandas dataframes to make things simple. See the docs here: https://contextualized.ml/docs/source/easy.html#contextualized.easy.ContextualizedNetworks.ContextualizedBayesianNetworks.fit

If you need more control for things like multi-device training, you'll need to use the contextualized.dags.lightning_modules model directly, along with the associated contextualized.dags.trainers. I don't recommend it if you don't need it, but if you do lightning should have most of the docs to get you started. https://lightning.ai/docs/pytorch/stable/common/trainer.html

alexanderchang1 commented 5 days ago

Thanks!

alexanderchang1 commented 4 days ago

@cnellington re-opening this issue, I'm still having an issue even when I double check the datatype are numpy arrays. I'm trying to isolate if the issue is my input matrices are not correctly moved to cuda, or if it's the initiated torch.eye

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/contextualized/easy/wrappers/SKLearnWrapper.py:515, in SKLearnWrapper.fit(self, *args, **kwargs)
    514 try:
--> 515     trainer.fit(model, train_dataloader, val_dataloader, **organized_kwargs["fit"])
    516 except:

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:538, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    537 self.training = True
--> 538 call._call_and_handle_interrupt(
    539     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    540 )

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py:47, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     46         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 47     return trainer_fn(*args, **kwargs)
     49 except _TunerExitException:

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:574, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    568 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    569     self.state.fn,
    570     ckpt_path,
    571     model_provided=True,
    572     model_connected=self.lightning_module is not None,
    573 )
--> 574 self._run(model, ckpt_path=ckpt_path)
    576 assert self.state.stopped

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:981, in Trainer._run(self, model, ckpt_path)
    978 # ----------------------------
    979 # RUN THE TRAINER
    980 # ----------------------------
--> 981 results = self._run_stage()
    983 # ----------------------------
    984 # POST-Training CLEAN UP
    985 # ----------------------------

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1023, in Trainer._run_stage(self)
   1022 with isolate_rng():
-> 1023     self._run_sanity_check()
   1024 with torch.autograd.set_detect_anomaly(self._detect_anomaly):

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1052, in Trainer._run_sanity_check(self)
   1051 # run eval step
-> 1052 val_loop.run()
   1054 call._call_callback_hooks(self, "on_sanity_check_end")

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/loops/utilities.py:178, in _no_grad_context.<locals>._decorator(self, *args, **kwargs)
    177 with context_manager():
--> 178     return loop_run(self, *args, **kwargs)

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/loops/evaluation_loop.py:135, in _EvaluationLoop.run(self)
    134     # run step hooks
--> 135     self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
    136 except StopIteration:
    137     # this needs to wrap the `*_step` call too (not just `next`) for `dataloader_iter` support

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/loops/evaluation_loop.py:396, in _EvaluationLoop._evaluation_step(self, batch, batch_idx, dataloader_idx, dataloader_iter)
    391 step_args = (
    392     self._build_step_args_from_hook_kwargs(hook_kwargs, hook_name)
    393     if not using_dataloader_iter
    394     else (dataloader_iter,)
    395 )
--> 396 output = call._call_strategy_hook(trainer, hook_name, *step_args)
    398 self.batch_progress.increment_processed()

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py:319, in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
    318 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 319     output = fn(*args, **kwargs)
    321 # restore current_fx when nested context

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/strategies/strategy.py:411, in Strategy.validation_step(self, *args, **kwargs)
    410     return self._forward_redirection(self.model, self.lightning_module, "validation_step", *args, **kwargs)
--> 411 return self.lightning_module.validation_step(*args, **kwargs)

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/contextualized/dags/lightning_modules.py:331, in NOTMAD.validation_step(self, batch, batch_idx)
    330 # ignore archetype loss, use constant alpha/rho upper bound for validation
--> 331 dag_term = self.ss_dag_loss(w_pred, **self.val_dag_loss_params).mean()
    332 if self.latent_dim < self.x_dim:

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/contextualized/dags/losses.py:15, in dag_loss_dagma(W, s, alpha, **kwargs)
     12 """DAG loss on batched networks W using the
     13 DAGMA log-determinant
     14 """
---> 15 sample_losses = torch.Tensor([dag_loss_dagma_indiv(w, s) for w in W])
     16 return alpha * torch.mean(sample_losses)

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/contextualized/dags/losses.py:15, in <listcomp>(.0)
     12 """DAG loss on batched networks W using the
     13 DAGMA log-determinant
     14 """
---> 15 sample_losses = torch.Tensor([dag_loss_dagma_indiv(w, s) for w in W])
     16 return alpha * torch.mean(sample_losses)

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/contextualized/dags/losses.py:7, in dag_loss_dagma_indiv(w, s)
      6 def dag_loss_dagma_indiv(w, s=1):
----> 7     M = s * torch.eye(w.shape[-1]) - w * w
      8     return w.shape[-1] * np.log(s) - torch.slogdet(M)[1]

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
Cell In[31], line 8
      1 from contextualized.easy import ContextualizedBayesianNetworks
      3 cbn = ContextualizedBayesianNetworks(
      4     encoder_type='mlp', num_archetypes=16,
      5     n_bootstraps=2, archetype_dag_loss_type="DAGMA", archetype_alpha=0.,
      6     sample_specific_dag_loss_type="DAGMA", sample_specific_alpha=1e-1,
      7     learning_rate=1e-3)
----> 8 cbn.fit(C_numpy, X_numpy, max_epochs=10, es_verbose=True)

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/contextualized/easy/wrappers/SKLearnWrapper.py:517, in SKLearnWrapper.fit(self, *args, **kwargs)
    515     trainer.fit(model, train_dataloader, val_dataloader, **organized_kwargs["fit"])
    516 except:
--> 517     trainer.fit(model, train_dataloader, **organized_kwargs["fit"])
    519 if kwargs.get("max_epochs", 1) > 0:
    520     best_checkpoint = torch.load(checkpoint_callback.best_model_path)

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:538, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    536 self.state.status = TrainerStatus.RUNNING
    537 self.training = True
--> 538 call._call_and_handle_interrupt(
    539     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    540 )

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py:47, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     45     if trainer.strategy.launcher is not None:
     46         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 47     return trainer_fn(*args, **kwargs)
     49 except _TunerExitException:
     50     _call_teardown_hook(trainer)

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:574, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    567 assert self.state.fn is not None
    568 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    569     self.state.fn,
    570     ckpt_path,
    571     model_provided=True,
    572     model_connected=self.lightning_module is not None,
    573 )
--> 574 self._run(model, ckpt_path=ckpt_path)
    576 assert self.state.stopped
    577 self.training = False

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:981, in Trainer._run(self, model, ckpt_path)
    976 self._signal_connector.register_signal_handlers()
    978 # ----------------------------
    979 # RUN THE TRAINER
    980 # ----------------------------
--> 981 results = self._run_stage()
    983 # ----------------------------
    984 # POST-Training CLEAN UP
    985 # ----------------------------
    986 log.debug(f"{self.__class__.__name__}: trainer tearing down")

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1025, in Trainer._run_stage(self)
   1023         self._run_sanity_check()
   1024     with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1025         self.fit_loop.run()
   1026     return None
   1027 raise RuntimeError(f"Unexpected state {self.state}")

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py:205, in _FitLoop.run(self)
    203 try:
    204     self.on_advance_start()
--> 205     self.advance()
    206     self.on_advance_end()
    207     self._restarting = False

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py:363, in _FitLoop.advance(self)
    361 with self.trainer.profiler.profile("run_training_epoch"):
    362     assert self._data_fetcher is not None
--> 363     self.epoch_loop.run(self._data_fetcher)

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/loops/training_epoch_loop.py:140, in _TrainingEpochLoop.run(self, data_fetcher)
    138 while not self.done:
    139     try:
--> 140         self.advance(data_fetcher)
    141         self.on_advance_end(data_fetcher)
    142         self._restarting = False

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/loops/training_epoch_loop.py:250, in _TrainingEpochLoop.advance(self, data_fetcher)
    247 with trainer.profiler.profile("run_training_batch"):
    248     if trainer.lightning_module.automatic_optimization:
    249         # in automatic optimization, there can only be one optimizer
--> 250         batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
    251     else:
    252         batch_output = self.manual_optimization.run(kwargs)

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/automatic.py:190, in _AutomaticOptimization.run(self, optimizer, batch_idx, kwargs)
    183         closure()
    185 # ------------------------------
    186 # BACKWARD PASS
    187 # ------------------------------
    188 # gradient update with accumulated gradients
    189 else:
--> 190     self._optimizer_step(batch_idx, closure)
    192 result = closure.consume_result()
    193 if result.loss is None:

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/automatic.py:268, in _AutomaticOptimization._optimizer_step(self, batch_idx, train_step_and_backward_closure)
    265     self.optim_progress.optimizer.step.increment_ready()
    267 # model hook
--> 268 call._call_lightning_module_hook(
    269     trainer,
    270     "optimizer_step",
    271     trainer.current_epoch,
    272     batch_idx,
    273     optimizer,
    274     train_step_and_backward_closure,
    275 )
    277 if not should_accumulate:
    278     self.optim_progress.optimizer.step.increment_completed()

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py:167, in _call_lightning_module_hook(trainer, hook_name, pl_module, *args, **kwargs)
    164 pl_module._current_fx_name = hook_name
    166 with trainer.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{hook_name}"):
--> 167     output = fn(*args, **kwargs)
    169 # restore current_fx when nested context
    170 pl_module._current_fx_name = prev_fx_name

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/core/module.py:1306, in LightningModule.optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure)
   1275 def optimizer_step(
   1276     self,
   1277     epoch: int,
   (...)
   1280     optimizer_closure: Optional[Callable[[], Any]] = None,
   1281 ) -> None:
   1282     r"""Override this method to adjust the default way the :class:`~pytorch_lightning.trainer.trainer.Trainer` calls
   1283     the optimizer.
   1284 
   (...)
   1304 
   1305     """
-> 1306     optimizer.step(closure=optimizer_closure)

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/core/optimizer.py:153, in LightningOptimizer.step(self, closure, **kwargs)
    150     raise MisconfigurationException("When `optimizer.step(closure)` is called, the closure should be callable")
    152 assert self._strategy is not None
--> 153 step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
    155 self._on_after_step()
    157 return step_output

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/strategies/strategy.py:238, in Strategy.optimizer_step(self, optimizer, closure, model, **kwargs)
    236 # TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed
    237 assert isinstance(model, pl.LightningModule)
--> 238 return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/plugins/precision/precision.py:122, in Precision.optimizer_step(self, optimizer, model, closure, **kwargs)
    120 """Hook to run the optimizer step."""
    121 closure = partial(self._wrap_closure, model, optimizer, closure)
--> 122 return optimizer.step(closure=closure, **kwargs)

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/torch/optim/lr_scheduler.py:68, in LRScheduler.__init__.<locals>.with_counter.<locals>.wrapper(*args, **kwargs)
     66 instance._step_count += 1
     67 wrapped = func.__get__(instance, cls)
---> 68 return wrapped(*args, **kwargs)

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/torch/optim/optimizer.py:373, in Optimizer.profile_hook_step.<locals>.wrapper(*args, **kwargs)
    368         else:
    369             raise RuntimeError(
    370                 f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
    371             )
--> 373 out = func(*args, **kwargs)
    374 self._optimizer_step_code()
    376 # call optimizer step post hooks

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/torch/optim/optimizer.py:76, in _use_grad_for_differentiable.<locals>._use_grad(self, *args, **kwargs)
     74     torch.set_grad_enabled(self.defaults['differentiable'])
     75     torch._dynamo.graph_break()
---> 76     ret = func(self, *args, **kwargs)
     77 finally:
     78     torch._dynamo.graph_break()

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/torch/optim/adam.py:143, in Adam.step(self, closure)
    141 if closure is not None:
    142     with torch.enable_grad():
--> 143         loss = closure()
    145 for group in self.param_groups:
    146     params_with_grad = []

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/plugins/precision/precision.py:108, in Precision._wrap_closure(self, model, optimizer, closure)
     95 def _wrap_closure(
     96     self,
     97     model: "pl.LightningModule",
     98     optimizer: Steppable,
     99     closure: Callable[[], Any],
    100 ) -> Any:
    101     """This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step``
    102     hook is called.
    103 
   (...)
    106 
    107     """
--> 108     closure_result = closure()
    109     self._after_closure(model, optimizer)
    110     return closure_result

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/automatic.py:144, in Closure.__call__(self, *args, **kwargs)
    142 @override
    143 def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
--> 144     self._result = self.closure(*args, **kwargs)
    145     return self._result.loss

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/automatic.py:129, in Closure.closure(self, *args, **kwargs)
    126 @override
    127 @torch.enable_grad()
    128 def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
--> 129     step_output = self._step_fn()
    131     if step_output.closure_loss is None:
    132         self.warning_cache.warn("`training_step` returned `None`. If this was on purpose, ignore this warning...")

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/automatic.py:317, in _AutomaticOptimization._training_step(self, kwargs)
    306 """Performs the actual train step with the tied hooks.
    307 
    308 Args:
   (...)
    313 
    314 """
    315 trainer = self.trainer
--> 317 training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
    318 self.trainer.strategy.post_training_step()  # unused hook - call anyway for backward compatibility
    320 if training_step_output is None and trainer.world_size > 1:

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py:319, in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
    316     return None
    318 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 319     output = fn(*args, **kwargs)
    321 # restore current_fx when nested context
    322 pl_module._current_fx_name = prev_fx_name

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/strategies/strategy.py:390, in Strategy.training_step(self, *args, **kwargs)
    388 if self.model != self.lightning_module:
    389     return self._forward_redirection(self.model, self.lightning_module, "training_step", *args, **kwargs)
--> 390 return self.lightning_module.training_step(*args, **kwargs)

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/contextualized/dags/lightning_modules.py:277, in NOTMAD.training_step(self, batch, batch_idx)
    267 def training_step(self, batch, batch_idx):
    268     (
    269         loss,
    270         notears,
    271         mse_term,
    272         l1_term,
    273         dag_term,
    274         arch_l1_term,
    275         arch_dag_term,
    276         factor_mat_term,
--> 277     ) = self._batch_loss(batch, batch_idx)
    278     ret = {
    279         "loss": loss,
    280         "train_loss": loss,
   (...)
    286         "train_factor_l1_loss": factor_mat_term,
    287     }
    288     self.log_dict(ret)

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/contextualized/dags/lightning_modules.py:235, in NOTMAD._batch_loss(self, batch, batch_idx)
    233     mse_term = linear_sem_loss(x_true, w_pred)
    234 l1_term = l1_loss(w_pred, self.ss_l1)
--> 235 dag_term = self.ss_dag_loss(w_pred, **self.ss_dag_params)
    236 notears = mse_term + l1_term + dag_term
    237 W_arch = self.explainer.get_archetypes()

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/contextualized/dags/losses.py:15, in dag_loss_dagma(W, s, alpha, **kwargs)
     11 def dag_loss_dagma(W, s=1, alpha=0.0, **kwargs):
     12     """DAG loss on batched networks W using the
     13     DAGMA log-determinant
     14     """
---> 15     sample_losses = torch.Tensor([dag_loss_dagma_indiv(w, s) for w in W])
     16     return alpha * torch.mean(sample_losses)

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/contextualized/dags/losses.py:15, in <listcomp>(.0)
     11 def dag_loss_dagma(W, s=1, alpha=0.0, **kwargs):
     12     """DAG loss on batched networks W using the
     13     DAGMA log-determinant
     14     """
---> 15     sample_losses = torch.Tensor([dag_loss_dagma_indiv(w, s) for w in W])
     16     return alpha * torch.mean(sample_losses)

File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/contextualized/dags/losses.py:7, in dag_loss_dagma_indiv(w, s)
      6 def dag_loss_dagma_indiv(w, s=1):
----> 7     M = s * torch.eye(w.shape[-1]) - w * w
      8     return w.shape[-1] * np.log(s) - torch.slogdet(M)[1]

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
cnellington commented 3 days ago

It's almost certainly the torch.eye. A simple torch.eye(w.shape[-1]).to(w.device) should fix this, but I won't be able to make a proper fix and test it for a few days. Feel free to change & make a PR if you feel inclined, or I'll get to it a bit later.