Closed Mactarvish closed 3 years ago
That is certainly a lot of memory!
By "other loss functions are working well", do you mean other loss functions in this library? Which ones specifically work well?
Also what version of this library are you using?
Yes, I've tried ArcFace loss, triplet loss, large margin softmax loss, etc, and they are all works well (wtith batch size 128, image size 256*512).
_version__ = "0.9.98"
Or it would be appriciated if there're any example code using Angular loss.
The usage is exactly the same as with TripletMarginLoss.
I will have to look into this to see if there's a memory issue with AngularLoss
Have you found what raises this problem?
Sorry I haven't had time for this yet
I'm happy to take a look. This library does a lot of things that we'd like to do, and it would be nice to dig into it....
@ibebrett That would be great! I'm guessing this line could be the cause of the excessive memory usage: https://github.com/KevinMusgrave/pytorch-metric-learning/blob/9559b21559ca6fbcb46d2d51d7953166e18f9de6/src/pytorch_metric_learning/losses/angular_loss.py#L35-L37
Thanks! Will look this afternoon.
Started looking at this today, but got sidetracked. Target is to have this tomorrow. Sorry for the delay!
One thing to note though is that locally for me:
$ python -m unittest tests/losses/test_angular_loss.py
Testing pytorch_metric_learning version 0.9.99 with pytorch version 1.8.1
TEST_DTYPES=[torch.float16, torch.float32, torch.float64], TEST_DEVICE=cuda, WITH_COLLECT_STATS=True
F.
======================================================================
FAIL: test_angular_loss (tests.losses.test_angular_loss.TestAngularLoss)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/home/brett/work/pytorch-metric-learning/tests/losses/test_angular_loss.py", line 59, in test_angular_loss
self.assertTrue(torch.isclose(loss, total_loss))
AssertionError: tensor(False, device='cuda:0') is not true
----------------------------------------------------------------------
Ran 2 tests in 1.543s
FAILED (failures=1)
the tests for this loss don't pass. I'm sure its something on my side, I'll dig in tomorrow. (this is on master)
Hmm, can you try testing only float32?
TEST_DTYPES=float32 python -m unittest tests/losses/test_angular_loss.py
Also please git checkout the "dev" branch (I'll add this to the readme).
Re: angular loss, this is the equation from the paper:
It should be computing a loss for each "anchor", and then dividing by the total number of anchors. In the paper, they assume a single positive pair per anchor, so for a batch of size N you have N/2 positive pairs. The current AngularLoss seems to work for that case, but I assumed that it would work for arbitrary numbers of positive pairs. For example, it should work when there are 10 positive pairs per element in the batch.
I made a small change to the unit test for AngularLoss and realized that the loss function is computing the loss per positive pair, rather than per anchor. (The new version of the test is on the dev branch and it fails.)
I tried using the PerAnchorReducer, but that doesn't work because it is being applied after the logsumexp step.
I changed the unit test to compute the loss per positive pair instead of per anchor so that it at least passes for now.
I can't think of a straightforward solution right now. To be fair, I don't think there is anything "wrong" about computing the loss per positive pair, and the original paper doesn't cover the case where there are multiple positives per element anyway.
I reran the tests on master and still got a failure:
TEST_DTYPES=float32 python -m unittest tests/losses/test_angular_loss.py
Testing pytorch_metric_learning version 0.9.99 with pytorch version 1.8.1
TEST_DTYPES=[torch.float32], TEST_DEVICE=cuda, WITH_COLLECT_STATS=True
F.
======================================================================
FAIL: test_angular_loss (tests.losses.test_angular_loss.TestAngularLoss)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/home/brett/work/pytorch-metric-learning/tests/losses/test_angular_loss.py", line 59, in test_angular_loss
self.assertTrue(torch.isclose(loss, total_loss))
AssertionError: tensor(False, device='cuda:0') is not true
----------------------------------------------------------------------
Ran 2 tests in 27.414s
and also tried on the dev branch:
TEST_DTYPES=float32 python -m unittest tests/losses/test_angular_loss.py
Testing pytorch_metric_learning version 1.0.0.dev0 with pytorch version 1.8.1
TEST_DTYPES=[torch.float32], TEST_DEVICE=cuda, WITH_COLLECT_STATS=True
F.
======================================================================
FAIL: test_angular_loss (tests.losses.test_angular_loss.TestAngularLoss)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/home/brett/work/pytorch-metric-learning/tests/losses/test_angular_loss.py", line 59, in test_angular_loss
self.assertTrue(torch.isclose(loss, total_loss))
AssertionError: tensor(False, device='cuda:0') is not true
----------------------------------------------------------------------
Apologies I lied!
After pulling your changes to dev my tests now pass with all dtypes! Thanks
Started looking at this EOD today, will finish monday, but initially im looking at some profiling that shows a spike in logsumexp.
5.67G 8.17G 10 def logsumexp(x, keep_mask=None, add_one=True, dim=1):
5.67G 8.17G 11 if keep_mask is not None:
8.79G 10.67G 12 x = x.masked_fill(~keep_mask, c_f.neg_inf(x.dtype))
8.79G 10.67G 13 if add_one:
8.79G 10.67G 14 zeros = torch.zeros(x.size(dim - 1), dtype=x.dtype, device=x.device).unsqueeze(
8.79G 10.67G 15 dim
16 )
11.29G 13.17G 17 x = torch.cat([x, zeros], dim=dim)
18
13.80G 18.18G 19 output = torch.logsumexp(x, dim=dim, keepdim=True)
8.80G 18.18G 20 if keep_mask is not None:
8.80G 18.18G 21 output = output.masked_fill(~torch.any(keep_mask, dim=dim, keepdim=True), 0)
8.80G 18.18G 22 return output
That being said my setup for this profile is quite naive.. I took your test case and just multiplied some lengths. The embedding dim is 2. There are a couple of steps that look suspicious to me but I have to dig in a bit further, but for this case the matmul step ends up creating a matrix of size [558800, 1200] when my batch size is just 1200.
Here is my "test script"
loss_func = AngularLoss(alpha=40)
dev = torch.device('cuda')
N = 200
embedding_angles = [0, 20, 40, 60, 80, 100] * N
embeddings = torch.tensor(
[c_f.angle_to_coord(a) for a in embedding_angles],
requires_grad=True,
dtype=torch.float32,
).to(
dev
) # 2D embeddings
labels = torch.LongTensor([0, 0, 1, 1, 2, 0] * N)
loss = loss_func(embeddings, labels)
Anyway, sorry I keep delaying this, but I started getting into it later today. I want to check my understanding of what actually should be summed up, but let me know if my test script resembles the actual failure case. Maybe since the batch size is 1200 it is quite reasonable for this to explode when looking at all of these triples. I appreciate being given the time to work this out!
That being said my setup for this profile is quite naive.. I took your test case and just multiplied some lengths. The embedding dim is 2. There are a couple of steps that look suspicious to me but I have to dig in a bit further, but for this case the matmul step ends up creating a matrix of size [558800, 1200] when my batch size is just 1200.
Yeah the first dim (558800) is the number of positive pairs, and that can't really be changed since the default behavior should be to use all positive pairs in the batch. The second dim (1200) has unnecessary elements because some of those columns get masked out in the logsumexp step.
Anyway, sorry I keep delaying this, but I started getting into it later today.
No problem, this isn't super urgent, and any contribution is greatly appreciated.
I want to check my understanding of what actually should be summed up, but let me know if my test script resembles the actual failure case. Maybe since the batch size is 1200 it is quite reasonable for this to explode when looking at all of these triples. I appreciate being given the time to work this out!
For this issue, I wonder if it would be better to use an actual model with inputs. I'm not sure if the memory usage would spike at a different place because of the autograd graph. I guess it's worth trying.
import torch.nn as nn
class Model(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.fc1 = nn.Linear(in_dim, 1024)
self.fc2 = nn.Linear(1024, out_dim)
def forward(self, x):
x = F.relu(x)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
return x
batch_size = 32
in_dim = 2048
out_dim = 128
model = Model(in_dim, out_dim)
data = torch.nn.randn(batch_size, in_dim)
labels = torch.randint(low=0, high=4, size=(batch_size,))
loss_func = AngularLoss(alpha=40)
model.train()
embeddings = model(data)
loss = loss_func(embeddings, labels)
@Mactarvish , would you mind posting a more complete example, perhaps in a gist? I am new to the library and I'm having a hard time reconstructing your problem. It may be obvious with more context though. (Or perhaps @KevinMusgrave , you know that it fits into a particular example or something)?
Re: angular loss, this is the equation from the paper:
It should be computing a loss for each "anchor", and then dividing by the total number of anchors. In the paper, they assume a single positive pair per anchor, so for a batch of size N you have N/2 positive pairs. The current AngularLoss seems to work for that case, but I assumed that it would work for arbitrary numbers of positive pairs. For example, it should work when there are 10 positive pairs per element in the batch.
I noticed this as well, and I'm wondering if it would reduce memory consumption to do it per anchor rather than per pair. You end up with the log sum over each pairing, across each negative. We'd go from n^2 to linear in terms of the size of the tensor for that step if we did it per anchor. I suppose if later it gets summed up anyway that doesn't effect much?
Thanks for letting me work on this issue. I wanted to get into the metric learning stuff and this has been a really nice way to start learning the terminology etc!
Side note also, I believe this code basically assumes that a "negative" is any other embedding, regardless of label. Looking at the loss function, it doesn't look to me like the loss would magically become irrelevant in that case.
I could be completely misinterpreting it though..
@Mactarvish , would you mind posting a more complete example, perhaps in a gist? I am new to the library and I'm having a hard time reconstructing your problem. It may be obvious with more context though. (Or perhaps @KevinMusgrave , you know that it fits into a particular example or something)?
I assumed the problem involved a convnet with images, but yeah more details from @Mactarvish would definitely help.
I noticed this as well, and I'm wondering if it would reduce memory consumption to do it per anchor rather than per pair. You end up with the log sum over each pairing, across each negative. We'd go from n^2 to linear in terms of the size of the tensor for that step if we did it per anchor. I suppose if later it gets summed up anyway that doesn't effect much?
Good point. But I'm not sure how to condense the f_apn
expression such that the logsumexp step uses an NxN matrix.
Thanks for letting me work on this issue. I wanted to get into the metric learning stuff and this has been a really nice way to start learning the terminology etc!
Glad to hear that 👍
Side note also, I believe this code basically assumes that a "negative" is any other embedding, regardless of label. Looking at the loss function, it doesn't look to me like the loss would magically become irrelevant in that case.
Negatives should have a different label from the anchor. This is what keep_mask
is for:
https://github.com/KevinMusgrave/pytorch-metric-learning/blob/52bb21ad925115651cba5ddc1f88235caf16d116/src/pytorch_metric_learning/losses/angular_loss.py#L43
Sorry I haven't review this problem so far, and I just found what's the matter. The miner just mines about 800000 pairs on one batch so that there's a quite huge consumption for the loss. So I changed the angle threshold to a bigger value in order to keep a reasonable number of mined samples, so the problem's vanished. Thanks anyway~
Yeah, we construct a matrix of all the triples. If you are passing in something that has 800000 pairs, we have are going to end up with a giant matrix. I'm not so experienced in this library, and maybe there are tricks around this, but at least how its implemented now this is expected.
I'm training a resnet18 using AngularLoss on a dataset containing 100000 images with size(256, 512) including 4922 classes, but the error shows : RuntimeError: CUDA out of memory. Tried to allocate 100.11 GiB (GPU 0; 15.90 GiB total capacity; 5.50 GiB already allocated; 9.43 GiB free; 5.65 GiB reserved in total by PyTorch)
Other loss functions are working well, here's my main code: ` loss_func = losses.AngularLoss() mining_func = miners.AngularMiner()
`
Please show me where's wrong, thanks.