Stonesjtu / pytorch_memlab

Profiling and inspecting memory in pytorch
MIT License
1.01k stars 37 forks source link

How to use it for memory consumption of optimizer states (e.g, Adam) and batch data? #61

Closed TOTTO27149 closed 1 month ago

TOTTO27149 commented 1 month ago

Hi, thank you for this very helpful repo. Just wondering, what would be the ideal way to use it to understand the memory consumption of optimizer states (e.g., Adam) and batch data?

I am very confused about how the memory is consumed during training, e.g., is it mostly consumed by the batch data, or Adam also occupies a lot.

It is quite a fortune that I found this repo but it is not very clear how to apply it.

TOTTO27149 commented 1 month ago

Or specifically, what's the difference between using pytorch_memlab and torch.cuda.memory_allocated()?

Stonesjtu commented 1 month ago

Or specifically, what's the difference between using pytorch_memlab and torch.cuda.memory_allocated()?

This tool uses the data returned from torch.cuda.memory_allocated actually. So the results should be identical.

Furthermore, it tries to give each memory partition a tensor name (tensor or corresponding gradient) for memory reporter.

Stonesjtu commented 1 month ago

Hi, thank you for this very helpful repo. Just wondering, what would be the ideal way to use it to understand the memory consumption of optimizer states (e.g., Adam) and batch data?

I am very confused about how the memory is consumed during training, e.g., is it mostly consumed by the batch data, or Adam also occupies a lot.

It is quite a fortune that I found this repo but it is not very clear how to apply it.

If you want to dive into the memory consumption of Adam optimizer, I think you'll need to hack a little about the MemReporter to accept Optimizer as input argument.

TOTTO27149 commented 1 month ago

Hi, thank you for this very helpful repo. Just wondering, what would be the ideal way to use it to understand the memory consumption of optimizer states (e.g., Adam) and batch data? I am very confused about how the memory is consumed during training, e.g., is it mostly consumed by the batch data, or Adam also occupies a lot. It is quite a fortune that I found this repo but it is not very clear how to apply it.

If you want to dive into the memory consumption of Adam optimizer, I think you'll need to hack a little about the MemReporter to accept Optimizer as input argument.

Wow, thank you for your prompt reply! I will certainly try.

What I am doing is to first try pytorch profiler, see how the result compares to torch.cuda.memory_allocated - do you think this would work somewhat similarly to using MemReporter?

Stonesjtu commented 1 month ago

memory_allocated can only give overall memory consumption, rather than the detailed (per tensor,per module) I think. Sent from my iPhoneOn Jul 30, 2024, at 8:37 PM, TOTTO27149 @.***> wrote:

Hi, thank you for this very helpful repo. Just wondering, what would be the ideal way to use it to understand the memory consumption of optimizer states (e.g., Adam) and batch data? I am very confused about how the memory is consumed during training, e.g., is it mostly consumed by the batch data, or Adam also occupies a lot. It is quite a fortune that I found this repo but it is not very clear how to apply it.

If you want to dive into the memory consumption of Adam optimizer, I think you'll need to hack a little about the MemReporter to accept Optimizer as input argument.

Wow, thank you for your prompt reply! I will certainly try. What I am doing is to first try pytorch profiler, see how the result compares to torch.cuda.memory_allocated - do you think this would work somewhat similarly to using MemReporter?

—Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you commented.Message ID: @.***>

Stonesjtu commented 1 month ago

I've hacked around and getting these lines:

test/test_mem_reporter.py Element type                                            Size  Used MEM
-------------------------------------------------------------------------------
Storage on cpu
weight                                          (1024, 1024)    16.00M
bias                                                 (1024,)    16.00K
Tensor0                                          (512, 1024)     8.00M
Tensor1                                                 (1,)   512.00B
-------------------------------------------------------------------------------
Total Tensors: 1573889  Used Memory: 24.02M
-------------------------------------------------------------------------------
Element type                                            Size  Used MEM
-------------------------------------------------------------------------------
Storage on cpu
weight                                          (1024, 1024)    16.00M
weight.grad                                     (1024, 1024)    16.00M
bias                                                 (1024,)    16.00K
bias.grad                                            (1024,)    16.00K
Tensor0                                          (512, 1024)     8.00M
Tensor1                                                 (1,)   512.00B
weight.grad                                     (1024, 1024)     0.00B
Adam.weight.exp_avg                             (1024, 1024)    16.00M
Adam.weight.exp_avg_sq                          (1024, 1024)    16.00M
bias.grad                                            (1024,)     0.00B
Adam.bias.exp_avg                                    (1024,)    16.00K
Adam.bias.exp_avg_sq                                 (1024,)    16.00K
Adam.weight.step                                        (1,)   512.00B
Adam.bias.step                                          (1,)   512.00B
-------------------------------------------------------------------------------
Total Tensors: 5772291  Used Memory: 72.06M
-------------------------------------------------------------------------------
.

which is produced by this test function:

def test_reporter_with_optimizer():
    linear = torch.nn.Linear(1024, 1024)
    inp = torch.Tensor(512, 1024)
    optimizer = torch.optim.Adam(linear.parameters())
    # reporter = MemReporter(linear)

    out = linear(inp).mean()
    # out = linear(inp*(inp+3)*(inp+2)).mean()
    reporter = MemReporter(linear)
    reporter.report()
    out.backward()
    # reporter.report()
    optimizer.step()

    reporter.add_optimizer(optimizer)
    reporter.report()

Is this what you expected?