pachterlab / CGCCP_2023

scVI extension for unspliced RNA
BSD 2-Clause "Simplified" License
16 stars 1 forks source link

Expected all tensors to be on the same device, but found at least two devices #1

Closed biobai closed 4 weeks ago

biobai commented 2 months ago

I follow the instructions from Demo.ipynb.

In [9]: model.train(max_epochs = max_epochs,
   ...:                 train_size = 0.9,
   ...:                 check_val_every_n_epoch  = 1,
   ...:                 plan_kwargs = plan_kwargs)
/users/999/byh/miniforge3/envs/bivi/lib/python3.12/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = os.fork()
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/users/999/byh/miniforge3/envs/bivi/lib/python3.12/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = os.fork()
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
/users/999/byh/miniforge3/envs/bivi/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:1892: PossibleUserWarning: The number of training batches (8) is smaller than the logging interval Trainer(log_every_n_steps=10). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  rank_zero_warn(
Epoch 1/100:   0%|                                                                                                    | 0/100 [00:00<?, ?it/s]<frozen abc>:119: FutureWarning: SparseDataset is deprecated and will be removed in late 2024. It has been replaced by the public classes CSRDataset and CSCDataset.

For instance checks, use `isinstance(X, (anndata.experimental.CSRDataset, anndata.experimental.CSCDataset))` instead.

For creation, use `anndata.experimental.sparse_dataset(X)` instead.

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[9], line 1
----> 1 model.train(max_epochs = max_epochs,
      2                 train_size = 0.9,
      3                 check_val_every_n_epoch  = 1,
      4                 plan_kwargs = plan_kwargs)

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/scvi/model/base/_training_mixin.py:142, in UnsupervisedTrainingMixin.train(self, max_epochs, use_gpu, train_size, validation_size, batch_size, early_stopping, plan_kwargs, **trainer_kwargs)
    131 trainer_kwargs[es] = (
    132     early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es]
    133 )
    134 runner = TrainRunner(
    135     self,
    136     training_plan=training_plan,
   (...)
    140     **trainer_kwargs,
    141 )
--> 142 return runner()

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/scvi/train/_trainrunner.py:82, in TrainRunner.__call__(self)
     79 if hasattr(self.data_splitter, "n_val"):
     80     self.training_plan.n_obs_validation = self.data_splitter.n_val
---> 82 self.trainer.fit(self.training_plan, self.data_splitter)
     83 self._update_history()
     85 # data splitter only gets these attrs after fit

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/scvi/train/_trainer.py:188, in Trainer.fit(self, *args, **kwargs)
    182 if isinstance(args[0], PyroTrainingPlan):
    183     warnings.filterwarnings(
    184         action="ignore",
    185         category=UserWarning,
    186         message="`LightningModule.configure_optimizers` returned `None`",
    187     )
--> 188 super().fit(*args, **kwargs)

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:696, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    677 r"""
    678 Runs the full optimization routine.
    679
   (...)
    693     datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
    694 """
    695 self.strategy.model = model
--> 696 self._call_and_handle_interrupt(
    697     self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    698 )

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:650, in Trainer._call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
    648         return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
    649     else:
--> 650         return trainer_fn(*args, **kwargs)
    651 # TODO(awaelchli): Unify both exceptions below, where `KeyboardError` doesn't re-raise
    652 except KeyboardInterrupt as exception:

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:735, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    731 ckpt_path = ckpt_path or self.resume_from_checkpoint
    732 self._ckpt_path = self.__set_ckpt_path(
    733     ckpt_path, model_provided=True, model_connected=self.lightning_module is not None
    734 )
--> 735 results = self._run(model, ckpt_path=self.ckpt_path)
    737 assert self.state.stopped
    738 self.training = False

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:1166, in Trainer._run(self, model, ckpt_path)
   1162 self._checkpoint_connector.restore_training_state()
   1164 self._checkpoint_connector.resume_end()
-> 1166 results = self._run_stage()
   1168 log.detail(f"{self.__class__.__name__}: trainer tearing down")
   1169 self._teardown()

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:1252, in Trainer._run_stage(self)
   1250 if self.predicting:
   1251     return self._run_predict()
-> 1252 return self._run_train()

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:1283, in Trainer._run_train(self)
   1280 self.fit_loop.trainer = self
   1282 with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1283     self.fit_loop.run()

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/pytorch_lightning/loops/loop.py:200, in Loop.run(self, *args, **kwargs)
    198 try:
    199     self.on_advance_start(*args, **kwargs)
--> 200     self.advance(*args, **kwargs)
    201     self.on_advance_end()
    202     self._restarting = False

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py:271, in FitLoop.advance(self)
    267 self._data_fetcher.setup(
    268     dataloader, batch_to_device=partial(self.trainer._call_strategy_hook, "batch_to_device", dataloader_idx=0)
    269 )
    270 with self.trainer.profiler.profile("run_training_epoch"):
