pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
84.41k stars 22.74k forks source link

Support python attribute mutations of tensor subclass graph inputs from within __torch_dispatch__ #130028

Open bdhirsh opened 4 months ago

bdhirsh commented 4 months ago

Filing this issue for tracking depending on user request.

Normally, the way that setattr on graph inputs is handled is that dynamo tracks attribute mutations on graph inputs as side effects, and replays the bytecode for them later.

However - if a tensor subclass setattr's some of its attributes from inside of __torch_dispatch__, dynamo will never see it and cannot preserve the side effect. This behavior is banned today under the "traceability" requirement of tensor subclasses that want to work with compile.

This is a placeholder issue to explore if / to what extent we could lift this restriction, for attributes that are part of the __tensor_flatten__ contract.

One idea: In AOTAutograd: we can stash the python id of each attribute on all subclass inputs, and check to see if any of them changed after tracing during the analysis pass. This isn't quite enough for two reasons:

(1) if running the analysis pass tweaked the attributes of one of our example inputs, then our example inputs will now be incorrect when we run the actual tracing pass with those same example inputs. We would need to unwind the attribute mutation (2) if the new attribute is a tensor that came from a graph intermediate, we will also need to properly make it a graph output, so that we can grab it at runtime after executing the compiled graph and perform the attribute mutation in an epilogue.

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @ezyang @albanD @chauhang @penguinwu @zou3519 @yf225 @anijain2305

bdhirsh commented 4 months ago

FWIW - we found out that DTensor mutates attributes directly in its embedding lowering (@wanchaol can explain this better, but something about DTensor wanting to temporarily cache a computed mask value so it can pass it into a later op rather than having to recompute it again (expensive)).

Not majorly blocking today because this only surfaces for DTensor if there is a graph break between the regions where DTensor computes the mask and releases it, but a good datapoint.

wanchaol commented 4 months ago

I think one way to make this work for the DTensor embedding case is that, when we do tensor_flatten, we explicitly check for those placements where it could encode the data and flatten it out, then unflatten would encode that flatten tensor data back to that placement. This would at least resolve dynamo graph break issue I guesss

bdhirsh commented 4 months ago

I think one way to make this work for the DTensor embedding case is that, when we do tensor_flatten, we explicitly check for those placements where it could encode the data and flatten it out, then unflatten would encode that flatten tensor data back to that placement.

I think that would be strictly more correct on the DTensor side - but it would still have the problem that if you have a compiled region with an input dtensor, where:

(1) on graph entry, dtensor._masked_data = None and type(dtensor.placements[0]) != _MaskPartial) (2) on graph exit,dtensor._masked_data = some_data_tensor and type(dtensor.placements[0]) == _MaskPartial)

at runtime, we have no way today of properly mutating the attributes on the input dtensor today (this is essentially a type of "metadata mutation" of a graph input). Since we will have inlined through the DTensor code that sets that mask attribute at compile time, and at runtime this code will not run.