Open desertfire opened 1 year ago
First step to debug this is get a TORCH_LOGS=dynamic
log on it
Relevant logs:
[2023-06-23 07:08:02,265] torch.fx.experimental.symbolic_shapes: [INFO] 20.0: create_symbol s0 = 4 for L['input_ids'].size()[0]
[2023-06-23 07:08:02,421] torch.fx.experimental.symbolic_shapes: [INFO] 20.0: eval Ne(s0, 512) [guard added] at transformers/src/transformers/models/t5/modeling_t5.py:259 in forward (_subclasses/fake_tensor.py:724 in fast_binary_impl)
[2023-06-23 07:08:02,582] torch.fx.experimental.symbolic_shapes: [INFO] 20.0: eval Eq(s0, 4) [guard added] at transformers/src/transformers/models/t5/modeling_t5.py:559 in forward (_refs/__init__.py:368 in _broadcast_shapes)
[2023-06-23 07:08:13,538] torch.fx.experimental.symbolic_shapes: [INFO] 20.0: produce_guards
ERROR:common:Constraints violated!
1. Could not validate constraint RelaxedUnspecConstraint(L['input_ids'].size()[0]) as L['input_ids'].size()[0] is actually a non-atomic symbolic expression 4. Did you really mean to mark this dimension as dynamic?
Actually, this feels a bit familiar...
I think this is the same as https://github.com/pytorch/pytorch/issues/102814#issue-1737404026 (probably bfloat16 has perturbed the graph breaks which is why it is breaking now).
Also, see this special case
if args.only in {"hf_T5_generate"}:
torch._dynamo.config.automatic_dynamic_shapes = True
https://github.com/pytorch/pytorch/pull/106808 tries to turn this back on
Repro:
Error:
cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305