--> 271     self._outputs = self.epoch_loop.run(self._data_fetcher)

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/pytorch_lightning/loops/loop.py:200, in Loop.run(self, *args, **kwargs)
    198 try:
    199     self.on_advance_start(*args, **kwargs)
--> 200     self.advance(*args, **kwargs)
    201     self.on_advance_end()
    202     self._restarting = False

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py:203, in TrainingEpochLoop.advance(self, data_fetcher)
    200     self.batch_progress.increment_started()
    202     with self.trainer.profiler.profile("run_training_batch"):
--> 203         batch_output = self.batch_loop.run(kwargs)
    205 self.batch_progress.increment_processed()
    207 # update non-plateau LR schedulers
    208 # update epoch-interval ones only when we are at the end of training epoch

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/pytorch_lightning/loops/loop.py:200, in Loop.run(self, *args, **kwargs)
    198 try:
    199     self.on_advance_start(*args, **kwargs)
--> 200     self.advance(*args, **kwargs)
    201     self.on_advance_end()
    202     self._restarting = False

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py:87, in TrainingBatchLoop.advance(self, kwargs)
     83 if self.trainer.lightning_module.automatic_optimization:
     84     optimizers = _get_active_optimizers(
     85         self.trainer.optimizers, self.trainer.optimizer_frequencies, kwargs.get("batch_idx", 0)
     86     )
---> 87     outputs = self.optimizer_loop.run(optimizers, kwargs)
     88 else:
     89     outputs = self.manual_loop.run(kwargs)

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/pytorch_lightning/loops/loop.py:200, in Loop.run(self, *args, **kwargs)
    198 try:
    199     self.on_advance_start(*args, **kwargs)
--> 200     self.advance(*args, **kwargs)
    201     self.on_advance_end()
    202     self._restarting = False

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py:201, in OptimizerLoop.advance(self, optimizers, kwargs)
    198 def advance(self, optimizers: List[Tuple[int, Optimizer]], kwargs: OrderedDict) -> None:  # type: ignore[override]
    199     kwargs = self._build_kwargs(kwargs, self.optimizer_idx, self._hiddens)
--> 201     result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position])
    202     if result.loss is not None:
    203         # automatic optimization assumes a loss needs to be returned for extras to be considered as the batch
    204         # would be skipped otherwise
    205         self._outputs[self.optimizer_idx] = result.asdict()

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py:248, in OptimizerLoop._run_optimization(self, kwargs, optimizer)
    240         closure()
    242 # ------------------------------
    243 # BACKWARD PASS
    244 # ------------------------------
    245 # gradient update with accumulated gradients
    246 else:
    247     # the `batch_idx` is optional with inter-batch parallelism
--> 248     self._optimizer_step(optimizer, opt_idx, kwargs.get("batch_idx", 0), closure)
    250 result = closure.consume_result()
    252 if result.loss is not None:
    253     # if no result, user decided to skip optimization
    254     # otherwise update running loss + reset accumulated loss
    255     # TODO: find proper way to handle updating running loss

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py:358, in OptimizerLoop._optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
    355     self.optim_progress.optimizer.step.increment_ready()
    357 # model hook
--> 358 self.trainer._call_lightning_module_hook(
    359     "optimizer_step",
    360     self.trainer.current_epoch,
    361     batch_idx,
    362     optimizer,
    363     opt_idx,
    364     train_step_and_backward_closure,
    365     on_tpu=isinstance(self.trainer.accelerator, TPUAccelerator),
    366     using_native_amp=(self.trainer.amp_backend == AMPType.NATIVE),
    367     using_lbfgs=is_lbfgs,
    368 )
    370 if not should_accumulate:
    371     self.optim_progress.optimizer.step.increment_completed()

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:1550, in Trainer._call_lightning_module_hook(self, hook_name, pl_module, *args, **kwargs)
   1547 pl_module._current_fx_name = hook_name
   1549 with self.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{hook_name}"):
-> 1550     output = fn(*args, **kwargs)
   1552 # restore current_fx when nested context
   1553 pl_module._current_fx_name = prev_fx_name

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/pytorch_lightning/core/module.py:1705, in LightningModule.optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs)
   1623 def optimizer_step(
   1624     self,
   1625     epoch: int,
   (...)
   1632     using_lbfgs: bool = False,
   1633 ) -> None:
   1634     r"""
   1635     Override this method to adjust the default way the :class:`~pytorch_lightning.trainer.trainer.Trainer` calls
   1636     each optimizer.
   (...)
   1703
   1704     """
-> 1705     optimizer.step(closure=optimizer_closure)

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/pytorch_lightning/core/optimizer.py:168, in LightningOptimizer.step(self, closure, **kwargs)
    165     raise MisconfigurationException("When `optimizer.step(closure)` is called, the closure should be callable")
    167 assert self._strategy is not None
--> 168 step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
    170 self._on_after_step()
    172 return step_output

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py:216, in Strategy.optimizer_step(self, optimizer, opt_idx, closure, model, **kwargs)
    206 """Performs the actual optimizer step.
    207
    208 Args:
   (...)
    213     **kwargs: Any extra arguments to ``optimizer.step``
    214 """
    215 model = model or self.lightning_module
