Closed sofiagilardini closed 12 months ago
Thank you for your contribution. We require contributors to sign our Contributor License Agreement (CLA). We do not have a signed CLA on file for you. In order for us to review and merge your code, please sign our CLA here. After you signed, you can comment on this PR with @cla-bot check
to trigger another check.
@cla-bot check
Thanks for tagging me. I looked for a signed form under your signature again, and updated the status on this PR. If the check was successful, no further action is needed. If the check was unsuccessful, please see the instructions in my first comment.
Hi @sofiagilardini, thanks a lot for flagging (and fixing!!) this.
A good addition to this PR would be a unit test that fails in the old CEBRA version, and works after applying your fix.
Would you be interested in contributing this test yourself, or should we add it to the PR based on your description?
Hi @stes,
I'd be happy to add it. Would you like the test to go in CEBRA/tests/test_sklearn.py ?
Is there a specific pytest decorator you suggest?
Hi @sofiagilardini , thanks a lot!
Adding to CEBRA/tests/test_sklearn.py
is perfect. In case it is not clear from the docs/Makefile, this here:
python -m pytest tests/test_sklearn.py::test_yours
is probably the quickest way for testing on your end. If you want to run more of the test suite, I would suggest to use the -m "not requires_dataset"
for some speed gains.
If you can think of multiple test cases, using a pytest.mark.parametrize
might be useful, otherwise no special requirements as long as the test reproduces your bug in the original version of the code, and this PRs makes the test pass.
Thanks again!
Hi @stes,
I've added the test, it can be run by doing: python -m pytest tests/test_sklearn.py::test_fit_after_moving_to_device
.
The code in this PR passes the test.
The same test fails if you checkout the previous version of the sklearn integration:
git checkout cda6e11eb89828ade70d3c342fff8bb955cb2b69 -- ./cebra/integrations/sklearn/cebra.py
Let me know if there are any issues :)
Sofia
@gonlairo please code review ASAP.
@gonlairo please code review ASAP.
lgtm, ready to merge imo
Thanks @sofiagilardini for the contribution!
Hi all, thank you very much for this great work.
I have been using the latest stable release of Cebra for my work on a summer research project. I recently switched to a new MacBook, so today I tried using main branch as it has MPS compatibility. I came across an issue with the recently added
to()
method within the CEBRA class for sklearn integration.From my understanding:
The recently implemented
to()
method within the CEBRA class in sklearn integration has a bug in the type of the class attributesdevice
anddevice_
. The two bugs are:device.startswith()
, which assumes device is of type str, yet the input type hint isdevice: Union[str, torch.device]
self.device
(andself.device_
if existent) to a torch.device object. This causes downstream errors when calling other class methods such aspartial_fit()
after having calledto()
, because the method_prepare_fit()
callssklearn_utils.check_device(self.device)
which fails whenself.device
is not of type str.The proposed fix modifies the logic for type checking and sets class attribute
self.device
(andself.device_
if existent) to a str object always. In this way, calling theto()
method does not break existing working code.