suinleelab / contrastiveVI

BSD 3-Clause "New" or "Revised" License
49 stars 1 forks source link

Issue during training of model #19

Open svntrx opened 1 year ago

svntrx commented 1 year ago

Dear contrastiveVI dev team,

First of all thanks a lot for developing this cool model and congrats on the Nature Methods paper on it! I was trying to apply it to a MIBI dataset harboring different drug treatment conditions, but unfortunately ran into an issue I can't seem to figure out myself.

I run the following code (taken from the Alzheimer example), wherein treated_control is an anndata file containing my single cell data and with "Drug" being the condition column.

# imports
from contrastive_vi.model import ContrastiveVI
from pytorch_lightning.utilities.seed import seed_everything

seed_everything(42) # For reproducibility

treated_control = treated_control.copy()
ContrastiveVI.setup_anndata(treated_control) # setup adata for use with this model

model = ContrastiveVI(
    treated_control,
    n_salient_latent=10,
    n_background_latent=10,
    use_observed_lib_size=False
)

background_indices = np.where(treated_control.obs["Drug"] == "CTRL")[0]
target_indices = np.where(treated_control.obs["Drug"] != "CTRL")[0]

model.train(
    check_val_every_n_epoch=1,
    train_size=0.8,
    background_indices=background_indices,
    target_indices=target_indices,
    use_gpu=False,
    early_stopping=True,
    max_epochs=500,
)

running model.train, I get the following error message:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[20], line 4
      1 background_indices = np.where(treated_control.obs["Drug"] == "CTRL")[0]
      2 target_indices = np.where(treated_control.obs["Drug"] != "CTRL")[0]
----> 4 model.train(
      5     check_val_every_n_epoch=1,
      6     train_size=0.8,
      7     background_indices=background_indices,
      8     target_indices=target_indices,
      9     use_gpu=False,
     10     early_stopping=True,
     11     max_epochs=500,
     12 )

File ~\Anaconda3\envs\ST0036\Lib\site-packages\contrastive_vi\model\base\training_mixin.py:88, in ContrastiveTrainingMixin.train(self, background_indices, target_indices, max_epochs, use_gpu, train_size, validation_size, batch_size, early_stopping, plan_kwargs, **trainer_kwargs)
     77 trainer_kwargs[es] = (
     78     early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es]
     79 )
     80 runner = TrainRunner(
     81     self,
     82     training_plan=training_plan,
   (...)
     86     **trainer_kwargs,
     87 )
---> 88 return runner()

File ~\Anaconda3\envs\ST0036\Lib\site-packages\scvi\train\_trainrunner.py:74, in TrainRunner.__call__(self)
     71 if hasattr(self.data_splitter, "n_val"):
     72     self.training_plan.n_obs_validation = self.data_splitter.n_val
---> 74 self.trainer.fit(self.training_plan, self.data_splitter)
     75 self._update_history()
     77 # data splitter only gets these attrs after fit

File ~\Anaconda3\envs\ST0036\Lib\site-packages\scvi\train\_trainer.py:186, in Trainer.fit(self, *args, **kwargs)
    180 if isinstance(args[0], PyroTrainingPlan):
    181     warnings.filterwarnings(
    182         action="ignore",
    183         category=UserWarning,
    184         message="`LightningModule.configure_optimizers` returned `None`",
    185     )
--> 186 super().fit(*args, **kwargs)

File ~\Anaconda3\envs\ST0036\Lib\site-packages\pytorch_lightning\trainer\trainer.py:740, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, train_dataloader, ckpt_path)
    735     rank_zero_deprecation(
    736         "`trainer.fit(train_dataloader)` is deprecated in v1.4 and will be removed in v1.6."
    737         " Use `trainer.fit(train_dataloaders)` instead. HINT: added 's'"
    738     )
    739     train_dataloaders = train_dataloader
--> 740 self._call_and_handle_interrupt(
    741     self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    742 )

File ~\Anaconda3\envs\ST0036\Lib\site-packages\pytorch_lightning\trainer\trainer.py:685, in Trainer._call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
    675 r"""
    676 Error handling, intended to be used only for main trainer function entry points (fit, validate, test, predict)
    677 as all errors should funnel through them
   (...)
    682     **kwargs: keyword arguments to be passed to `trainer_fn`
    683 """
    684 try:
