Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.19k stars 80 forks source link

Revise memory clearing mechanism in the torch.autograd.Function integration #606

Open IvanYashchuk opened 4 months ago

IvanYashchuk commented 4 months ago

🚀 Memory clearing mechanism with torch.autograd.Function integration

Today we pass "saved for backward" tensors to the generated backward function inside a Python list and within the generated function we clear the list after it's unpacked. This is required to remove references to tensors and let the CUDA memory be freed as soon as the tensor is not needed anymore later in the function with del calls. The unavoidable Python behavior is to hold a reference to all function arguments until the end of the function, any tensor passed as an argument or part of immutable containers passed as an argument like tuples cannot be freed, and a reference to that still exists even after a del call. This memory problem was fixed in https://github.com/Lightning-AI/lightning-thunder/commit/afd69e9cf03ccc9e7290a8cb477828258c79103b.

Another way how people sometimes achieve the same effect is by swapping the .data attribute, here's one example from TransformerEngine. However, this variant depends on internal PyTorch attributes.

@nikitaved has an idea of implementing a special object wrapping PyTorch tensors that would delete the reference to the tensor when del is called with the goal of removing clear_mutable_collection from the trace. See the comment here https://github.com/Lightning-AI/lightning-thunder/pull/596#discussion_r1639853031.

Click here to see how the backward trace looks today. When printing the backward trace you can notice `clear_mutable_collection` after a collection is unpacked: ```py In [1]: import torch; import thunder; In [2]: @thunder.jit ...: def func(a): ...: for _ in range(2): ...: a = a @ a ...: return a ...: In [3]: a = torch.randn(3, 3, device="cuda", requires_grad=True) In [4]: func(a); In [5]: thunder.last_backward_traces(func)[-1] ``` ```py def backward_fn(saved_for_backward, cotangents): # saved_for_backward: "Collection" # cotangents: "Collection" C0, _, = saved_for_backward clear_mutable_collection(saved_for_backward) del saved_for_backward t2, = cotangents clear_mutable_collection(cotangents) del cotangents a, t0, = C0 clear_mutable_collection(C0) del C0 t13 = torch.permute(a, (1, 0)) # t13: "cuda:0 f32[3, 3]" # t13 = ltorch.permute(a, (1, 0)) # t13: "cuda:0 f32[3, 3]" # t13 = prims.transpose(a, (1, 0)) # t13: "cuda:0 f32[3, 3]" del a t8 = torch.permute(t0, (1, 0)) # t8: "cuda:0 f32[3, 3]" # t8 = ltorch.permute(t0, (1, 0)) # t8: "cuda:0 f32[3, 3]" # t8 = prims.transpose(t0, (1, 0)) # t8: "cuda:0 f32[3, 3]" del t0 t9 = torch.matmul(t2, t8) # t9: "cuda:0 f32[3, 3]" # t9 = ltorch.matmul(t2, t8) # t9: "cuda:0 f32[3, 3]" # t9 = prims.matmul(t2, t8) # t9: "cuda:0 f32[3, 3]" t11 = torch.matmul(t8, t2) # t11: "cuda:0 f32[3, 3]" # t11 = ltorch.matmul(t8, t2) # t11: "cuda:0 f32[3, 3]" # t11 = prims.matmul(t8, t2) # t11: "cuda:0 f32[3, 3]" del t8, t2 [t12] = nvFusion0(t11, t9) # t12 = prims.add(t9, t11) # t12: "cuda:0 f32[3, 3]" del t11, t9 t14 = torch.matmul(t12, t13) # t14: "cuda:0 f32[3, 3]" # t14 = ltorch.matmul(t12, t13) # t14: "cuda:0 f32[3, 3]" # t14 = prims.matmul(t12, t13) # t14: "cuda:0 f32[3, 3]" t16 = torch.matmul(t13, t12) # t16: "cuda:0 f32[3, 3]" # t16 = ltorch.matmul(t13, t12) # t16: "cuda:0 f32[3, 3]" # t16 = prims.matmul(t13, t12) # t16: "cuda:0 f32[3, 3]" del t13, t12 [t17] = nvFusion1(t14, t16) # t17 = prims.add(t14, t16) # t17: "cuda:0 f32[3, 3]" del t14, t16 return (t17,) ```

Here's documentation about reference count in Python:

the call mechanism guarantees to hold a reference to every argument for the duration of the call.

