topal-team / rockmate

GNU General Public License v3.0
30 stars 4 forks source link

GPT-2 high measured memory footprint #4

Closed haoming-codes closed 6 months ago

haoming-codes commented 6 months ago

Hi Theotime, Xunyi and all, this is Haoming from ICML 2023. Hope all is well, and happy holidays! I've recently started to work on rematerialization again, and would like to ask a few questions about the results in the paper.

Are the GPT-2 results in the paper measured (by calling torch.cuda functions) or theoretical (by calling Rockmate.expect_mem())? Asking because I am not getting close to the <1Gb (measured) footprint with GPT-2-medium reported in the last table of the appendix. Am I doing something wrong?

model = get_GPT(model="GPT2-medium"),
sample = torch.randint(0, 600, (4, 512))

torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
rk_out, rk_grad = train_batch(model, sample, y_true)
original_measured_memory = torch.cuda.max_memory_allocated()
m_budget = int(original_measured_memory * multiplier)

rk_model = Rockmate(
    model, 
    sample, 
    m_budget)

del model
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
rk_out, rk_grad = train_batch(rk_model, sample, y_true)
remat_measured_memory = torch.cuda.max_memory_allocated()
del rk_model

With multiplier = .1, I'm getting

original_measured_memory = 10.21 GB
remat_measured_memory = 5.07 GB
rk_model.expect_mem() = 1.01 GB
XunyiZhao commented 6 months ago

Hi Haoming! Nice to hear from you!

Regarding your question: yes, we measure the results by calling torch.cuda functions. Since we are doing rematerialization, all the memory results and functions (including m_budget and rk_model.expect_mem()) correspond to the activation memory size. In practice, we first measure the size of the model with gradients and then subtract it from the peak memory (the details can be found in test/utils.py).

In your example, I believe torch.cuda.max_memory_allocated() will measure the size including the model itself. If I'm not wrong, the size of GPT2-medium is about 350M parameters, which cost ~1.4GB for the model and ~1.4GB for gradients ( assuming there is no optimizer states like momentum/variance for Adam). This part of the memory cost remains the same within the current Rockmate framework. If the goal is to train it with 1GB cost in total, I'm afraid rematerialization will not be enough.

Since I cannot see what's inside train_batch(), I cannot tell if Rockmate is called correctly. As a reminder, the current Rockmate model needs rk_model.backward() to be called during each iteration, which might cause problems. If you could share your definition of train_batch() I may try to reproduce the problem and explain it in more detail.

Hope that helps you! Happy holidays and looking forward to discussing more!

haoming-codes commented 6 months ago

Thanks Xunyi, that's very helpful. My train_batch() looks like

def train_batch(model, sample, y_true):
    optimizer = torch.optim.Adam(model.parameters())
    optimizer.zero_grad()
    y = model(sample)
    out = copy.deepcopy(y.detach())
    loss = nn.MSELoss()(y, y_true)
    loss.backward()
    if isinstance(model, Rockmate):
        model.backward()
    optimizer.zero_grad()
    return out

But I think you answered my question. Basically, I should expect model_size + activation size (produced by a schedule w/ or w/o remat) = total_size (as measured by torch.cuda.max_memory_allocated()). Is this theoretically accurate, or are there other components not accounted for?

XunyiZhao commented 6 months ago

Yes, that should be correct if model_size represents all the parameter-related memory. Note that Adam will also cost some memory for optimizer states if you call optimizer.step(). Also, model_size is not a constant during the iteration if .zero_grad() is called at the end/beginning, since parameter gradients are generated and kept during the backward. We do not call zero_grad() during each iteration in our experiments to exclude the effects of dynamic gradients memory.

haoming-codes commented 6 months ago

That all makes sense. Thank you very much!