Open SunMarc opened 1 month ago
I think the purpose here was to allow any class that has a to
method, tensordict being just one example of that.
For some reason, this works on pytorch nightlies on my machine:
@torch.compile(fullgraph=True)
def func(x):
if hasattr(x, "to"):
return x.to("cpu")
return x
func(torch.randn(3))
EDIT: I can reprod with 2.3 so I think this will be solved in the next major of PyTorch, would that work for you @SunMarc ?
raise Unsupported(msg)
torch._dynamo.exc.Unsupported: hasattr TensorVariable to
cc @anijain2305: dynamo doesn't support hasattr
In your case, you are testing with a tensor, not sure if this will work with another data type (On my example, it was failing with ConstDictVariable). Thanks for the ping !
I think the purpose here was to allow any class that has a to method, tensordict being just one example of that.
Yes, i understand. If there is no good solution, I was just thinking on making sure that at least it works with tensordict !
Oh yeah insteresting, this fails:
import torch
from tensordict import TensorDict
@torch.compile(fullgraph=True)
def func(x):
if hasattr(x, "to"):
return x.to("cpu")
return x
func(dict(a=torch.randn(3)))
but this runs
func(TensorDict(a=torch.randn(3)))
so it seems that dynamo only likes hasarttr
if it's True lol.
Say this is patched in 2.5 for instance, would that solve the issue?
Say this is patched in 2.5 for instance, would that solve the issue?
Yeah ! This is the only issue I have right now in order to make big model inference (multi-gpu) + torch.compile works together. I tried to remove that and it was working fine on torch nightly. The goal for us would be to enable this starting from torch 2.5 since with torch 2.4, there were other errors that I didn't manage to debug. LMK if this is something that will be fixed on torch 2.5 !
System Info
Information
Tasks
no_trainer
script in theexamples
folder of thetransformers
repo (such asrun_no_trainer_glue.py
)Reproduction
Issue when compiling a model that has been dispatch to multiple gpus.
Error encountered:
This was something we added to enable accelerate to work with tensordict. cc @vmoens Is there a way to make it work without having to use
hasattr
. One solution could be to make tensordict library an optional dependency and check if we indeed have a tensordict ? related issue : https://github.com/huggingface/accelerate/issues/2405cc @muellerzr
To reproduce:
Expected behavior
compile as expected