Open bdhirsh opened 4 months ago
FWIW - we found out that DTensor mutates attributes directly in its embedding
lowering (@wanchaol can explain this better, but something about DTensor wanting to temporarily cache a computed mask value so it can pass it into a later op rather than having to recompute it again (expensive)).
Not majorly blocking today because this only surfaces for DTensor if there is a graph break between the regions where DTensor computes the mask and releases it, but a good datapoint.
I think one way to make this work for the DTensor embedding case is that, when we do tensor_flatten, we explicitly check for those placements where it could encode the data and flatten it out, then unflatten would encode that flatten tensor data back to that placement. This would at least resolve dynamo graph break issue I guesss
I think one way to make this work for the DTensor embedding case is that, when we do tensor_flatten, we explicitly check for those placements where it could encode the data and flatten it out, then unflatten would encode that flatten tensor data back to that placement.
I think that would be strictly more correct on the DTensor side - but it would still have the problem that if you have a compiled region with an input dtensor
, where:
(1) on graph entry, dtensor._masked_data = None
and type(dtensor.placements[0]) != _MaskPartial)
(2) on graph exit,dtensor._masked_data = some_data_tensor
and type(dtensor.placements[0]) == _MaskPartial)
at runtime, we have no way today of properly mutating the attributes on the input dtensor
today (this is essentially a type of "metadata mutation" of a graph input). Since we will have inlined through the DTensor code that sets that mask attribute at compile time, and at runtime this code will not run.
Filing this issue for tracking depending on user request.
Normally, the way that
setattr
on graph inputs is handled is that dynamo tracks attribute mutations on graph inputs as side effects, and replays the bytecode for them later.However - if a tensor subclass setattr's some of its attributes from inside of
__torch_dispatch__
, dynamo will never see it and cannot preserve the side effect. This behavior is banned today under the "traceability" requirement of tensor subclasses that want to work with compile.This is a placeholder issue to explore if / to what extent we could lift this restriction, for attributes that are part of the __tensor_flatten__ contract.
One idea: In AOTAutograd: we can stash the python id of each attribute on all subclass inputs, and check to see if any of them changed after tracing during the analysis pass. This isn't quite enough for two reasons:
(1) if running the analysis pass tweaked the attributes of one of our example inputs, then our example inputs will now be incorrect when we run the actual tracing pass with those same example inputs. We would need to unwind the attribute mutation (2) if the new attribute is a tensor that came from a graph intermediate, we will also need to properly make it a graph output, so that we can grab it at runtime after executing the compiled graph and perform the attribute mutation in an epilogue.
cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @ezyang @albanD @chauhang @penguinwu @zou3519 @yf225 @anijain2305