berenslab / t-simcne

Unsupervised visualization of image datasets using contrastive learning
https://t-simcne.readthedocs.io/en/latest/
115 stars 12 forks source link

Definition of `torch.diagonal_copy` #2

Closed zhijie-yang closed 1 year ago

zhijie-yang commented 1 year ago

Hi,

I'm wondering what is the definition of torch.diagonal_copy, as there is nowhere to find it in the official pytorch document.

Much thanks in advance.

https://github.com/berenslab/t-simcne/blob/3e0062f7fb2fa33975c90ebf0a332bc324d627b1/tsimcne/losses/infonce.py#L72C36-L72C49

jnboehm commented 1 year ago

Hi,

good catch, I didn't realize that it's not part of the documentation.

For what it's worth, it is part of torch though (and even has a docstring). In my python, I get the following:

In [1]: import torch

In [2]: torch.diagonal_copy?
Docstring:
Performs the same operation as :func:`torch.diagonal`, but all output tensors
are freshly created instead of aliasing the input.
Type:      builtin_function_or_method

In [3]: torch.__version__
Out[3]: '2.0.1'

It was already available before the 2.0 release though, but I am not entirely sure what version I used it on. As to why it's not part of the torch documentation, I have to say that I don't know. Might be good to raise an issue there.

zhijie-yang commented 1 year ago

Hi Jan Niklas,

it's great to have your rapid reply! With my pytorch of following versions python=3.8.12, pytorch=1.11.0 and cuda=11.7, calling into this function yields an attribute error: module 'torch' has no attribute 'diagonal_copy'.

However, thanks to the hint you've provided, I have made a temporary fix by replacing torch.diagonal_copy(...) with torch.diagonal(torch.clone(...)).

I will test to see if they are equivalent, and I can make a pull request if you find it necessary.

jnboehm commented 1 year ago

I see, then it was probably introduced in pytorch 1.12 or so. Still a bit odd that it's not documented. In any case, I think it could also be changed to

tempered_alignment = torch.diagonal(sim_ab).log().mean()

In that case the call to log will copy the tensor (since diagonal is non-copying by default IIRC). Sorry, that part of the code hasn't been touched in a while, so it's not optimal.