AdaptiveMotorControlLab / CEBRA

Learnable latent embeddings for joint behavioral and neural analysis - Official implementation of CEBRA
https://cebra.ai
Other
884 stars 72 forks source link

Fix type bug in to() method #55

Closed sofiagilardini closed 12 months ago

sofiagilardini commented 1 year ago

Hi all, thank you very much for this great work.

I have been using the latest stable release of Cebra for my work on a summer research project. I recently switched to a new MacBook, so today I tried using main branch as it has MPS compatibility. I came across an issue with the recently added to() method within the CEBRA class for sklearn integration.

From my understanding:

The recently implemented to() method within the CEBRA class in sklearn integration has a bug in the type of the class attributes device and device_. The two bugs are:

The proposed fix modifies the logic for type checking and sets class attribute self.device (and self.device_ if existent) to a str object always. In this way, calling the to() method does not break existing working code.

cla-bot[bot] commented 1 year ago

Thank you for your contribution. We require contributors to sign our Contributor License Agreement (CLA). We do not have a signed CLA on file for you. In order for us to review and merge your code, please sign our CLA here. After you signed, you can comment on this PR with @cla-bot check to trigger another check.

sofiagilardini commented 1 year ago

@cla-bot check

cla-bot[bot] commented 1 year ago

Thanks for tagging me. I looked for a signed form under your signature again, and updated the status on this PR. If the check was successful, no further action is needed. If the check was unsuccessful, please see the instructions in my first comment.

stes commented 1 year ago

Hi @sofiagilardini, thanks a lot for flagging (and fixing!!) this.

A good addition to this PR would be a unit test that fails in the old CEBRA version, and works after applying your fix.

Would you be interested in contributing this test yourself, or should we add it to the PR based on your description?

sofiagilardini commented 1 year ago

Hi @stes,

I'd be happy to add it. Would you like the test to go in CEBRA/tests/test_sklearn.py ?

Is there a specific pytest decorator you suggest?

stes commented 1 year ago

Hi @sofiagilardini , thanks a lot!

Adding to CEBRA/tests/test_sklearn.py is perfect. In case it is not clear from the docs/Makefile, this here:

python -m pytest tests/test_sklearn.py::test_yours

is probably the quickest way for testing on your end. If you want to run more of the test suite, I would suggest to use the -m "not requires_dataset" for some speed gains.

If you can think of multiple test cases, using a pytest.mark.parametrize might be useful, otherwise no special requirements as long as the test reproduces your bug in the original version of the code, and this PRs makes the test pass.

Thanks again!

sofiagilardini commented 1 year ago

Hi @stes,

I've added the test, it can be run by doing: python -m pytest tests/test_sklearn.py::test_fit_after_moving_to_device.

The code in this PR passes the test.

The same test fails if you checkout the previous version of the sklearn integration: git checkout cda6e11eb89828ade70d3c342fff8bb955cb2b69 -- ./cebra/integrations/sklearn/cebra.py

Let me know if there are any issues :)

Sofia

MMathisLab commented 1 year ago

@gonlairo please code review ASAP.

gonlairo commented 1 year ago

@gonlairo please code review ASAP.

lgtm, ready to merge imo

stes commented 12 months ago

Thanks @sofiagilardini for the contribution!