ndif-team / nnsight

The nnsight package enables interpreting and manipulating the internals of deep learned models.
https://nnsight.net/
MIT License
345 stars 34 forks source link

Llama-70-B causes GPU memory leak without explicit garbage collection #169

Open arunasank opened 2 months ago

arunasank commented 2 months ago

I am working on a project that uses nnsight's trace to implement attribution patching. I found that in certain cases triggering an explicit garbage collection after calling nnsight's trace method is the only way to prevent GPU memory leaks -- this seems related to circular object references that aren't cleared when the torch cache is emptied or when known variables are explicitly deleted.

I think @saprmarks has also encountered this issue when working on feature circuits. Also cc-ing @Butanium who mentioned adding .value fixes it in the scenario they encountered.

nnsight-llama.zip contains the necessary files to reproduce this leak.

I've commented out lines 112-114 because they make no difference to the memory leak when uncommented.

Please let me know if more info is needed.

cc @JadenFiotto-Kaufman

jkminder commented 2 months ago

Hi, I've encountered a similar issue (I'm the collaborator of @Butanium). What I believe happens is that if you don't call .value, the whole computation graph is sometimes still kept in memory. I don't have a good enough understanding of the internals of nnsight to know exactly what happens but what helped for me is making sure that no cache of activations is a pointer to the computation graph (namely making sure I only store .values and not the proxies). Skimming through your code, I assume you have a similar issue. Your attn_layer_cache_prompt and mlp_layer_cache_prompt store the proxies and not the values. I assume that if you insert a loop after line 89 going through all layers in both caches and converting the proxies to tensors (through calling .value)/only store the values in the dicts would solve the problem.

My very non-grounded hypothesis what happens here is: You store the proxies in the dicts. Upon leaving the scope of your function the garbage collector is not immediately triggered (maybe to reuse the allocated memory in the next loop iteration? don't know too much about the intricacies of python gc). This leaves to pointers to the computation graphs alive, which means the graph is still kept in VRAM. Triggering the gc, cleans up the dicts, which also frees the VRAM.

JadenFiotto-Kaufman commented 2 months ago

Hey @arunasank. I recreated part of your script below (to not have to worry about a config or zip file):

from nnsight import LanguageModel
import torch
model = LanguageModel("meta-llama/Llama-2-7b-hf", device_map='auto', dispatch=True, torch_dtype=torch.float16) # Load the model

mlp_effects_cache = torch.zeros((model.config.num_hidden_layers, model.config.hidden_size)).to("cuda")
attn_effects_cache = torch.zeros((model.config.num_hidden_layers, model.config.hidden_size)).to("cuda")

def test():

  attn_layer_cache_prompt = {}

  mlp_layer_cache_prompt = {}

  attn_layer_cache_patch = {}

  mlp_layer_cache_patch = {}

  with model.trace("hello", scan=False, validate=False) as tracer:
      for layer in range(len(model.model.layers)):
          self_attn = model.model.layers[layer].self_attn.o_proj.output
          mlp = model.model.layers[layer].mlp.down_proj.output
          mlp.retain_grad()
          self_attn.retain_grad()

          attn_layer_cache_prompt[layer] = {"forward": self_attn.save()} 
          mlp_layer_cache_prompt[layer] = {"forward": mlp.save()}

      logits = model.lm_head.output.save()

  loss = logits[:, -1, 0] - logits[:, -1, 0]
  loss = loss.sum()
  loss.backward()

  with model.trace("hello", scan=False, validate=False) as tracer:
      for layer in range(len(model.model.layers)):
          self_attn = model.model.layers[layer].self_attn.o_proj.output
          mlp = model.model.layers[layer].mlp.down_proj.output

          attn_layer_cache_patch[layer] = {"forward": self_attn.save()}
          mlp_layer_cache_patch[layer] = {"forward": mlp.save()}

  for layer in range(len(model.model.layers)):
    mlp_effects = (mlp_layer_cache_prompt[layer]["forward"].value.grad * (mlp_layer_cache_patch[layer]["forward"].value - mlp_layer_cache_prompt[layer]["forward"].value)).detach()
    attn_effects = (attn_layer_cache_prompt[layer]["forward"].value.grad * (attn_layer_cache_patch[layer]["forward"].value - attn_layer_cache_prompt[layer]["forward"].value)).detach()

    mlp_effects = mlp_effects[0, -1, :] # batch, token, hidden_states
    attn_effects = attn_effects[0, -1, :] # batch, token, hidden_states

    mlp_effects_cache[layer] += mlp_effects.to(mlp_effects_cache.get_device())
    attn_effects_cache[layer] += attn_effects.to(attn_effects_cache.get_device())

for i in range(100):

  test()
  print(torch.cuda.memory_reserved(0))

I dont see memory requirements increase. Can you make sure youre using the latest version of nnsight? Otherwise I'm not sure this is an nnsight problem. Is there something I am missing?

arunasank commented 2 months ago

@JadenFiotto-Kaufman It does seem to happen on the dataset + model combination I shared, so it's definitely not a data independent thing! The memory usage drops to 40G (about half of what it is without garbage collection) when I set explicit garbage collection! EDIT: To clarify, this dataset caused my GPU to crash because of an OOM error, but this does happen even with other datasets as well, it just did not surface in that case.

@jkminder Thank you for sharing these. Hopefully it will help us debug the issue.