--> 685     return trainer_fn(*args, **kwargs)
    686 # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7
    687 except KeyboardInterrupt as exception:

File ~\Anaconda3\envs\ST0036\Lib\site-packages\pytorch_lightning\trainer\trainer.py:777, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    775 # TODO: ckpt_path only in v1.7
    776 ckpt_path = ckpt_path or self.resume_from_checkpoint
--> 777 self._run(model, ckpt_path=ckpt_path)
    779 assert self.state.stopped
    780 self.training = False

File ~\Anaconda3\envs\ST0036\Lib\site-packages\pytorch_lightning\trainer\trainer.py:1138, in Trainer._run(self, model, ckpt_path)
   1136 self.call_hook("on_before_accelerator_backend_setup")
   1137 self.accelerator.setup_environment()
-> 1138 self._call_setup_hook()  # allow user to setup lightning_module in accelerator environment
   1140 # check if we should delay restoring checkpoint till later
   1141 if not self.training_type_plugin.restore_checkpoint_after_pre_dispatch:

File ~\Anaconda3\envs\ST0036\Lib\site-packages\pytorch_lightning\trainer\trainer.py:1438, in Trainer._call_setup_hook(self)
   1435 self.training_type_plugin.barrier("pre_setup")
   1437 if self.datamodule is not None:
-> 1438     self.datamodule.setup(stage=fn)
   1439 self.call_hook("setup", stage=fn)
   1441 self.training_type_plugin.barrier("post_setup")

File ~\Anaconda3\envs\ST0036\Lib\site-packages\pytorch_lightning\core\datamodule.py:461, in LightningDataModule._track_data_hook_calls.<locals>.wrapped_fn(*args, **kwargs)
    459     else:
    460         attr = f"_has_{name}_{stage}"
--> 461         has_run = getattr(obj, attr)
    462         setattr(obj, attr, True)
    464 elif name == "prepare_data":

AttributeError: 'ContrastiveDataSplitter' object has no attribute '_has_setup_TrainerFn.FITTING'

I ran the package in a fresh conda environment. Any ideas where the issue may lie?

Thanks a ton for your help!

Best regards, Sven

gwendolinelecuyer commented 1 year ago

Hello, I have the same issue, did you find a solution?

Thank you !! Best regards, Gwendoline

ghost commented 1 year ago

Also experiencing the same issue.

I'm guessing it has something to do with this warning: DEPRECATION: pytorch-lightning 1.5.10 has a non-standard dependency specifier torch>=1.7.*. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pytorch-lightning or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063

Will be looking forward for any solution! Thanks a ton! Seyoon

ethanweinberger commented 1 year ago

Hey all,

Thanks for reporting the issue here. Would you be able to share any additional information about your environments? We haven't seen this issue before on our end and it looks like this issue isn't present in the Colab notebooks, so I'm not immediately sure what the issue is. As one sanity check, could you try training a model with one of the AnnData files from the tutorials (as opposed to your own files)? That might help narrow down where exactly the issue is.

Thanks, Ethan

gwendolinelecuyer commented 1 year ago

Hello, Thank you for your reply. I tried using the Zheng 2017 dataset and I had the same issue. The only difference I have with your code is that I don't use GPU, I use CPU. I found the same error here: https://discourse.scverse.org/t/attributeerror-datasplitter-object-has-no-attribute-has-setup-trainerfn-fitting/1518, it seems to be an update problem.

Here my error, env info and the code I ran.

Env:

Name Version Build Channel

