Open IvanYashchuk opened 4 months ago
A prototype for this idea lives here https://github.com/Lightning-AI/lightning-thunder/pull/607
__torch_dispatch__
is only needed for PyTorch to understand subclasses. NVFuser is irrelevant here because it works with the storage ptrs directly.
__torch_dispatch__
is only needed for PyTorch to understand subclasses. NVFuser is irrelevant here because it works with the storage ptrs directly.
PyTorch also doesn't seem to require it to work correctly but the result is wrapped with the subclass:
In [1]: import torch
In [2]: class BasicTensorSubclass(torch.Tensor):
...: """
...: A tensor subclass that does own explicit storage (created with _make_wrapper_subclass),
...: has a reference to a torch.Tensor instance, and dispatches to torch.Tensor operations.
...: """
...:
...: @staticmethod
...: def __new__(cls, t: torch.Tensor | None):
...: # Allow to propagate None - this is useful in the context of torch.autograd
...: if t is None:
...: return None
...:
...: while isinstance(t, BasicTensorSubclass):
...: t = t.tensor_obj
...:
...: res = torch.Tensor._make_subclass(cls, t, t.requires_grad)
...: res.tensor_obj = t
...: return res
...:
In [3]: a = torch.randn(3, 3, device="cuda", requires_grad=False)
In [4]: aa = BasicTensorSubclass(a)
In [5]: aa + aa
Out[5]:
BasicTensorSubclass([[-1.7147, 1.4739, 3.4224],
[ 1.7706, 0.3625, 0.0954],
[-1.0441, -2.4179, -1.9087]], device='cuda:0')
In [6]: torch.sum(aa)
Out[6]: BasicTensorSubclass(0.0197, device='cuda:0')
I think it's needed only for the updated computation to be seen by AOT Autograd, PyTorch's vmap, and all other dispatcher-based systems.
Why do we need to remove
clear_mutable_collection
which is also just a verbose way of callingdel list[:]
?
I would like to have a mechanism that gives some control over which tensors are alive and which are not. Packing everything into a list that is being deleted in a symbol is a black hole - once we pass things into the callable, we loose them. Also, it will not work if a function is being wrapped with some flatten/unflatten input strategy. We will sure clear a list, but we will not release the memory because the outer function will hold the references.
What will the trace look like after the idea is implemented?
We either remove the symbol altogether, or leave it such that it does nothing on saved_tensors
, but at the same time hints the user of what is being potentially done.
How should the input to the backward function be preprocessed?
Nothing should be done to the inputs.
How can this custom wrapper of PyTorch tensor work universally with all different extensions that are not registered for
__torch_function__
or__torch_dispatch__
?
The underlying storage is the same, it is just a python object which is a bit altered. I see no issues there
Packing everything into a list that is being deleted in a symbol is a black hole - once we pass things into the callable, we loose them. Also, it will not work if a function is being wrapped with some flatten/unflatten input strategy. We will sure clear a list, but we will not release the memory because the outer function will hold the references.
I agree, current list clearing won't work if the generated backward function is wrapped into something that changes the arguments. The generated backward trace is meant to be used only in ThunderFunction.backward
, it shouldn't be used somewhere independently even though currently users can easily query it and call the corresponding Python callable with thunder.last_backward_traces
. In a normal Thunder trace I don't think we would be clearing lists unless it's present in the user script itself, it would be quite an unexpected thing to see your list content disappear after calling the Thunder-generated function.
I'm supportive of using more of the Tensor subclasses (a private issue about the use of async tensors returned from backward).
How should the input to the backward function be preprocessed?
Nothing should be done to the inputs.
Well, I meant how "saved for backward tensors" should be preprocessed. They need to be wrapped in a custom class first at least.
We either remove the symbol altogether, or leave it such that it does nothing on saved_tensors, but at the same time hints the user of what is being potentially done.
Sounds good. If you'd like to remove the "clear mutable collections" op from the backward trace just this line needs to be modified: https://github.com/Lightning-AI/lightning-thunder/blob/8309fc015c1ea9d684337cdcf95138524cd8645c/thunder/executors/torch_autograd.py#L269
Thank you for taking the time and answering my questions!
triage review —
🚀 Memory clearing mechanism with torch.autograd.Function integration
Today we pass "saved for backward" tensors to the generated backward function inside a Python list and within the generated function we clear the list after it's unpacked. This is required to remove references to tensors and let the CUDA memory be freed as soon as the tensor is not needed anymore later in the function with
del
calls. The unavoidable Python behavior is to hold a reference to all function arguments until the end of the function, any tensor passed as an argument or part of immutable containers passed as an argument like tuples cannot be freed, and a reference to that still exists even after adel
call. This memory problem was fixed in https://github.com/Lightning-AI/lightning-thunder/commit/afd69e9cf03ccc9e7290a8cb477828258c79103b.Another way how people sometimes achieve the same effect is by swapping the
.data
attribute, here's one example from TransformerEngine. However, this variant depends on internal PyTorch attributes.@nikitaved has an idea of implementing a special object wrapping PyTorch tensors that would delete the reference to the tensor when
del
is called with the goal of removingclear_mutable_collection
from the trace. See the comment here https://github.com/Lightning-AI/lightning-thunder/pull/596#discussion_r1639853031.clear_mutable_collection
which is also just a verbose way of callingdel list[:]
?__torch_function__
or__torch_dispatch__
?Click here to see how the backward trace looks today.
When printing the backward trace you can notice `clear_mutable_collection` after a collection is unpacked: ```py In [1]: import torch; import thunder; In [2]: @thunder.jit ...: def func(a): ...: for _ in range(2): ...: a = a @ a ...: return a ...: In [3]: a = torch.randn(3, 3, device="cuda", requires_grad=True) In [4]: func(a); In [5]: thunder.last_backward_traces(func)[-1] ``` ```py def backward_fn(saved_for_backward, cotangents): # saved_for_backward: "Collection" # cotangents: "Collection" C0, _, = saved_for_backward clear_mutable_collection(saved_for_backward) del saved_for_backward t2, = cotangents clear_mutable_collection(cotangents) del cotangents a, t0, = C0 clear_mutable_collection(C0) del C0 t13 = torch.permute(a, (1, 0)) # t13: "cuda:0 f32[3, 3]" # t13 = ltorch.permute(a, (1, 0)) # t13: "cuda:0 f32[3, 3]" # t13 = prims.transpose(a, (1, 0)) # t13: "cuda:0 f32[3, 3]" del a t8 = torch.permute(t0, (1, 0)) # t8: "cuda:0 f32[3, 3]" # t8 = ltorch.permute(t0, (1, 0)) # t8: "cuda:0 f32[3, 3]" # t8 = prims.transpose(t0, (1, 0)) # t8: "cuda:0 f32[3, 3]" del t0 t9 = torch.matmul(t2, t8) # t9: "cuda:0 f32[3, 3]" # t9 = ltorch.matmul(t2, t8) # t9: "cuda:0 f32[3, 3]" # t9 = prims.matmul(t2, t8) # t9: "cuda:0 f32[3, 3]" t11 = torch.matmul(t8, t2) # t11: "cuda:0 f32[3, 3]" # t11 = ltorch.matmul(t8, t2) # t11: "cuda:0 f32[3, 3]" # t11 = prims.matmul(t8, t2) # t11: "cuda:0 f32[3, 3]" del t8, t2 [t12] = nvFusion0(t11, t9) # t12 = prims.add(t9, t11) # t12: "cuda:0 f32[3, 3]" del t11, t9 t14 = torch.matmul(t12, t13) # t14: "cuda:0 f32[3, 3]" # t14 = ltorch.matmul(t12, t13) # t14: "cuda:0 f32[3, 3]" # t14 = prims.matmul(t12, t13) # t14: "cuda:0 f32[3, 3]" t16 = torch.matmul(t13, t12) # t16: "cuda:0 f32[3, 3]" # t16 = ltorch.matmul(t13, t12) # t16: "cuda:0 f32[3, 3]" # t16 = prims.matmul(t13, t12) # t16: "cuda:0 f32[3, 3]" del t13, t12 [t17] = nvFusion1(t14, t16) # t17 = prims.add(t14, t16) # t17: "cuda:0 f32[3, 3]" del t14, t16 return (t17,) ```Here's documentation about reference count in Python:
(from https://docs.python.org/3/c-api/intro.html#reference-counts) means that even if we delete a variable that was passed an argument inside our Python function a reference to the object still exists until the return statement. We can't free the memory of tensors passed as arguments until we exit the function.