--> 216 return self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py:153, in PrecisionPlugin.optimizer_step(self, model, optimizer, optimizer_idx, closure, **kwargs)
    151 if isinstance(model, pl.LightningModule):
    152     closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
--> 153 return optimizer.step(closure=closure, **kwargs)

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/torch/optim/optimizer.py:484, in Optimizer.profile_hook_step.<locals>.wrapper(*args, **kwargs)
    479         else:
    480             raise RuntimeError(
    481                 f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
    482             )
--> 484 out = func(*args, **kwargs)
    485 self._optimizer_step_code()
    487 # call optimizer step post hooks

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/torch/optim/optimizer.py:89, in _use_grad_for_differentiable.<locals>._use_grad(self, *args, **kwargs)
     87     torch.set_grad_enabled(self.defaults["differentiable"])
     88     torch._dynamo.graph_break()
---> 89     ret = func(self, *args, **kwargs)
     90 finally:
     91     torch._dynamo.graph_break()

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/torch/optim/adam.py:205, in Adam.step(self, closure)
    203 if closure is not None:
    204     with torch.enable_grad():
--> 205         loss = closure()
    207 for group in self.param_groups:
    208     params_with_grad: List[Tensor] = []

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py:138, in PrecisionPlugin._wrap_closure(self, model, optimizer, optimizer_idx, closure)
    125 def _wrap_closure(
    126     self,
    127     model: "pl.LightningModule",
   (...)
    130     closure: Callable[[], Any],
    131 ) -> Any:
    132     """This double-closure allows makes sure the ``closure`` is executed before the
    133     ``on_before_optimizer_step`` hook is called.
    134
    135     The closure (generally) runs ``backward`` so this allows inspecting gradients in this hook. This structure is
    136     consistent with the ``PrecisionPlugin`` subclasses that cannot pass ``optimizer.step(closure)`` directly.
    137     """
--> 138     closure_result = closure()
    139     self._after_closure(model, optimizer, optimizer_idx)
    140     return closure_result

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py:146, in Closure.__call__(self, *args, **kwargs)
    145 def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
--> 146     self._result = self.closure(*args, **kwargs)
    147     return self._result.loss

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py:132, in Closure.closure(self, *args, **kwargs)
    131 def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
--> 132     step_output = self._step_fn()
    134     if step_output.closure_loss is None:
    135         self.warning_cache.warn("`training_step` returned `None`. If this was on purpose, ignore this warning...")

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py:407, in OptimizerLoop._training_step(self, kwargs)
    398 """Performs the actual train step with the tied hooks.
    399
    400 Args:
   (...)
    404     A ``ClosureResult`` containing the training step output.
    405 """
    406 # manually capture logged metrics
--> 407 training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
    408 self.trainer.strategy.post_training_step()
    410 model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output)

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:1704, in Trainer._call_strategy_hook(self, hook_name, *args, **kwargs)
   1701     return
   1703 with self.profiler.profile(f"[Strategy]{self.strategy.__class__.__name__}.{hook_name}"):
-> 1704     output = fn(*args, **kwargs)
   1706 # restore current_fx when nested context
   1707 pl_module._current_fx_name = prev_fx_name

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py:358, in Strategy.training_step(self, *args, **kwargs)
    356 with self.precision_plugin.train_step_context():
    357     assert isinstance(self.model, TrainingStep)
--> 358     return self.model.training_step(*args, **kwargs)

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/scvi/train/_trainingplans.py:327, in TrainingPlan.training_step(self, batch, batch_idx, optimizer_idx)
    325 if "kl_weight" in self.loss_kwargs:
    326     self.loss_kwargs.update({"kl_weight": self.kl_weight})
--> 327 _, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs)
    328 self.log("train_loss", scvi_loss.loss, on_epoch=True)
    329 self.compute_and_log_metrics(scvi_loss, self.train_metrics, "train")

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/scvi/train/_trainingplans.py:265, in TrainingPlan.forward(self, *args, **kwargs)
    263 def forward(self, *args, **kwargs):
    264     """Passthrough to the module's forward method."""
--> 265     return self.module(*args, **kwargs)

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/scvi/module/base/_decorators.py:32, in auto_move_data.<locals>.auto_transfer_args(self, *args, **kwargs)
     30 # decorator only necessary after training
     31 if self.training:
---> 32     return fn(self, *args, **kwargs)
     34 device = list({p.device for p in self.parameters()})
     35 if len(device) > 1:

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/scvi/module/base/_base_module.py:156, in BaseModuleClass.forward(self, tensors, get_inference_input_kwargs, get_generative_input_kwargs, inference_kwargs, generative_kwargs, loss_kwargs, compute_loss)
    121 @auto_move_data
    122 def forward(
    123     self,
   (...)
    133     Tuple[torch.Tensor, torch.Tensor, LossRecorder],
    134 ]:
    135     """
    136     Forward pass through the network.
    137
   (...)
    154         another return value.
    155     """
