KevinMusgrave / pytorch-metric-learning

The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.
https://kevinmusgrave.github.io/pytorch-metric-learning/
MIT License
6.01k stars 658 forks source link

100G GPU Memory occupation of AngularLoss #329

Closed Mactarvish closed 3 years ago

Mactarvish commented 3 years ago

I'm training a resnet18 using AngularLoss on a dataset containing 100000 images with size(256, 512) including 4922 classes, but the error shows : RuntimeError: CUDA out of memory. Tried to allocate 100.11 GiB (GPU 0; 15.90 GiB total capacity; 5.50 GiB already allocated; 9.43 GiB free; 5.65 GiB reserved in total by PyTorch)

Other loss functions are working well, here's my main code: ` loss_func = losses.AngularLoss() mining_func = miners.AngularMiner()

...
        for key in output.keys():
            if key == "embedding":
                indices_tuple = self.hard_case_miner(output["embedding"], batch["labels"])
                loss = self.criterion[key](output["embedding"], batch["labels"], indices_tuple)
....
        loss.backward()
        self.optimizer.step()

`

Please show me where's wrong, thanks.

KevinMusgrave commented 3 years ago

That is certainly a lot of memory!

By "other loss functions are working well", do you mean other loss functions in this library? Which ones specifically work well?

Also what version of this library are you using?

Mactarvish commented 3 years ago

Yes, I've tried ArcFace loss, triplet loss, large margin softmax loss, etc, and they are all works well (wtith batch size 128, image size 256*512).

_version__ = "0.9.98"

Mactarvish commented 3 years ago

Or it would be appriciated if there're any example code using Angular loss.

KevinMusgrave commented 3 years ago

The usage is exactly the same as with TripletMarginLoss.

I will have to look into this to see if there's a memory issue with AngularLoss

Mactarvish commented 3 years ago

Have you found what raises this problem?

KevinMusgrave commented 3 years ago

Sorry I haven't had time for this yet

ibebrett commented 3 years ago

I'm happy to take a look. This library does a lot of things that we'd like to do, and it would be nice to dig into it....

KevinMusgrave commented 3 years ago

@ibebrett That would be great! I'm guessing this line could be the cause of the excessive memory usage: https://github.com/KevinMusgrave/pytorch-metric-learning/blob/9559b21559ca6fbcb46d2d51d7953166e18f9de6/src/pytorch_metric_learning/losses/angular_loss.py#L35-L37

ibebrett commented 3 years ago

Thanks! Will look this afternoon.

ibebrett commented 3 years ago

Started looking at this today, but got sidetracked. Target is to have this tomorrow. Sorry for the delay!

ibebrett commented 3 years ago

One thing to note though is that locally for me:

$ python -m unittest tests/losses/test_angular_loss.py
Testing pytorch_metric_learning version 0.9.99 with pytorch version 1.8.1
TEST_DTYPES=[torch.float16, torch.float32, torch.float64], TEST_DEVICE=cuda, WITH_COLLECT_STATS=True
F.
======================================================================
FAIL: test_angular_loss (tests.losses.test_angular_loss.TestAngularLoss)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/brett/work/pytorch-metric-learning/tests/losses/test_angular_loss.py", line 59, in test_angular_loss
    self.assertTrue(torch.isclose(loss, total_loss))
AssertionError: tensor(False, device='cuda:0') is not true

----------------------------------------------------------------------
Ran 2 tests in 1.543s

FAILED (failures=1)

the tests for this loss don't pass. I'm sure its something on my side, I'll dig in tomorrow. (this is on master)

KevinMusgrave commented 3 years ago

Hmm, can you try testing only float32?

TEST_DTYPES=float32 python -m unittest tests/losses/test_angular_loss.py

Also please git checkout the "dev" branch (I'll add this to the readme).

KevinMusgrave commented 3 years ago

Re: angular loss, this is the equation from the paper: image

It should be computing a loss for each "anchor", and then dividing by the total number of anchors. In the paper, they assume a single positive pair per anchor, so for a batch of size N you have N/2 positive pairs. The current AngularLoss seems to work for that case, but I assumed that it would work for arbitrary numbers of positive pairs. For example, it should work when there are 10 positive pairs per element in the batch.

I made a small change to the unit test for AngularLoss and realized that the loss function is computing the loss per positive pair, rather than per anchor. (The new version of the test is on the dev branch and it fails.)

I tried using the PerAnchorReducer, but that doesn't work because it is being applied after the logsumexp step.

KevinMusgrave commented 3 years ago

I changed the unit test to compute the loss per positive pair instead of per anchor so that it at least passes for now.

I can't think of a straightforward solution right now. To be fair, I don't think there is anything "wrong" about computing the loss per positive pair, and the original paper doesn't cover the case where there are multiple positives per element anyway.

ibebrett commented 3 years ago

I reran the tests on master and still got a failure:

 TEST_DTYPES=float32 python -m unittest tests/losses/test_angular_loss.py

Testing pytorch_metric_learning version 0.9.99 with pytorch version 1.8.1
TEST_DTYPES=[torch.float32], TEST_DEVICE=cuda, WITH_COLLECT_STATS=True
F.
======================================================================
FAIL: test_angular_loss (tests.losses.test_angular_loss.TestAngularLoss)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/brett/work/pytorch-metric-learning/tests/losses/test_angular_loss.py", line 59, in test_angular_loss
    self.assertTrue(torch.isclose(loss, total_loss))
AssertionError: tensor(False, device='cuda:0') is not true

----------------------------------------------------------------------
Ran 2 tests in 27.414s

and also tried on the dev branch:

TEST_DTYPES=float32 python -m unittest tests/losses/test_angular_loss.py
Testing pytorch_metric_learning version 1.0.0.dev0 with pytorch version 1.8.1
TEST_DTYPES=[torch.float32], TEST_DEVICE=cuda, WITH_COLLECT_STATS=True
F.
======================================================================
FAIL: test_angular_loss (tests.losses.test_angular_loss.TestAngularLoss)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/brett/work/pytorch-metric-learning/tests/losses/test_angular_loss.py", line 59, in test_angular_loss
    self.assertTrue(torch.isclose(loss, total_loss))
AssertionError: tensor(False, device='cuda:0') is not true

----------------------------------------------------------------------
ibebrett commented 3 years ago

Apologies I lied!

After pulling your changes to dev my tests now pass with all dtypes! Thanks

ibebrett commented 3 years ago

Started looking at this EOD today, will finish monday, but initially im looking at some profiling that shows a spike in logsumexp.

       5.67G          8.17G   10 def logsumexp(x, keep_mask=None, add_one=True, dim=1):                                  
       5.67G          8.17G   11     if keep_mask is not None:                                                           
       8.79G         10.67G   12         x = x.masked_fill(~keep_mask, c_f.neg_inf(x.dtype))                             
       8.79G         10.67G   13     if add_one:                                                                         
       8.79G         10.67G   14         zeros = torch.zeros(x.size(dim - 1), dtype=x.dtype, device=x.device).unsqueeze( 
       8.79G         10.67G   15             dim                                                                         
                              16         )                                                                               
      11.29G         13.17G   17         x = torch.cat([x, zeros], dim=dim)                                              
                              18                                                                                         
      13.80G         18.18G   19     output = torch.logsumexp(x, dim=dim, keepdim=True)                                  
       8.80G         18.18G   20     if keep_mask is not None:                                                           
       8.80G         18.18G   21         output = output.masked_fill(~torch.any(keep_mask, dim=dim, keepdim=True), 0)    
       8.80G         18.18G   22     return output         

That being said my setup for this profile is quite naive.. I took your test case and just multiplied some lengths. The embedding dim is 2. There are a couple of steps that look suspicious to me but I have to dig in a bit further, but for this case the matmul step ends up creating a matrix of size [558800, 1200] when my batch size is just 1200.

Here is my "test script"

loss_func = AngularLoss(alpha=40)
dev = torch.device('cuda')

N = 200
embedding_angles = [0, 20, 40, 60, 80, 100] * N
embeddings = torch.tensor(
    [c_f.angle_to_coord(a) for a in embedding_angles],
    requires_grad=True,
    dtype=torch.float32,
).to(
    dev
)  # 2D embeddings
labels = torch.LongTensor([0, 0, 1, 1, 2, 0] * N)

loss = loss_func(embeddings, labels)

Anyway, sorry I keep delaying this, but I started getting into it later today. I want to check my understanding of what actually should be summed up, but let me know if my test script resembles the actual failure case. Maybe since the batch size is 1200 it is quite reasonable for this to explode when looking at all of these triples. I appreciate being given the time to work this out!

KevinMusgrave commented 3 years ago

That being said my setup for this profile is quite naive.. I took your test case and just multiplied some lengths. The embedding dim is 2. There are a couple of steps that look suspicious to me but I have to dig in a bit further, but for this case the matmul step ends up creating a matrix of size [558800, 1200] when my batch size is just 1200.

Yeah the first dim (558800) is the number of positive pairs, and that can't really be changed since the default behavior should be to use all positive pairs in the batch. The second dim (1200) has unnecessary elements because some of those columns get masked out in the logsumexp step.

Anyway, sorry I keep delaying this, but I started getting into it later today.

No problem, this isn't super urgent, and any contribution is greatly appreciated.

I want to check my understanding of what actually should be summed up, but let me know if my test script resembles the actual failure case. Maybe since the batch size is 1200 it is quite reasonable for this to explode when looking at all of these triples. I appreciate being given the time to work this out!

For this issue, I wonder if it would be better to use an actual model with inputs. I'm not sure if the memory usage would spike at a different place because of the autograd graph. I guess it's worth trying.

import torch.nn as nn

class Model(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, 1024)
        self.fc2 = nn.Linear(1024, out_dim)

    def forward(self, x):
        x = F.relu(x)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

batch_size = 32
in_dim = 2048
out_dim = 128
model = Model(in_dim, out_dim)
data = torch.nn.randn(batch_size, in_dim)
labels = torch.randint(low=0, high=4, size=(batch_size,))
loss_func = AngularLoss(alpha=40)

model.train()
embeddings = model(data)
loss = loss_func(embeddings, labels)
ibebrett commented 3 years ago

@Mactarvish , would you mind posting a more complete example, perhaps in a gist? I am new to the library and I'm having a hard time reconstructing your problem. It may be obvious with more context though. (Or perhaps @KevinMusgrave , you know that it fits into a particular example or something)?

Re: angular loss, this is the equation from the paper: image

It should be computing a loss for each "anchor", and then dividing by the total number of anchors. In the paper, they assume a single positive pair per anchor, so for a batch of size N you have N/2 positive pairs. The current AngularLoss seems to work for that case, but I assumed that it would work for arbitrary numbers of positive pairs. For example, it should work when there are 10 positive pairs per element in the batch.

I noticed this as well, and I'm wondering if it would reduce memory consumption to do it per anchor rather than per pair. You end up with the log sum over each pairing, across each negative. We'd go from n^2 to linear in terms of the size of the tensor for that step if we did it per anchor. I suppose if later it gets summed up anyway that doesn't effect much?

Thanks for letting me work on this issue. I wanted to get into the metric learning stuff and this has been a really nice way to start learning the terminology etc!

ibebrett commented 3 years ago

Side note also, I believe this code basically assumes that a "negative" is any other embedding, regardless of label. Looking at the loss function, it doesn't look to me like the loss would magically become irrelevant in that case.

I could be completely misinterpreting it though..

KevinMusgrave commented 3 years ago

@Mactarvish , would you mind posting a more complete example, perhaps in a gist? I am new to the library and I'm having a hard time reconstructing your problem. It may be obvious with more context though. (Or perhaps @KevinMusgrave , you know that it fits into a particular example or something)?

I assumed the problem involved a convnet with images, but yeah more details from @Mactarvish would definitely help.

I noticed this as well, and I'm wondering if it would reduce memory consumption to do it per anchor rather than per pair. You end up with the log sum over each pairing, across each negative. We'd go from n^2 to linear in terms of the size of the tensor for that step if we did it per anchor. I suppose if later it gets summed up anyway that doesn't effect much?

Good point. But I'm not sure how to condense the f_apn expression such that the logsumexp step uses an NxN matrix.

Thanks for letting me work on this issue. I wanted to get into the metric learning stuff and this has been a really nice way to start learning the terminology etc!

Glad to hear that 👍

Side note also, I believe this code basically assumes that a "negative" is any other embedding, regardless of label. Looking at the loss function, it doesn't look to me like the loss would magically become irrelevant in that case.

Negatives should have a different label from the anchor. This is what keep_mask is for: https://github.com/KevinMusgrave/pytorch-metric-learning/blob/52bb21ad925115651cba5ddc1f88235caf16d116/src/pytorch_metric_learning/losses/angular_loss.py#L43

Mactarvish commented 3 years ago

Sorry I haven't review this problem so far, and I just found what's the matter. The miner just mines about 800000 pairs on one batch so that there's a quite huge consumption for the loss. So I changed the angle threshold to a bigger value in order to keep a reasonable number of mined samples, so the problem's vanished. Thanks anyway~

ibebrett commented 3 years ago

Yeah, we construct a matrix of all the triples. If you are passing in something that has 800000 pairs, we have are going to end up with a giant matrix. I'm not so experienced in this library, and maybe there are tricks around this, but at least how its implemented now this is expected.