_libgcc_mutex 0.1 conda_forge conda-forge _openmp_mutex 4.5 2_gnu conda-forge absl-py 1.4.0 pypi_0 pypi aiohttp 3.8.5 pypi_0 pypi aiosignal 1.3.1 pypi_0 pypi anndata 0.9.2 pypi_0 pypi annotated-types 0.5.0 pypi_0 pypi anyio 3.7.1 pypi_0 pypi arrow 1.2.3 pypi_0 pypi asttokens 2.2.1 pypi_0 pypi async-timeout 4.0.3 pypi_0 pypi attrs 23.1.0 pypi_0 pypi backcall 0.2.0 pypi_0 pypi backoff 2.2.1 pypi_0 pypi beautifulsoup4 4.12.2 pypi_0 pypi blessed 1.20.0 pypi_0 pypi bzip2 1.0.8 h7f98852_4 conda-forge ca-certificates 2023.7.22 hbcca054_0 conda-forge cachetools 5.3.1 pypi_0 pypi certifi 2023.7.22 pypi_0 pypi charset-normalizer 3.2.0 pypi_0 pypi chex 0.1.7 pypi_0 pypi click 8.1.7 pypi_0 pypi cmake 3.27.2 pypi_0 pypi comm 0.1.4 pypi_0 pypi contextlib2 21.6.0 pypi_0 pypi contourpy 1.1.0 pypi_0 pypi contrastive-vi 0.2.0 pypi_0 pypi croniter 1.4.1 pypi_0 pypi cycler 0.11.0 pypi_0 pypi dateutils 0.6.12 pypi_0 pypi decorator 5.1.1 pypi_0 pypi deepdiff 6.3.1 pypi_0 pypi dm-tree 0.1.8 pypi_0 pypi docrep 0.3.2 pypi_0 pypi et-xmlfile 1.1.0 pypi_0 pypi etils 1.4.1 pypi_0 pypi executing 1.2.0 pypi_0 pypi fastapi 0.101.1 pypi_0 pypi filelock 3.12.2 pypi_0 pypi flax 0.7.2 pypi_0 pypi fonttools 4.42.1 pypi_0 pypi frozenlist 1.4.0 pypi_0 pypi fsspec 2023.6.0 pypi_0 pypi future 0.18.3 pypi_0 pypi gdown 4.7.1 pypi_0 pypi google-auth 2.22.0 pypi_0 pypi google-auth-oauthlib 1.0.0 pypi_0 pypi grpcio 1.57.0 pypi_0 pypi h11 0.14.0 pypi_0 pypi h5py 3.9.0 pypi_0 pypi idna 3.4 pypi_0 pypi importlib-resources 6.0.1 pypi_0 pypi inquirer 3.1.3 pypi_0 pypi ipython 8.14.0 pypi_0 pypi ipywidgets 8.1.0 pypi_0 pypi itsdangerous 2.1.2 pypi_0 pypi jax 0.4.14 pypi_0 pypi jaxlib 0.4.14 pypi_0 pypi jedi 0.19.0 pypi_0 pypi jinja2 3.1.2 pypi_0 pypi joblib 1.3.2 pypi_0 pypi jupyterlab-widgets 3.0.8 pypi_0 pypi kiwisolver 1.4.4 pypi_0 pypi ld_impl_linux-64 2.40 h41732ed_0 conda-forge libexpat 2.5.0 hcb278e6_1 conda-forge libffi 3.4.2 h7f98852_5 conda-forge libgcc-ng 13.1.0 he5830b7_0 conda-forge libgomp 13.1.0 he5830b7_0 conda-forge libnsl 2.0.0 h7f98852_0 conda-forge libsqlite 3.42.0 h2797004_0 conda-forge libuuid 2.38.1 h0b41bf4_0 conda-forge libzlib 1.2.13 hd590300_5 conda-forge lightning 2.0.7 pypi_0 pypi lightning-cloud 0.5.37 pypi_0 pypi lightning-utilities 0.9.0 pypi_0 pypi lit 16.0.6 pypi_0 pypi llvmlite 0.40.1 pypi_0 pypi markdown 3.4.4 pypi_0 pypi markdown-it-py 3.0.0 pypi_0 pypi markupsafe 2.1.3 pypi_0 pypi matplotlib 3.7.2 pypi_0 pypi matplotlib-inline 0.1.6 pypi_0 pypi mdurl 0.1.2 pypi_0 pypi ml-collections 0.1.1 pypi_0 pypi ml-dtypes 0.2.0 pypi_0 pypi mpmath 1.3.0 pypi_0 pypi msgpack 1.0.5 pypi_0 pypi mudata 0.2.3 pypi_0 pypi multidict 6.0.4 pypi_0 pypi multipledispatch 1.0.0 pypi_0 pypi natsort 8.4.0 pypi_0 pypi ncurses 6.4 hcb278e6_0 conda-forge nest-asyncio 1.5.7 pypi_0 pypi networkx 3.1 pypi_0 pypi numba 0.57.1 pypi_0 pypi numpy 1.24.4 pypi_0 pypi numpyro 0.12.1 pypi_0 pypi nvidia-cublas-cu11 11.10.3.66 pypi_0 pypi nvidia-cuda-cupti-cu11 11.7.101 pypi_0 pypi nvidia-cuda-nvrtc-cu11 11.7.99 pypi_0 pypi nvidia-cuda-runtime-cu11 11.7.99 pypi_0 pypi nvidia-cudnn-cu11 8.5.0.96 pypi_0 pypi nvidia-cufft-cu11 10.9.0.58 pypi_0 pypi nvidia-curand-cu11 10.2.10.91 pypi_0 pypi nvidia-cusolver-cu11 11.4.0.1 pypi_0 pypi nvidia-cusparse-cu11 11.7.4.91 pypi_0 pypi nvidia-nccl-cu11 2.14.3 pypi_0 pypi nvidia-nvtx-cu11 11.7.91 pypi_0 pypi oauthlib 3.2.2 pypi_0 pypi openpyxl 3.1.2 pypi_0 pypi openssl 3.1.2 hd590300_0 conda-forge opt-einsum 3.3.0 pypi_0 pypi optax 0.1.7 pypi_0 pypi orbax-checkpoint 0.3.5 pypi_0 pypi ordered-set 4.1.0 pypi_0 pypi packaging 23.1 pypi_0 pypi pandas 2.0.3 pypi_0 pypi parso 0.8.3 pypi_0 pypi patsy 0.5.3 pypi_0 pypi pexpect 4.8.0 pypi_0 pypi pickleshare 0.7.5 pypi_0 pypi pillow 10.0.0 pypi_0 pypi pip 23.2.1 pyhd8ed1ab_0 conda-forge prompt-toolkit 3.0.39 pypi_0 pypi protobuf 3.20.1 pypi_0 pypi psutil 5.9.5 pypi_0 pypi ptyprocess 0.7.0 pypi_0 pypi pure-eval 0.2.2 pypi_0 pypi pyasn1 0.5.0 pypi_0 pypi pyasn1-modules 0.3.0 pypi_0 pypi pydantic 2.1.1 pypi_0 pypi pydantic-core 2.4.0 pypi_0 pypi pydeprecate 0.3.1 pypi_0 pypi pygments 2.16.1 pypi_0 pypi pyjwt 2.8.0 pypi_0 pypi pynndescent 0.5.10 pypi_0 pypi pyparsing 3.0.9 pypi_0 pypi pyro-api 0.1.2 pypi_0 pypi pyro-ppl 1.8.6 pypi_0 pypi pysocks 1.7.1 pypi_0 pypi python 3.11.4 hab00c5b_0_cpython conda-forge python-dateutil 2.8.2 pypi_0 pypi python-editor 1.0.4 pypi_0 pypi python-multipart 0.0.6 pypi_0 pypi pytorch-lightning 1.5.10 pypi_0 pypi pytz 2023.3 pypi_0 pypi pyyaml 6.0.1 pypi_0 pypi readchar 4.0.5 pypi_0 pypi readline 8.2 h8228510_1 conda-forge requests 2.31.0 pypi_0 pypi requests-oauthlib 1.3.1 pypi_0 pypi rich 13.5.2 pypi_0 pypi rsa 4.9 pypi_0 pypi scanpy 1.9.3 pypi_0 pypi scikit-learn 1.3.0 pypi_0 pypi scipy 1.11.2 pypi_0 pypi scvi-tools 0.16.1 pypi_0 pypi seaborn 0.12.2 pypi_0 pypi session-info 1.0.0 pypi_0 pypi setuptools 59.5.0 pypi_0 pypi six 1.16.0 pypi_0 pypi sniffio 1.3.0 pypi_0 pypi soupsieve 2.4.1 pypi_0 pypi sparse 0.14.0 pypi_0 pypi stack-data 0.6.2 pypi_0 pypi starlette 0.27.0 pypi_0 pypi starsessions 1.3.0 pypi_0 pypi statsmodels 0.14.0 pypi_0 pypi stdlib-list 0.9.0 pypi_0 pypi sympy 1.12 pypi_0 pypi tensorboard 2.14.0 pypi_0 pypi tensorboard-data-server 0.7.1 pypi_0 pypi tensorstore 0.1.41 pypi_0 pypi threadpoolctl 3.2.0 pypi_0 pypi tk 8.6.12 h27826a3_0 conda-forge toolz 0.12.0 pypi_0 pypi torch 2.0.1 pypi_0 pypi torchmetrics 1.0.3 pypi_0 pypi tqdm 4.66.1 pypi_0 pypi traitlets 5.9.0 pypi_0 pypi triton 2.0.0 pypi_0 pypi typing-extensions 4.7.1 pypi_0 pypi tzdata 2023.3 pypi_0 pypi umap-learn 0.5.3 pypi_0 pypi urllib3 1.26.16 pypi_0 pypi uvicorn 0.23.2 pypi_0 pypi wcwidth 0.2.6 pypi_0 pypi websocket-client 1.6.1 pypi_0 pypi websockets 11.0.3 pypi_0 pypi werkzeug 2.3.7 pypi_0 pypi wheel 0.41.1 pyhd8ed1ab_0 conda-forge widgetsnbextension 4.0.8 pypi_0 pypi xarray 2023.8.0 pypi_0 pypi xz 5.2.6 h166bdaf_0 conda-forge yarl 1.9.2 pypi_0 pypi zipp 3.16.2 pypi_0 pypi

