Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.16k stars 76 forks source link

increase of GPU memory footprint #216

Closed mpatel31415 closed 5 months ago

mpatel31415 commented 5 months ago

🐛 Bug

With newest version of Docker image (tested on 2024-04-05 and later) training with thunder.jit with default executors now needs more GPU memory comparing to image with tag pjnl-pipeline-13351802-amd64.

To Reproduce

Before each testing each compilation method I restarted the container:

mkdir -p output
docker run --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864  -v $PWD/output:/output -it INTERNAL_ADDRESS:5005/dl/pytorch/update-scripts:pjnl-latest

Eager

Expected behavior

Memory used when compiling with Thunder should be similar to cases when using inductor or eager.

Environment

As in the Docker image. This results come from NVIDIA RTX A6000, but similar behaviour has observed on H100.

Additional context

Similar behaviour is visible for other models, batch sizes and multi-gpu training.

cc @apaz-cli

t-vi commented 5 months ago
diff --git a/thunder/core/interpreter.py b/thunder/core/interpreter.py
index 6cf18a24..85ee0372 100644
--- a/thunder/core/interpreter.py
+++ b/thunder/core/interpreter.py
@@ -5979,11 +5979,11 @@ def _interpret_call(fn: Callable | WrappedValue, /, *args, **kwargs) -> Any | IN
     runtimectx: InterpreterRuntimeCtx = get_interpreterruntimectx()

     # TODO: Implement generics and fix WrappedValue[T] everywhere.
-    runtimectx.record_interpreter_call(fn)  # type: ignore
+    # runtimectx.record_interpreter_call(fn)  # type: ignore
     rval = _call_dispatch(compilectx, runtimectx, fn, *args, **kwargs)  # type: ignore
     if compilectx._with_provenance_tracking:
         assert isinstance(rval, (INTERPRETER_SIGNALS, WrappedValue)), f"return {rval} unexpected calling {unwrap(fn)}"
-    runtimectx.record_interpreter_return(fn, rval)  # type: ignore
+    #runtimectx.record_interpreter_return(fn, rval)  # type: ignore

     return rval

@@ -6140,7 +6140,7 @@ def _call_dispatch(
         # Happens with sharp edges, for example
         return lookaside_fn
     if lookaside_fn:
-        runtimectx.record_lookaside(lookaside_fn)
+        #runtimectx.record_lookaside(lookaside_fn)
         res = lookaside_fn(*args, **kwargs)
         return res

@@ -6160,7 +6160,7 @@ def _call_dispatch(

     # (4) Handles opaque functions
     if is_opaque(fn):
-        runtimectx.record_opaque_call(fn)
+        #runtimectx.record_opaque_call(fn)
         args_ = [unwrap(a) for a in args]
         kwargs_ = {unwrap(k): unwrap(v) for k, v in kwargs.items()}
         try:
@@ -6374,7 +6374,7 @@ def _run_frame(
             # Updates the stack frame to the current position
             # TODO maybe also have inst_ptr?
             frame.nexti(inst)
-            runtimectx.record_interpreted_instruction(inst)
+            #runtimectx.record_interpreted_instruction(inst)
             skip_stack_effect_check: bool = False  # the exception handling will change the stack wildly
             stack_size_before_handler: int = len(stack)

seems to give 5.3GB

apaz-cli commented 5 months ago

@mpatel31415 Looking into and fixing. I think I've identified the issue.

runtimectx.record_interpreter_call(fn)

This prevents the garbage collector from letting go of fn.__closures__, which may have tensors in them, so the tensors never get deallocated and release their VRAM. Oops.

Talking to @t-vi, I think we're also somewhat unhappy with the amount of cpu RAM used. So rather than fixing just this case, we've decided to turn off recording history by default. You'll be able to turn it back on with a flag. thunder.jit(fn, record_history=True) perhaps? Open to bikeshedding on the name. History being off by default also lets us capture the args, kwargs, and return values that we were afraid to before.

PR forthcoming, along with the PR for allowing duplicates in the executor list again.

parthmannan commented 5 months ago

Thanks @t-vi and @apaz-cli ! I guess the issue has been identified but I'll post my debugging analysis below anyway for reference -

Found the commit where the memory usage in Thunder goes substantially higher. https://github.com/Lightning-AI/lightning-thunder/commit/dba8ce7a9f417b9eb365a305d7611ab1618c2856

Before this commit

LLaMa2 7B FSDP ZeRO2 (8xH100) - 42 GB
Dolly V2 12B (num_layers=24) FSDP ZeRO2 (8xH100) - 34.77 GB

After this commit

LLaMa2 7B FSDP ZeRO2 (8xH100) - 78 GB
Dolly V2 12B (num_layers=24) FSDP ZeRO2 (8xH100) - 62.74 GB
apaz-cli commented 5 months ago

I didn't really know, it was just a guess. Thank you very much @parthmannan :heart:

apaz-cli commented 5 months ago

@mpatel31415 @parthmannan Should be resolved now, as of #239.

parthmannan commented 5 months ago

Yep, things are back to expected. Thanks a lot for the quick fix @apaz-cli ! @mpatel31415 Closing this as it seems resolved. Let us know if the next benchmarking runs still show any issues across workloads.

t-vi commented 5 months ago

@mpatel31415 Thank you for reporting this with a clear repro and expectations, this helped a lot.

StephennFernandes commented 5 months ago

@mpatel31415 hey, i am pretty new to pytorch lightning thunder, so apologies for asking noobish questions.

All i see is exceptionally increased acceleration in Thunder using H100 GPUs, are there considerable performance gain even in A6000 GPUs ?