(from https://docs.python.org/3/c-api/intro.html#reference-counts) means that even if we delete a variable that was passed an argument inside our Python function a reference to the object still exists until the return statement. We can't free the memory of tensors passed as arguments until we exit the function.

IvanYashchuk commented 4 months ago

A prototype for this idea lives here https://github.com/Lightning-AI/lightning-thunder/pull/607

nikitaved commented 4 months ago

__torch_dispatch__ is only needed for PyTorch to understand subclasses. NVFuser is irrelevant here because it works with the storage ptrs directly.

IvanYashchuk commented 4 months ago

__torch_dispatch__ is only needed for PyTorch to understand subclasses. NVFuser is irrelevant here because it works with the storage ptrs directly.

PyTorch also doesn't seem to require it to work correctly but the result is wrapped with the subclass:

In [1]: import torch

In [2]: class BasicTensorSubclass(torch.Tensor):
   ...:     """
   ...:     A tensor subclass that does own explicit storage (created with _make_wrapper_subclass),
   ...:     has a reference to a torch.Tensor instance, and dispatches to torch.Tensor operations.
   ...:     """
   ...: 
   ...:     @staticmethod
   ...:     def __new__(cls, t: torch.Tensor | None):
   ...:         # Allow to propagate None - this is useful in the context of torch.autograd
   ...:         if t is None:
   ...:             return None
   ...: 
   ...:         while isinstance(t, BasicTensorSubclass):
   ...:             t = t.tensor_obj
   ...: 
   ...:         res = torch.Tensor._make_subclass(cls, t, t.requires_grad)
   ...:         res.tensor_obj = t
   ...:         return res
   ...: 

In [3]: a = torch.randn(3, 3, device="cuda", requires_grad=False)

In [4]: aa = BasicTensorSubclass(a)

In [5]: aa + aa
Out[5]: 
BasicTensorSubclass([[-1.7147,  1.4739,  3.4224],
                     [ 1.7706,  0.3625,  0.0954],
                     [-1.0441, -2.4179, -1.9087]], device='cuda:0')

In [6]: torch.sum(aa)
Out[6]: BasicTensorSubclass(0.0197, device='cuda:0')

I think it's needed only for the updated computation to be seen by AOT Autograd, PyTorch's vmap, and all other dispatcher-based systems.

nikitaved commented 4 months ago

Why do we need to remove clear_mutable_collection which is also just a verbose way of calling del list[:]?

I would like to have a mechanism that gives some control over which tensors are alive and which are not. Packing everything into a list that is being deleted in a symbol is a black hole - once we pass things into the callable, we loose them. Also, it will not work if a function is being wrapped with some flatten/unflatten input strategy. We will sure clear a list, but we will not release the memory because the outer function will hold the references.

What will the trace look like after the idea is implemented?

We either remove the symbol altogether, or leave it such that it does nothing on saved_tensors, but at the same time hints the user of what is being potentially done.

How should the input to the backward function be preprocessed?

Nothing should be done to the inputs.

How can this custom wrapper of PyTorch tensor work universally with all different extensions that are not registered for __torch_function__ or __torch_dispatch__?

The underlying storage is the same, it is just a python object which is a bit altered. I see no issues there

IvanYashchuk commented 4 months ago

Packing everything into a list that is being deleted in a symbol is a black hole - once we pass things into the callable, we loose them. Also, it will not work if a function is being wrapped with some flatten/unflatten input strategy. We will sure clear a list, but we will not release the memory because the outer function will hold the references.

I agree, current list clearing won't work if the generated backward function is wrapped into something that changes the arguments. The generated backward trace is meant to be used only in ThunderFunction.backward, it shouldn't be used somewhere independently even though currently users can easily query it and call the corresponding Python callable with thunder.last_backward_traces. In a normal Thunder trace I don't think we would be clearing lists unless it's present in the user script itself, it would be quite an unexpected thing to see your list content disappear after calling the Thunder-generated function.

I'm supportive of using more of the Tensor subclasses (a private issue about the use of async tensors returned from backward).

How should the input to the backward function be preprocessed?

Nothing should be done to the inputs.

Well, I meant how "saved for backward tensors" should be preprocessed. They need to be wrapped in a custom class first at least.

We either remove the symbol altogether, or leave it such that it does nothing on saved_tensors, but at the same time hints the user of what is being potentially done.

Sounds good. If you'd like to remove the "clear mutable collections" op from the backward trace just this line needs to be modified: https://github.com/Lightning-AI/lightning-thunder/blob/8309fc015c1ea9d684337cdcf95138524cd8645c/thunder/executors/torch_autograd.py#L269

Thank you for taking the time and answering my questions!

mruberry commented 4 months ago

triage review —