Code : adata = sc.read_h5ad("/groups/irset/glecuyer/projects/treatment_ccRCC/scripts/test_outils/contrastiveVI/env/zheng_2017.h5ad") ContrastiveVI.setup_anndata(adata,layer="count") adata.obs['condition'] background_indices=np.where(adata.obs["condition"]=="healthy")[0] target_indices=np.where(adata.obs["condition"]!="healthy")[0] model=ContrastiveVI(adata,n_salient_latent=10, n_background_latent=10,use_observed_lib_size=False)

model.train(check_val_every_n_epoch=1, train_size=0.8, background_indices=background_indices, target_indices=target_indices, use_gpu=False, early_stopping=True, max_epochs=500,)

Error : GPU available: False, used: False TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs Traceback (most recent call last): File "", line 1, in File "/groups/irset/glecuyer/projects/treatment_ccRCC/scripts/test_outils/contrastiveVI/env/contrastivevi/lib/python3.11/site-packages/contrastive_vi/model/base/training_mixin.py", line 88, in train return runner() ^^^^^^^^ File "/groups/irset/glecuyer/projects/treatment_ccRCC/scripts/test_outils/contrastiveVI/env/contrastivevi/lib/python3.11/site-packages/scvi/train/_trainrunner.py", line 74, in call self.trainer.fit(self.training_plan, self.data_splitter) File "/groups/irset/glecuyer/projects/treatment_ccRCC/scripts/test_outils/contrastiveVI/env/contrastivevi/lib/python3.11/site-packages/scvi/train/_trainer.py", line 186, in fit super().fit(*args, *kwargs) File "/groups/irset/glecuyer/projects/treatment_ccRCC/scripts/test_outils/contrastiveVI/env/contrastivevi/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 740, in fit self._call_and_handle_interrupt( File "/groups/irset/glecuyer/projects/treatment_ccRCC/scripts/test_outils/contrastiveVI/env/contrastivevi/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 685, in _call_and_handle_interrupt return trainer_fn(args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/groups/irset/glecuyer/projects/treatment_ccRCC/scripts/test_outils/contrastiveVI/env/contrastivevi/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 777, in _fit_impl self._run(model, ckpt_path=ckpt_path) File "/groups/irset/glecuyer/projects/treatment_ccRCC/scripts/test_outils/contrastiveVI/env/contrastivevi/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 1138, in _run self._call_setup_hook() # allow user to setup lightning_module in accelerator environment ^^^^^^^^^^^^^^^^^^^^^^^ File "/groups/irset/glecuyer/projects/treatment_ccRCC/scripts/test_outils/contrastiveVI/env/contrastivevi/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 1438, in _call_setup_hook self.datamodule.setup(stage=fn) File "/groups/irset/glecuyer/projects/treatment_ccRCC/scripts/test_outils/contrastiveVI/env/contrastivevi/lib/python3.11/site-packages/pytorch_lightning/core/datamodule.py", line 461, in wrapped_fn has_run = getattr(obj, attr) ^^^^^^^^^^^^^^^^^^ AttributeError: 'ContrastiveDataSplitter' object has no attribute '_has_setup_TrainerFn.FITTING'

Thank you very much for your help !! Sincerely, Gwendoline

ethanweinberger commented 1 year ago

Hey @gwendolinelecuyer,

Thanks for the additional info. Based on the link you provided it seems that (hopefully) this issue should be resolved once contrastiveVI is updated to be compatible with the latest scvi-tools version.

We're currently working on integrating contrastiveVI into the main scvi-tools package, and we're planning on having a pull request ready sometime this week. Once the integration is complete you thus (in theory) should be able to use contrastiveVI from the latest scvi-tools without running into this issue.

I'll make another comment here once the integration is ready.

Best, Ethan

hamid-bolouri commented 8 months ago

Hi @ethanweinberger

I just installed 'contrastiveVI' and tried to run your 'alzheimers_response.ipynb' using your data. I get the same error as reported above (details below). Were you able to update the package as discussed above? THANKS! --Hamid

Details: Everything goes fine until 'model.train()', at which point, I get:

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[5], line 6
      3 background_indices = np.where(adata.obs["batchCond"] == "ct")[0]
      4 target_indices = np.where(adata.obs["batchCond"] != "ct")[0]
----> 6 model.train(
      7     check_val_every_n_epoch=1,
      8     train_size=0.8,
      9     background_indices=background_indices,
     10     target_indices=target_indices,
     11     use_gpu=True,
     12     early_stopping=True,
     13     max_epochs=500,
     14 )

File [~/mambaforge/envs/contrastiveVI_env/lib/python3.11/site-packages/contrastive_vi/model/base/training_mixin.py:88](http://localhost:8888/~/mambaforge/envs/contrastiveVI_env/lib/python3.11/site-packages/contrastive_vi/model/base/training_mixin.py#line=87), in ContrastiveTrainingMixin.train(self, background_indices, target_indices, max_epochs, use_gpu, train_size, validation_size, batch_size, early_stopping, plan_kwargs, **trainer_kwargs)
     77 trainer_kwargs[es] = (
     78     early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es]
     79 )
     80 runner = TrainRunner(
     81     self,
     82     training_plan=training_plan,
   (...)
     86     **trainer_kwargs,
     87 )
---> 88 return runner()

File [~/mambaforge/envs/contrastiveVI_env/lib/python3.11/site-packages/scvi/train/_trainrunner.py:74](http://localhost:8888/~/mambaforge/envs/contrastiveVI_env/lib/python3.11/site-packages/scvi/train/_trainrunner.py#line=73), in TrainRunner.__call__(self)
     71 if hasattr(self.data_splitter, "n_val"):
     72     self.training_plan.n_obs_validation = self.data_splitter.n_val
---> 74 self.trainer.fit(self.training_plan, self.data_splitter)
     75 self._update_history()
     77 # data splitter only gets these attrs after fit

File [~/mambaforge/envs/contrastiveVI_env/lib/python3.11/site-packages/scvi/train/_trainer.py:186](http://localhost:8888/~/mambaforge/envs/contrastiveVI_env/lib/python3.11/site-packages/scvi/train/_trainer.py#line=185), in Trainer.fit(self, *args, **kwargs)
    180 if isinstance(args[0], PyroTrainingPlan):
    181     warnings.filterwarnings(
    182         action="ignore",
    183         category=UserWarning,
    184         message="`LightningModule.configure_optimizers` returned `None`",
    185     )
--> 186 super().fit(*args, **kwargs)

File [~/mambaforge/envs/contrastiveVI_env/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:740](http://localhost:8888/~/mambaforge/envs/contrastiveVI_env/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py#line=739), in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, train_dataloader, ckpt_path)
    735     rank_zero_deprecation(
    736         "`trainer.fit(train_dataloader)` is deprecated in v1.4 and will be removed in v1.6."
    737         " Use `trainer.fit(train_dataloaders)` instead. HINT: added 's'"
    738     )
    739     train_dataloaders = train_dataloader
--> 740 self._call_and_handle_interrupt(
    741     self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    742 )

File [~/mambaforge/envs/contrastiveVI_env/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:685](http://localhost:8888/~/mambaforge/envs/contrastiveVI_env/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py#line=684), in Trainer._call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
    675 r"""
    676 Error handling, intended to be used only for main trainer function entry points (fit, validate, test, predict)
    677 as all errors should funnel through them
   (...)
    682     **kwargs: keyword arguments to be passed to `trainer_fn`
    683 """
    684 try:
--> 685     return trainer_fn(*args, **kwargs)
    686 # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7
    687 except KeyboardInterrupt as exception:

File [~/mambaforge/envs/contrastiveVI_env/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:777](http://localhost:8888/~/mambaforge/envs/contrastiveVI_env/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py#line=776), in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    775 # TODO: ckpt_path only in v1.7
    776 ckpt_path = ckpt_path or self.resume_from_checkpoint
--> 777 self._run(model, ckpt_path=ckpt_path)
    779 assert self.state.stopped
    780 self.training = False

File [~/mambaforge/envs/contrastiveVI_env/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:1138](http://localhost:8888/~/mambaforge/envs/contrastiveVI_env/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py#line=1137), in Trainer._run(self, model, ckpt_path)
   1136 self.call_hook("on_before_accelerator_backend_setup")
   1137 self.accelerator.setup_environment()
-> 1138 self._call_setup_hook()  # allow user to setup lightning_module in accelerator environment
   1140 # check if we should delay restoring checkpoint till later
   1141 if not self.training_type_plugin.restore_checkpoint_after_pre_dispatch:

File [~/mambaforge/envs/contrastiveVI_env/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:1438](http://localhost:8888/~/mambaforge/envs/contrastiveVI_env/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py#line=1437), in Trainer._call_setup_hook(self)
   1435 self.training_type_plugin.barrier("pre_setup")
   1437 if self.datamodule is not None:
-> 1438     self.datamodule.setup(stage=fn)
   1439 self.call_hook("setup", stage=fn)
   1441 self.training_type_plugin.barrier("post_setup")

File [~/mambaforge/envs/contrastiveVI_env/lib/python3.11/site-packages/pytorch_lightning/core/datamodule.py:461](http://localhost:8888/~/mambaforge/envs/contrastiveVI_env/lib/python3.11/site-packages/pytorch_lightning/core/datamodule.py#line=460), in LightningDataModule._track_data_hook_calls.<locals>.wrapped_fn(*args, **kwargs)
    459     else:
    460         attr = f"_has_{name}_{stage}"
--> 461         has_run = getattr(obj, attr)
    462         setattr(obj, attr, True)
    464 elif name == "prepare_data":

AttributeError: 'ContrastiveDataSplitter' object has no attribute '_has_setup_TrainerFn.FITTING'
ethanweinberger commented 8 months ago

Hi @hamid-bolouri,

Thanks for reaching out. We've added contrastiveVI to the main scvi-tools repository (see tutorial notebook here), and are no longer actively maintaining this repo. As for this specific dataset, I've modified the Colab notebook to instead import contrastiveVI from scvi-tools, and it appears to train successfully.

Let me know if you run into any more issues.

hamid-bolouri commented 8 months ago

That's great! Thanks for getting back to me. --Hamid

hoeflerb commented 6 months ago

Hi @ethanweinberger,

Do you have any plans to include totalContrastiveVI in scvi-tools as well?