Open alexanderchang1 opened 6 days ago
Code to reproduce
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch # Import torch to handle tensors and CUDA devices
from contextualized.easy import ContextualizedBayesianNetworks
from sklearn.preprocessing import StandardScaler
# Detect CUDA device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Synthetic data example (replace this with your actual data)
n_samples = 50
n_genes = 10 # Number of genes
n_context_features = 20 # Number of context features (cna + rna)
# Generate random data for illustration purposes
np.random.seed(42)
C_array = np.random.randn(n_samples, n_context_features)
X_array = np.random.randn(n_samples, n_genes)
# Standardize the synthetic data
scaler_C = StandardScaler()
C_array = scaler_C.fit_transform(C_array)
scaler_X = StandardScaler()
X_array = scaler_X.fit_transform(X_array)
# Convert data to torch tensors and move them to the device
C_tensor = torch.tensor(C_array, dtype=torch.float32).to(device)
X_tensor = torch.tensor(X_array, dtype=torch.float32).to(device)
# Initialize the model (no need to move it to the device manually)
cbn = ContextualizedBayesianNetworks(
encoder_type='mlp',
num_archetypes=16,
n_bootstraps=2,
archetype_dag_loss_type="DAGMA",
archetype_alpha=0.,
sample_specific_dag_loss_type="DAGMA",
sample_specific_alpha=1e-1,
learning_rate=1e-3
)
# Train the model (the model expects tensors on the same device)
cbn.fit(C_tensor, X_tensor, max_epochs=100)
# Evaluate the model (again with data tensors on the same device)
mses = cbn.measure_mses(C_tensor, X_tensor)
print(f"Mean Squared Error: {np.mean(mses)}")
# Predict and visualize networks (ensure data stays on the correct device)
predicted_networks = cbn.predict_networks(C_tensor)
# Move the predicted network back to CPU for visualization
predicted_networks_cpu = predicted_networks[0].cpu().detach().numpy()
plt.imshow(predicted_networks_cpu)
plt.title("Predicted Network for First Sample")
plt.colorbar()
plt.show()
Hey @alexanderchang1 thanks for the comment and the reproduction. The easy models handle cpu/cuda device switches internally, and they expect data as numpy arrays or pandas dataframes to make things simple. See the docs here: https://contextualized.ml/docs/source/easy.html#contextualized.easy.ContextualizedNetworks.ContextualizedBayesianNetworks.fit
If you need more control for things like multi-device training, you'll need to use the contextualized.dags.lightning_modules
model directly, along with the associated contextualized.dags.trainers
. I don't recommend it if you don't need it, but if you do lightning should have most of the docs to get you started. https://lightning.ai/docs/pytorch/stable/common/trainer.html
Thanks!
@cnellington re-opening this issue, I'm still having an issue even when I double check the datatype are numpy arrays. I'm trying to isolate if the issue is my input matrices are not correctly moved to cuda, or if it's the initiated torch.eye
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/contextualized/easy/wrappers/SKLearnWrapper.py:515, in SKLearnWrapper.fit(self, *args, **kwargs)
514 try:
--> 515 trainer.fit(model, train_dataloader, val_dataloader, **organized_kwargs["fit"])
516 except:
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:538, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
537 self.training = True
--> 538 call._call_and_handle_interrupt(
539 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
540 )
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py:47, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
46 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 47 return trainer_fn(*args, **kwargs)
49 except _TunerExitException:
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:574, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
568 ckpt_path = self._checkpoint_connector._select_ckpt_path(
569 self.state.fn,
570 ckpt_path,
571 model_provided=True,
572 model_connected=self.lightning_module is not None,
573 )
--> 574 self._run(model, ckpt_path=ckpt_path)
576 assert self.state.stopped
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:981, in Trainer._run(self, model, ckpt_path)
978 # ----------------------------
979 # RUN THE TRAINER
980 # ----------------------------
--> 981 results = self._run_stage()
983 # ----------------------------
984 # POST-Training CLEAN UP
985 # ----------------------------
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1023, in Trainer._run_stage(self)
1022 with isolate_rng():
-> 1023 self._run_sanity_check()
1024 with torch.autograd.set_detect_anomaly(self._detect_anomaly):
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1052, in Trainer._run_sanity_check(self)
1051 # run eval step
-> 1052 val_loop.run()
1054 call._call_callback_hooks(self, "on_sanity_check_end")
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/loops/utilities.py:178, in _no_grad_context.<locals>._decorator(self, *args, **kwargs)
177 with context_manager():
--> 178 return loop_run(self, *args, **kwargs)
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/loops/evaluation_loop.py:135, in _EvaluationLoop.run(self)
134 # run step hooks
--> 135 self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
136 except StopIteration:
137 # this needs to wrap the `*_step` call too (not just `next`) for `dataloader_iter` support
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/loops/evaluation_loop.py:396, in _EvaluationLoop._evaluation_step(self, batch, batch_idx, dataloader_idx, dataloader_iter)
391 step_args = (
392 self._build_step_args_from_hook_kwargs(hook_kwargs, hook_name)
393 if not using_dataloader_iter
394 else (dataloader_iter,)
395 )
--> 396 output = call._call_strategy_hook(trainer, hook_name, *step_args)
398 self.batch_progress.increment_processed()
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py:319, in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
318 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 319 output = fn(*args, **kwargs)
321 # restore current_fx when nested context
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/strategies/strategy.py:411, in Strategy.validation_step(self, *args, **kwargs)
410 return self._forward_redirection(self.model, self.lightning_module, "validation_step", *args, **kwargs)
--> 411 return self.lightning_module.validation_step(*args, **kwargs)
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/contextualized/dags/lightning_modules.py:331, in NOTMAD.validation_step(self, batch, batch_idx)
330 # ignore archetype loss, use constant alpha/rho upper bound for validation
--> 331 dag_term = self.ss_dag_loss(w_pred, **self.val_dag_loss_params).mean()
332 if self.latent_dim < self.x_dim:
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/contextualized/dags/losses.py:15, in dag_loss_dagma(W, s, alpha, **kwargs)
12 """DAG loss on batched networks W using the
13 DAGMA log-determinant
14 """
---> 15 sample_losses = torch.Tensor([dag_loss_dagma_indiv(w, s) for w in W])
16 return alpha * torch.mean(sample_losses)
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/contextualized/dags/losses.py:15, in <listcomp>(.0)
12 """DAG loss on batched networks W using the
13 DAGMA log-determinant
14 """
---> 15 sample_losses = torch.Tensor([dag_loss_dagma_indiv(w, s) for w in W])
16 return alpha * torch.mean(sample_losses)
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/contextualized/dags/losses.py:7, in dag_loss_dagma_indiv(w, s)
6 def dag_loss_dagma_indiv(w, s=1):
----> 7 M = s * torch.eye(w.shape[-1]) - w * w
8 return w.shape[-1] * np.log(s) - torch.slogdet(M)[1]
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
During handling of the above exception, another exception occurred:
RuntimeError Traceback (most recent call last)
Cell In[31], line 8
1 from contextualized.easy import ContextualizedBayesianNetworks
3 cbn = ContextualizedBayesianNetworks(
4 encoder_type='mlp', num_archetypes=16,
5 n_bootstraps=2, archetype_dag_loss_type="DAGMA", archetype_alpha=0.,
6 sample_specific_dag_loss_type="DAGMA", sample_specific_alpha=1e-1,
7 learning_rate=1e-3)
----> 8 cbn.fit(C_numpy, X_numpy, max_epochs=10, es_verbose=True)
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/contextualized/easy/wrappers/SKLearnWrapper.py:517, in SKLearnWrapper.fit(self, *args, **kwargs)
515 trainer.fit(model, train_dataloader, val_dataloader, **organized_kwargs["fit"])
516 except:
--> 517 trainer.fit(model, train_dataloader, **organized_kwargs["fit"])
519 if kwargs.get("max_epochs", 1) > 0:
520 best_checkpoint = torch.load(checkpoint_callback.best_model_path)
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:538, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
536 self.state.status = TrainerStatus.RUNNING
537 self.training = True
--> 538 call._call_and_handle_interrupt(
539 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
540 )
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py:47, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
45 if trainer.strategy.launcher is not None:
46 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 47 return trainer_fn(*args, **kwargs)
49 except _TunerExitException:
50 _call_teardown_hook(trainer)
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:574, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
567 assert self.state.fn is not None
568 ckpt_path = self._checkpoint_connector._select_ckpt_path(
569 self.state.fn,
570 ckpt_path,
571 model_provided=True,
572 model_connected=self.lightning_module is not None,
573 )
--> 574 self._run(model, ckpt_path=ckpt_path)
576 assert self.state.stopped
577 self.training = False
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:981, in Trainer._run(self, model, ckpt_path)
976 self._signal_connector.register_signal_handlers()
978 # ----------------------------
979 # RUN THE TRAINER
980 # ----------------------------
--> 981 results = self._run_stage()
983 # ----------------------------
984 # POST-Training CLEAN UP
985 # ----------------------------
986 log.debug(f"{self.__class__.__name__}: trainer tearing down")
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1025, in Trainer._run_stage(self)
1023 self._run_sanity_check()
1024 with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1025 self.fit_loop.run()
1026 return None
1027 raise RuntimeError(f"Unexpected state {self.state}")
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py:205, in _FitLoop.run(self)
203 try:
204 self.on_advance_start()
--> 205 self.advance()
206 self.on_advance_end()
207 self._restarting = False
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py:363, in _FitLoop.advance(self)
361 with self.trainer.profiler.profile("run_training_epoch"):
362 assert self._data_fetcher is not None
--> 363 self.epoch_loop.run(self._data_fetcher)
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/loops/training_epoch_loop.py:140, in _TrainingEpochLoop.run(self, data_fetcher)
138 while not self.done:
139 try:
--> 140 self.advance(data_fetcher)
141 self.on_advance_end(data_fetcher)
142 self._restarting = False
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/loops/training_epoch_loop.py:250, in _TrainingEpochLoop.advance(self, data_fetcher)
247 with trainer.profiler.profile("run_training_batch"):
248 if trainer.lightning_module.automatic_optimization:
249 # in automatic optimization, there can only be one optimizer
--> 250 batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
251 else:
252 batch_output = self.manual_optimization.run(kwargs)
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/automatic.py:190, in _AutomaticOptimization.run(self, optimizer, batch_idx, kwargs)
183 closure()
185 # ------------------------------
186 # BACKWARD PASS
187 # ------------------------------
188 # gradient update with accumulated gradients
189 else:
--> 190 self._optimizer_step(batch_idx, closure)
192 result = closure.consume_result()
193 if result.loss is None:
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/automatic.py:268, in _AutomaticOptimization._optimizer_step(self, batch_idx, train_step_and_backward_closure)
265 self.optim_progress.optimizer.step.increment_ready()
267 # model hook
--> 268 call._call_lightning_module_hook(
269 trainer,
270 "optimizer_step",
271 trainer.current_epoch,
272 batch_idx,
273 optimizer,
274 train_step_and_backward_closure,
275 )
277 if not should_accumulate:
278 self.optim_progress.optimizer.step.increment_completed()
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py:167, in _call_lightning_module_hook(trainer, hook_name, pl_module, *args, **kwargs)
164 pl_module._current_fx_name = hook_name
166 with trainer.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{hook_name}"):
--> 167 output = fn(*args, **kwargs)
169 # restore current_fx when nested context
170 pl_module._current_fx_name = prev_fx_name
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/core/module.py:1306, in LightningModule.optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure)
1275 def optimizer_step(
1276 self,
1277 epoch: int,
(...)
1280 optimizer_closure: Optional[Callable[[], Any]] = None,
1281 ) -> None:
1282 r"""Override this method to adjust the default way the :class:`~pytorch_lightning.trainer.trainer.Trainer` calls
1283 the optimizer.
1284
(...)
1304
1305 """
-> 1306 optimizer.step(closure=optimizer_closure)
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/core/optimizer.py:153, in LightningOptimizer.step(self, closure, **kwargs)
150 raise MisconfigurationException("When `optimizer.step(closure)` is called, the closure should be callable")
152 assert self._strategy is not None
--> 153 step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
155 self._on_after_step()
157 return step_output
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/strategies/strategy.py:238, in Strategy.optimizer_step(self, optimizer, closure, model, **kwargs)
236 # TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed
237 assert isinstance(model, pl.LightningModule)
--> 238 return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/plugins/precision/precision.py:122, in Precision.optimizer_step(self, optimizer, model, closure, **kwargs)
120 """Hook to run the optimizer step."""
121 closure = partial(self._wrap_closure, model, optimizer, closure)
--> 122 return optimizer.step(closure=closure, **kwargs)
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/torch/optim/lr_scheduler.py:68, in LRScheduler.__init__.<locals>.with_counter.<locals>.wrapper(*args, **kwargs)
66 instance._step_count += 1
67 wrapped = func.__get__(instance, cls)
---> 68 return wrapped(*args, **kwargs)
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/torch/optim/optimizer.py:373, in Optimizer.profile_hook_step.<locals>.wrapper(*args, **kwargs)
368 else:
369 raise RuntimeError(
370 f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
371 )
--> 373 out = func(*args, **kwargs)
374 self._optimizer_step_code()
376 # call optimizer step post hooks
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/torch/optim/optimizer.py:76, in _use_grad_for_differentiable.<locals>._use_grad(self, *args, **kwargs)
74 torch.set_grad_enabled(self.defaults['differentiable'])
75 torch._dynamo.graph_break()
---> 76 ret = func(self, *args, **kwargs)
77 finally:
78 torch._dynamo.graph_break()
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/torch/optim/adam.py:143, in Adam.step(self, closure)
141 if closure is not None:
142 with torch.enable_grad():
--> 143 loss = closure()
145 for group in self.param_groups:
146 params_with_grad = []
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/plugins/precision/precision.py:108, in Precision._wrap_closure(self, model, optimizer, closure)
95 def _wrap_closure(
96 self,
97 model: "pl.LightningModule",
98 optimizer: Steppable,
99 closure: Callable[[], Any],
100 ) -> Any:
101 """This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step``
102 hook is called.
103
(...)
106
107 """
--> 108 closure_result = closure()
109 self._after_closure(model, optimizer)
110 return closure_result
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/automatic.py:144, in Closure.__call__(self, *args, **kwargs)
142 @override
143 def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
--> 144 self._result = self.closure(*args, **kwargs)
145 return self._result.loss
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/automatic.py:129, in Closure.closure(self, *args, **kwargs)
126 @override
127 @torch.enable_grad()
128 def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
--> 129 step_output = self._step_fn()
131 if step_output.closure_loss is None:
132 self.warning_cache.warn("`training_step` returned `None`. If this was on purpose, ignore this warning...")
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/automatic.py:317, in _AutomaticOptimization._training_step(self, kwargs)
306 """Performs the actual train step with the tied hooks.
307
308 Args:
(...)
313
314 """
315 trainer = self.trainer
--> 317 training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
318 self.trainer.strategy.post_training_step() # unused hook - call anyway for backward compatibility
320 if training_step_output is None and trainer.world_size > 1:
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py:319, in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
316 return None
318 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 319 output = fn(*args, **kwargs)
321 # restore current_fx when nested context
322 pl_module._current_fx_name = prev_fx_name
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/pytorch_lightning/strategies/strategy.py:390, in Strategy.training_step(self, *args, **kwargs)
388 if self.model != self.lightning_module:
389 return self._forward_redirection(self.model, self.lightning_module, "training_step", *args, **kwargs)
--> 390 return self.lightning_module.training_step(*args, **kwargs)
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/contextualized/dags/lightning_modules.py:277, in NOTMAD.training_step(self, batch, batch_idx)
267 def training_step(self, batch, batch_idx):
268 (
269 loss,
270 notears,
271 mse_term,
272 l1_term,
273 dag_term,
274 arch_l1_term,
275 arch_dag_term,
276 factor_mat_term,
--> 277 ) = self._batch_loss(batch, batch_idx)
278 ret = {
279 "loss": loss,
280 "train_loss": loss,
(...)
286 "train_factor_l1_loss": factor_mat_term,
287 }
288 self.log_dict(ret)
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/contextualized/dags/lightning_modules.py:235, in NOTMAD._batch_loss(self, batch, batch_idx)
233 mse_term = linear_sem_loss(x_true, w_pred)
234 l1_term = l1_loss(w_pred, self.ss_l1)
--> 235 dag_term = self.ss_dag_loss(w_pred, **self.ss_dag_params)
236 notears = mse_term + l1_term + dag_term
237 W_arch = self.explainer.get_archetypes()
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/contextualized/dags/losses.py:15, in dag_loss_dagma(W, s, alpha, **kwargs)
11 def dag_loss_dagma(W, s=1, alpha=0.0, **kwargs):
12 """DAG loss on batched networks W using the
13 DAGMA log-determinant
14 """
---> 15 sample_losses = torch.Tensor([dag_loss_dagma_indiv(w, s) for w in W])
16 return alpha * torch.mean(sample_losses)
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/contextualized/dags/losses.py:15, in <listcomp>(.0)
11 def dag_loss_dagma(W, s=1, alpha=0.0, **kwargs):
12 """DAG loss on batched networks W using the
13 DAGMA log-determinant
14 """
---> 15 sample_losses = torch.Tensor([dag_loss_dagma_indiv(w, s) for w in W])
16 return alpha * torch.mean(sample_losses)
File /bgfs/alee/LO_LAB/Personal/Alexander_Chang/alc376/envs/contextml/lib/python3.9/site-packages/contextualized/dags/losses.py:7, in dag_loss_dagma_indiv(w, s)
6 def dag_loss_dagma_indiv(w, s=1):
----> 7 M = s * torch.eye(w.shape[-1]) - w * w
8 return w.shape[-1] * np.log(s) - torch.slogdet(M)[1]
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
It's almost certainly the torch.eye. A simple torch.eye(w.shape[-1]).to(w.device)
should fix this, but I won't be able to make a proper fix and test it for a few days. Feel free to change & make a PR if you feel inclined, or I'll get to it a bit later.
Is there something I'm missing in the documentation as to how to move C-ML models and data to and from CUDA/CPU? I keep getting errors when trying to run code based on wrong location of torch.eye, etc. etc.