--> 156     return _generic_forward(
    157         self,
    158         tensors,
    159         inference_kwargs,
    160         generative_kwargs,
    161         loss_kwargs,
    162         get_inference_input_kwargs,
    163         get_generative_input_kwargs,
    164         compute_loss,
    165     )

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/scvi/module/base/_base_module.py:566, in _generic_forward(module, tensors, inference_kwargs, generative_kwargs, loss_kwargs, get_inference_input_kwargs, get_generative_input_kwargs, compute_loss)
    564 generative_outputs = module.generative(**generative_inputs, **generative_kwargs)
    565 if compute_loss:
--> 566     losses = module.loss(
    567         tensors, inference_outputs, generative_outputs, **loss_kwargs
    568     )
    569     return inference_outputs, generative_outputs, losses
    570 else:

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/scvi/module/_vae.py:446, in VAE.loss(self, tensors, inference_outputs, generative_outputs, kl_weight)
    443 else:
    444     kl_divergence_l = 0.0
--> 446 reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
    448 kl_local_for_warmup = kl_divergence_z
    449 kl_local_no_warmup = kl_divergence_l

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/BIVI/distributions.py:144, in BivariateNegativeBinomial.log_prob(self, value)
    142 if self.use_custom:
    143     calculate_log_nb = log_prob_custom
--> 144     log_nb = calculate_log_nb(value,
    145                               mu1=self.mu1, mu2=self.mu2,
    146                               theta=self.theta, eps=self._eps,
    147                               THETA_IS = self.THETA_IS,
    148                               custom_dist = self.custom_dist)
    149 else:
    150     log_nb = log_prob_NBuncorr(value,
    151                               mu1 = self.mu1, mu2 = self.mu2, eps = self._eps)

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/BIVI/distributions.py:167, in log_prob_custom(x, mu1, mu2, theta, THETA_IS, eps, custom_dist, **kwargs)
    161 """
    162 Log likelihood (scalar) of a minibatch according to a bivariate nb model
    163 where individual genes use one of the distributions
    164 """
    166 assert custom_dist is not None, "Input a custom_dist"
--> 167 res = custom_dist(x=x, mu1=mu1, mu2=mu2, theta=theta, eps=eps,  THETA_IS = THETA_IS)
    169 return res

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/BIVI/nnNB_module.py:218, in log_prob_nnNB(x, mu1, mu2, theta, THETA_IS, eps, **kwargs)
    209 pv = torch.column_stack((torch.log10(b).reshape(-1),
    210                          torch.log10(beta).reshape(-1),
    211                          torch.log10(gamma).reshape(-1),
   (...)
    215                          n.reshape(-1)
    216                          ))
    217 # run through model
--> 218 w_,hyp_= model(pv)
    221 n = n.reshape(-1,1)
    222 m = m.reshape(-1,1)

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/BIVI/nnNB_module.py:35, in MLP.forward(self, inputs)
     32 def forward(self, inputs):
     33
     34     # pass inputs to first layer, apply sigmoid
---> 35     l_1 = self.sigmoid(self.input(inputs))
     37     # pass to second layer, apply sigmoid
     38     l_2 = self.sigmoid(self.hidden(l_1))

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File ~/miniforge3/envs/bivi/lib/python3.12/site-packages/torch/nn/modules/linear.py:117, in Linear.forward(self, input)
    116 def forward(self, input: Tensor) -> Tensor:
--> 117     return F.linear(input, self.weight, self.bias)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)
$ nvidia-smi
Fri Sep  6 13:50:19 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.161.08             Driver Version: 535.161.08   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A100 80GB PCIe          Off | 00000000:65:00.0 Off |                    0 |
| N/A   48C    P0              74W / 300W |   6168MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA A100 80GB PCIe          Off | 00000000:E3:00.0 Off |                    0 |
| N/A   48C    P0              71W / 300W |  17704MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

conda list

# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                  2_kmp_llvm    conda-forge
_sysroot_linux-64_curr_repodata_hack 3                   h69a702a_16    conda-forge
absl-py                   2.1.0              pyhd8ed1ab_0    conda-forge
anndata                   0.10.9             pyhd8ed1ab_0    conda-forge
aom                       3.9.1                hac33072_0    conda-forge
arpack                    3.9.1           nompi_h77f6705_101    conda-forge
array-api-compat          1.8                pyhd8ed1ab_0    conda-forge
asttokens                 2.4.1              pyhd8ed1ab_0    conda-forge
binutils_impl_linux-64    2.40                 ha1999f0_7    conda-forge
binutils_linux-64         2.40                 hb3c18ed_1    conda-forge
bivi                      0.1.0                    pypi_0    pypi
blas                      1.0                         mkl    conda-forge
blosc                     1.21.5               hc2324a3_1    conda-forge
brotli                    1.1.0                hb9d3cd8_2    conda-forge
brotli-bin                1.1.0                hb9d3cd8_2    conda-forge
brotli-python             1.1.0           py312h2ec8cdc_2    conda-forge
bzip2                     1.0.8                h4bc722e_7    conda-forge
c-ares                    1.33.1               heb4867d_0    conda-forge
ca-certificates           2024.8.30            hbcca054_0    conda-forge
cached-property           1.5.2                hd8ed1ab_1    conda-forge
cached_property           1.5.2              pyha770c72_1    conda-forge
certifi                   2024.8.30          pyhd8ed1ab_0    conda-forge
cffi                      1.17.1          py312h06ac9bb_0    conda-forge
charset-normalizer        3.3.2              pyhd8ed1ab_0    conda-forge
chex                      0.1.86             pyhd8ed1ab_0    conda-forge
click                     8.1.7           unix_pyh707e725_0    conda-forge
colorama                  0.4.6              pyhd8ed1ab_0    conda-forge
comm                      0.2.2              pyhd8ed1ab_0    conda-forge
contextlib2               21.6.0             pyhd8ed1ab_0    conda-forge
contourpy                 1.3.0           py312h68727a3_1    conda-forge
cuda-cccl                 12.6.37                       0    nvidia
cuda-cccl_linux-64        12.6.37                       0    nvidia
cuda-crt-dev_linux-64     12.6.20                       0    nvidia
cuda-crt-tools            12.6.20                       0    nvidia
cuda-cudart               12.4.127                      0    nvidia
cuda-cudart-dev           12.4.127                      0    nvidia
cuda-cudart-dev_linux-64  12.6.68                       0    nvidia
cuda-cudart-static_linux-64 12.6.68                       0    nvidia
cuda-cudart_linux-64      12.6.68                       0    nvidia
cuda-cupti                12.4.127                      0    nvidia
cuda-driver-dev_linux-64  12.6.68                       0    nvidia
cuda-libraries            12.4.0                        0    nvidia
cuda-nvcc                 12.6.20                       0    nvidia
cuda-nvcc-dev_linux-64    12.6.20                       0    nvidia
cuda-nvcc-impl            12.6.20                       0    nvidia
cuda-nvcc-tools           12.6.20                       0    nvidia
cuda-nvcc_linux-64        12.6.20                       0    nvidia
cuda-nvrtc                12.4.127                      0    nvidia
cuda-nvtx                 12.4.127                      0    nvidia
cuda-nvvm-dev_linux-64    12.6.20                       0    nvidia
cuda-nvvm-impl            12.6.20                       0    nvidia
cuda-nvvm-tools           12.6.20                       0    nvidia
cuda-opencl               12.6.68                       0    nvidia
cuda-runtime              12.4.0                        0    nvidia
cuda-version              12.6                          3    nvidia
cudnn                     8.9.7.29             h092f7fd_3    conda-forge
cycler                    0.12.1             pyhd8ed1ab_0    conda-forge
dav1d                     1.2.1                hd590300_0    conda-forge
decorator                 5.1.1              pyhd8ed1ab_0    conda-forge
docrep                    0.3.2              pyh44b312d_0    conda-forge
et_xmlfile                1.1.0              pyhd8ed1ab_0    conda-forge
etils                     1.9.4              pyhd8ed1ab_0    conda-forge
exceptiongroup            1.2.2              pyhd8ed1ab_0    conda-forge
executing                 2.1.0              pyhd8ed1ab_0    conda-forge
ffmpeg                    4.3                  hf484d3e_0    pytorch
filelock                  3.15.4             pyhd8ed1ab_0    conda-forge
flax                      0.9.0              pyhd8ed1ab_0    conda-forge
fonttools                 4.53.1          py312h66e93f0_1    conda-forge
freetype                  2.12.1               h267a509_2    conda-forge
fsspec                    2024.9.0           pyhff2d567_0    conda-forge
gcc_impl_linux-64         12.4.0               hb2e57f8_1    conda-forge
gcc_linux-64              12.4.0               h6b7512a_1    conda-forge
get-annotations           0.1.2              pyhd8ed1ab_0    conda-forge
glpk                      5.0                  h445213a_0    conda-forge
gmp                       6.3.0                hac33072_2    conda-forge
gmpy2                     2.1.5           py312h7201bc8_2    conda-forge
gnutls                    3.6.13               h85f3911_1    conda-forge
grpcio                    1.62.2          py312hb06c811_0    conda-forge
gxx_impl_linux-64         12.4.0               h613a52c_1    conda-forge
gxx_linux-64              12.4.0               h8489865_1    conda-forge
h2                        4.1.0              pyhd8ed1ab_0    conda-forge
h5py                      3.11.0          nompi_py312hb7ab980_102    conda-forge
hdf5                      1.14.3          nompi_hdf9ad27_105    conda-forge
hpack                     4.0.0              pyh9f0ad1d_0    conda-forge
hyperframe                6.0.1              pyhd8ed1ab_0    conda-forge
icu                       73.2                 h59595ed_0    conda-forge
idna                      3.8                pyhd8ed1ab_0    conda-forge
igraph                    0.10.13              hef0740d_0    conda-forge
importlib-metadata        8.4.0              pyha770c72_0    conda-forge
importlib_metadata        8.4.0                hd8ed1ab_0    conda-forge
importlib_resources       6.4.4              pyhd8ed1ab_0    conda-forge
ipython                   8.27.0             pyh707e725_0    conda-forge
ipywidgets                8.1.5              pyhd8ed1ab_0    conda-forge
jax                       0.4.27             pyhd8ed1ab_0    conda-forge
jaxlib                    0.4.23          cuda120py312h6027bbc_202    conda-forge
jedi                      0.19.1             pyhd8ed1ab_0    conda-forge
jinja2                    3.1.4              pyhd8ed1ab_0    conda-forge
joblib                    1.4.2              pyhd8ed1ab_0    conda-forge
jupyterlab_widgets        3.0.13             pyhd8ed1ab_0    conda-forge
kernel-headers_linux-64   3.10.0              h4a8ded7_16    conda-forge
keyutils                  1.6.1                h166bdaf_0    conda-forge
kiwisolver                1.4.7           py312h68727a3_0    conda-forge
krb5                      1.21.3               h659f571_0    conda-forge
lame                      3.100             h166bdaf_1003    conda-forge
lcms2                     2.16                 hb7c19ff_0    conda-forge
ld_impl_linux-64          2.40                 hf3520f5_7    conda-forge
legacy-api-wrap           1.4                pyhd8ed1ab_1    conda-forge
leidenalg                 0.10.2          py312h30efb56_0    conda-forge
lerc                      4.0.0                h27087fc_0    conda-forge
libabseil                 20240116.2      cxx17_he02047a_1    conda-forge
libaec                    1.1.3                h59595ed_0    conda-forge
libavif16                 1.1.1                h104a339_1    conda-forge
libblas                   3.9.0            16_linux64_mkl    conda-forge
libbrotlicommon           1.1.0                hb9d3cd8_2    conda-forge
libbrotlidec              1.1.0                hb9d3cd8_2    conda-forge
libbrotlienc              1.1.0                hb9d3cd8_2    conda-forge
libcblas                  3.9.0            16_linux64_mkl    conda-forge
libcublas                 12.4.2.65                     0    nvidia
libcufft                  11.2.0.44                     0    nvidia
libcufile                 1.11.1.6                      0    nvidia
libcurand                 10.3.7.68                     0    nvidia
libcurl                   8.8.0                hca28451_1    conda-forge
libcusolver               11.6.0.99                     0    nvidia
libcusparse               12.3.0.142                    0    nvidia
libdeflate                1.20                 hd590300_0    conda-forge
libedit                   3.1.20191231         he28a2e2_2    conda-forge
libev                     4.33                 hd590300_2    conda-forge
libexpat                  2.6.3                h5888daf_0    conda-forge
libffi                    3.4.2                h7f98852_5    conda-forge
libgcc                    14.1.0               h77fa898_1    conda-forge
libgcc-devel_linux-64     12.4.0             ha4f9413_101    conda-forge
libgcc-ng                 14.1.0               h69a702a_1    conda-forge
libgfortran               14.1.0               h69a702a_1    conda-forge
libgfortran-ng            14.1.0               h69a702a_1    conda-forge
libgfortran5              14.1.0               hc5f4f2c_1    conda-forge
libgomp                   14.1.0               h77fa898_1    conda-forge
libgrpc                   1.62.2               h15f2491_0    conda-forge
libhwloc                  2.11.1          default_hecaa2ac_1000    conda-forge
libiconv                  1.17                 hd590300_2    conda-forge
libjpeg-turbo             3.0.0                hd590300_1    conda-forge
liblapack                 3.9.0            16_linux64_mkl    conda-forge
libleidenalg              0.11.1               h00ab1b0_0    conda-forge
libllvm14                 14.0.6               hcd5def8_4    conda-forge
libnghttp2                1.58.0               h47da74e_1    conda-forge
libnpp                    12.2.5.2                      0    nvidia
libnsl                    2.0.1                hd590300_0    conda-forge
libnvfatbin               12.6.68                       0    nvidia
libnvjitlink              12.4.99                       0    nvidia
libnvjpeg                 12.3.1.89                     0    nvidia
libpng                    1.6.43               h2797004_0    conda-forge
libprotobuf               4.25.3               h08a7969_0    conda-forge
libre2-11                 2023.09.01           h5a48ba9_2    conda-forge
libsanitizer              12.4.0               h46f95d5_1    conda-forge
libsqlite                 3.46.0               hde9e2c9_0    conda-forge
libssh2                   1.11.0               h0841786_0    conda-forge
libstdcxx                 14.1.0               hc0a3c3a_1    conda-forge
libstdcxx-devel_linux-64  12.4.0             ha4f9413_101    conda-forge
libstdcxx-ng              14.1.0               h4852527_1    conda-forge
libtiff                   4.6.0                h1dd3fc0_3    conda-forge
libuuid                   2.38.1               h0b41bf4_0    conda-forge
libwebp-base              1.4.0                hd590300_0    conda-forge
libxcb                    1.15                 h0b41bf4_0    conda-forge
libxcrypt                 4.4.36               hd590300_1    conda-forge
libxml2                   2.12.7               hc051c1a_1    conda-forge
libzlib                   1.2.13               h4ab18f5_6    conda-forge
llvm-openmp               15.0.7               h0cdce71_0    conda-forge
llvmlite                  0.42.0          py312hb06c811_1    conda-forge
loompy                    3.0.6                      py_0    conda-forge
lz4-c                     1.9.4                hcb278e6_0    conda-forge
markdown                  3.6                pyhd8ed1ab_0    conda-forge
markdown-it-py            3.0.0              pyhd8ed1ab_0    conda-forge
markupsafe                2.1.5           py312h66e93f0_1    conda-forge
matplotlib-base           3.9.2           py312h854627b_0    conda-forge
matplotlib-inline         0.1.7              pyhd8ed1ab_0    conda-forge
mdurl                     0.1.2              pyhd8ed1ab_0    conda-forge
mkl                       2022.2.1         h84fe81f_16997    conda-forge
ml-collections            0.1.1              pyhd8ed1ab_0    conda-forge
ml_dtypes                 0.4.0           py312hf9745cd_2    conda-forge
mpc                       1.3.1                h24ddda3_0    conda-forge
mpfr                      4.2.1                h38ae2d0_2    conda-forge
mpmath                    1.3.0              pyhd8ed1ab_0    conda-forge
msgpack-python            1.0.8           py312h68727a3_1    conda-forge
mudata                    0.3.1              pyhd8ed1ab_0    conda-forge
multipledispatch          0.6.0              pyhd8ed1ab_1    conda-forge
munkres                   1.1.4              pyh9f0ad1d_0    conda-forge
natsort                   8.4.0              pyhd8ed1ab_0    conda-forge
nccl                      2.22.3.1             hbc370b7_1    conda-forge
ncurses                   6.5                  he02047a_1    conda-forge
nest-asyncio              1.6.0              pyhd8ed1ab_0    conda-forge
nettle                    3.6                  he412f7d_0    conda-forge
networkx                  3.3                pyhd8ed1ab_1    conda-forge
numba                     0.59.1          py312hacefee8_0    conda-forge
numpy                     1.26.4          py312heda63a1_0    conda-forge
numpy_groupies            0.11.2             pyhd8ed1ab_0    conda-forge
numpyro                   0.15.2             pyhd8ed1ab_0    conda-forge
openh264                  2.1.1                h780b84a_0    conda-forge
openjpeg                  2.5.2                h488ebb8_0    conda-forge
openpyxl                  3.1.5           py312h710cb58_1    conda-forge
openssl                   3.3.2                hb9d3cd8_0    conda-forge
opt-einsum                3.3.0                hd8ed1ab_2    conda-forge
opt_einsum                3.3.0              pyhc1e730c_2    conda-forge
optax                     0.2.2              pyhd8ed1ab_1    conda-forge
orbax-checkpoint          0.4.4              pyhd8ed1ab_0    conda-forge
packaging                 24.1               pyhd8ed1ab_0    conda-forge
pandas                    2.2.2           py312h1d6d2e6_1    conda-forge
parso                     0.8.4              pyhd8ed1ab_0    conda-forge
patsy                     0.5.6              pyhd8ed1ab_0    conda-forge
pexpect                   4.9.0              pyhd8ed1ab_0    conda-forge
pickleshare               0.7.5                   py_1003    conda-forge
pillow                    10.3.0          py312hdcec9eb_0    conda-forge
pip                       24.2               pyh8b19718_1    conda-forge
prompt-toolkit            3.0.47             pyha770c72_0    conda-forge
protobuf                  4.25.3          py312h83439f5_1    conda-forge
pthread-stubs             0.4               h36c2ea0_1001    conda-forge
ptyprocess                0.7.0              pyhd3deb0d_0    conda-forge
pure_eval                 0.2.3              pyhd8ed1ab_0    conda-forge
pybind11-abi              4                    hd8ed1ab_3    conda-forge
pycparser                 2.22               pyhd8ed1ab_0    conda-forge
pydeprecate               0.3.2              pyhd8ed1ab_0    conda-forge
pygments                  2.18.0             pyhd8ed1ab_0    conda-forge
pynndescent               0.5.13             pyhff2d567_0    conda-forge
pyparsing                 3.1.4              pyhd8ed1ab_0    conda-forge
pyro-api                  0.1.2              pyhd8ed1ab_0    conda-forge
pyro-ppl                  1.9.1              pyhd8ed1ab_0    conda-forge
pysocks                   1.7.1              pyha2e5f31_6    conda-forge
python                    3.12.3          hab00c5b_0_cpython    conda-forge
python-dateutil           2.9.0              pyhd8ed1ab_0    conda-forge
python-igraph             0.11.5          py312h4f72774_1    conda-forge
python-tzdata             2024.1             pyhd8ed1ab_0    conda-forge
python_abi                3.12                    5_cp312    conda-forge
pytorch                   2.4.1           py3.12_cuda12.4_cudnn9.1.0_0    pytorch
pytorch-cuda              12.4                 hc786d27_6    pytorch
pytorch-lightning         1.7.7              pyhd8ed1ab_0    conda-forge
pytorch-mutex             1.0                        cuda    pytorch
pytz                      2024.1             pyhd8ed1ab_0    conda-forge
pyyaml                    6.0.2           py312h66e93f0_1    conda-forge
qhull                     2020.2               h434a139_5    conda-forge
rav1e                     0.6.6                he8a937b_2    conda-forge
re2                       2023.09.01           h7f4b329_2    conda-forge
readline                  8.2                  h8228510_1    conda-forge
requests                  2.32.3             pyhd8ed1ab_0    conda-forge
rich                      13.7.1             pyhd8ed1ab_0    conda-forge
scanpy                    1.10.2             pyhd8ed1ab_0    conda-forge
scikit-learn              1.5.1           py312h775a589_0    conda-forge
scipy                     1.14.1          py312h7d485d2_0    conda-forge
scvi-tools                0.18.0             pyhd8ed1ab_0    conda-forge
seaborn                   0.13.2               hd8ed1ab_2    conda-forge
seaborn-base              0.13.2             pyhd8ed1ab_2    conda-forge
session-info              1.0.0              pyhd8ed1ab_0    conda-forge
setuptools                73.0.1             pyhd8ed1ab_0    conda-forge
six                       1.16.0             pyh6c4a22f_0    conda-forge
snappy                    1.2.1                ha2e4443_0    conda-forge
stack_data                0.6.2              pyhd8ed1ab_0    conda-forge
statsmodels               0.14.2          py312h085067d_0    conda-forge
stdlib-list               0.10.0             pyhd8ed1ab_0    conda-forge
svt-av1                   2.2.1                h5888daf_0    conda-forge
sympy                     1.13.2          pypyh2585a3b_103    conda-forge
sysroot_linux-64          2.17                h4a8ded7_16    conda-forge
tbb                       2021.13.0            h84d6215_0    conda-forge
tensorboard               2.17.1             pyhd8ed1ab_0    conda-forge
tensorboard-data-server   0.7.0           py312h241aef2_1    conda-forge
tensorstore               0.1.60          py312h80f44a3_0    conda-forge
texttable                 1.7.0              pyhd8ed1ab_0    conda-forge
threadpoolctl             3.5.0              pyhc1e730c_0    conda-forge
tk                        8.6.13          noxft_h4845f30_101    conda-forge
toolz                     0.12.1             pyhd8ed1ab_0    conda-forge
torchaudio                2.4.1               py312_cu124    pytorch
torchmetrics              0.11.4             pyhd8ed1ab_0    conda-forge
torchtriton               3.0.0                     py312    pytorch
torchvision               0.19.1              py312_cu124    pytorch
tqdm                      4.66.5             pyhd8ed1ab_0    conda-forge
traitlets                 5.14.3             pyhd8ed1ab_0    conda-forge
typing-extensions         4.12.2               hd8ed1ab_0    conda-forge
typing_extensions         4.12.2             pyha770c72_0    conda-forge
tzdata                    2024a                h8827d51_1    conda-forge
umap-learn                0.5.6           py312h7900ff3_1    conda-forge
urllib3                   2.2.2              pyhd8ed1ab_1    conda-forge
wcwidth                   0.2.13             pyhd8ed1ab_0    conda-forge
werkzeug                  3.0.4              pyhd8ed1ab_0    conda-forge
wheel                     0.44.0             pyhd8ed1ab_0    conda-forge
widgetsnbextension        4.0.13             pyhd8ed1ab_0    conda-forge
xorg-libxau               1.0.11               hd590300_0    conda-forge
xorg-libxdmcp             1.1.3                h7f98852_0    conda-forge
xz                        5.2.6                h166bdaf_0    conda-forge
yaml                      0.2.5                h7f98852_2    conda-forge
zipp                      3.20.1             pyhd8ed1ab_0    conda-forge
zlib                      1.2.13               h4ab18f5_6    conda-forge
zstandard                 0.23.0          py312hef9b889_1    conda-forge
zstd                      1.5.6                ha6fb4c9_0    conda-forge
mtcarilli commented 2 months ago

Can you check what device your model is on, cuda:0 or cuda:1?

You might be able to fix this by setting only one visible device at the top of your notebook/script (before defining your model):

  import os
  # set only GPU 0 as visible
  os.environ['CUDA_VISIBLE_DEVICES'] = '0'  

Hope this helps!

biobai commented 2 months ago

That works. Thanks, @mtcarilli