pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.79k stars 22.59k forks source link

dlrm and hf_T5_generate fails aot_eager with bfloat16+dynamic_shapes #103760

Open desertfire opened 1 year ago

desertfire commented 1 year ago

Repro:

benchmarks/dynamo/torchbench.py --inductor --bfloat16 --accuracy --inference --device cuda  --dynamic-shapes --dynamic-batch-only --only hf_T5_generate

Error:

2023-06-16T13:53:37.3160247Z cuda eval  hf_T5_generate                     
2023-06-16T13:54:07.7583460Z ERROR:common:Constraints violated!
2023-06-16T13:54:07.7584228Z   1. Could not validate constraint RelaxedUnspecConstraint(L['input_ids'].size()[0]) as L['input_ids'].size()[0] is actually a non-atomic symbolic expression 4. Did you really mean to mark this dimension as dynamic?
2023-06-16T13:54:07.7586656Z 
2023-06-16T13:54:07.7587045Z 
2023-06-16T13:54:07.7587387Z You can suppress this exception and fall back to eager by setting:
2023-06-16T13:54:07.7587758Z     import torch._dynamo
2023-06-16T13:54:07.7589120Z     torch._dynamo.config.suppress_errors = True
2023-06-16T13:54:07.7589393Z Traceback (most recent call last):
2023-06-16T13:54:07.7589768Z   File "/var/lib/jenkins/workspace/benchmarks/dynamo/common.py", line 1531, in check_accuracy
2023-06-16T13:54:07.7590502Z     new_result = optimized_model_iter_fn(model_copy, example_inputs)
2023-06-16T13:54:07.7591023Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 295, in _fn
2023-06-16T13:54:07.7591317Z     return fn(*args, **kwargs)
2023-06-16T13:54:07.7591712Z   File "/var/lib/jenkins/workspace/benchmarks/dynamo/common.py", line 1356, in run_n_iterations
2023-06-16T13:54:07.7592145Z     self.model_iter_fn(mod, inputs, collect_outputs=False)
2023-06-16T13:54:07.7592558Z   File "/var/lib/jenkins/workspace/benchmarks/dynamo/torchbench.py", line 436, in forward_pass
2023-06-16T13:54:07.7592841Z     return mod(*inputs)
2023-06-16T13:54:07.7595175Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
2023-06-16T13:54:07.7595553Z     return self._call_impl(*args, **kwargs)
2023-06-16T13:54:07.7596028Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _call_impl
2023-06-16T13:54:07.7596419Z     return forward_call(*args, **kwargs)
2023-06-16T13:54:07.7596873Z   File "/var/lib/jenkins/workspace/torchbench/torchbenchmark/util/framework/huggingface/model_factory.py", line 204, in forward
2023-06-16T13:54:07.7597338Z     return self.model.generate(inputs, self.generation_config)
2023-06-16T13:54:07.7597986Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
2023-06-16T13:54:07.7598376Z     return func(*args, **kwargs)
2023-06-16T13:54:07.7598969Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/transformers/generation/utils.py", line 1192, in generate
2023-06-16T13:54:07.7599424Z     self._validate_model_class()
2023-06-16T13:54:07.7599961Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/transformers/generation/utils.py", line 1210, in <resume in generate>
2023-06-16T13:54:07.7600323Z     generation_config = copy.deepcopy(generation_config)
2023-06-16T13:54:07.7600778Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/transformers/generation/utils.py", line 1213, in <resume in generate>
2023-06-16T13:54:07.7601120Z     self._validate_model_kwargs(model_kwargs.copy())
2023-06-16T13:54:07.7601708Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/transformers/generation/utils.py", line 1213, in <resume in generate>
2023-06-16T13:54:07.7602049Z     self._validate_model_kwargs(model_kwargs.copy())
2023-06-16T13:54:07.7602477Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/transformers/generation/utils.py", line 1216, in <resume in generate>
2023-06-16T13:54:07.7602883Z     logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
2023-06-16T13:54:07.7603441Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/transformers/generation/utils.py", line 1217, in <resume in generate>
2023-06-16T13:54:07.7603871Z     stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
2023-06-16T13:54:07.7604372Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/transformers/generation/utils.py", line 1268, in <resume in generate>
2023-06-16T13:54:07.7604734Z     model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
2023-06-16T13:54:07.7605245Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/transformers/generation/utils.py", line 634, in _prepare_encoder_decoder_kwargs_for_generation
2023-06-16T13:54:07.7605640Z     model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)
2023-06-16T13:54:07.7606168Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
2023-06-16T13:54:07.7606496Z     return self._call_impl(*args, **kwargs)
2023-06-16T13:54:07.7607029Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _call_impl
2023-06-16T13:54:07.7607433Z     return forward_call(*args, **kwargs)
2023-06-16T13:54:07.7607935Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 448, in catch_errors
2023-06-16T13:54:07.7608276Z     return callback(frame, cache_size, hooks, frame_state)
2023-06-16T13:54:07.7608708Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 526, in _convert_frame
2023-06-16T13:54:07.7609070Z     result = inner_convert(frame, cache_size, hooks, frame_state)
2023-06-16T13:54:07.7609494Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 127, in _fn
2023-06-16T13:54:07.7609790Z     return fn(*args, **kwargs)
2023-06-16T13:54:07.7610196Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 360, in _convert_frame_assert
2023-06-16T13:54:07.7610500Z     return _compile(
2023-06-16T13:54:07.7610880Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 180, in time_wrapper
2023-06-16T13:54:07.7611178Z     r = func(*args, **kwargs)
2023-06-16T13:54:07.7611568Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 478, in _compile
2023-06-16T13:54:07.7611884Z     check_fn = CheckFunctionManager(
2023-06-16T13:54:07.7612289Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/guards.py", line 863, in __init__
2023-06-16T13:54:07.7612603Z     guard.create(local_builder, global_builder)
2023-06-16T13:54:07.7612993Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_guards.py", line 208, in create
2023-06-16T13:54:07.7613349Z     return self.create_fn(self.source.select(local_builder, global_builder), self)
2023-06-16T13:54:07.7613787Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/guards.py", line 540, in SHAPE_ENV
2023-06-16T13:54:07.7614122Z     guards = output_graph.shape_env.produce_guards(
2023-06-16T13:54:07.7614575Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 2596, in produce_guards
2023-06-16T13:54:07.7614967Z     raise ConstraintViolationError(f"Constraints violated!\n{err}")
2023-06-16T13:54:07.7615336Z torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated!
2023-06-16T13:54:07.7616030Z   1. Could not validate constraint RelaxedUnspecConstraint(L['input_ids'].size()[0]) as L['input_ids'].size()[0] is actually a non-atomic symbolic expression 4. Did you really mean to mark this dimension as dynamic?

cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305

ezyang commented 1 year ago

First step to debug this is get a TORCH_LOGS=dynamic log on it

ezyang commented 1 year ago

Relevant logs:

[2023-06-23 07:08:02,265] torch.fx.experimental.symbolic_shapes: [INFO] 20.0: create_symbol s0 = 4 for L['input_ids'].size()[0]       
[2023-06-23 07:08:02,421] torch.fx.experimental.symbolic_shapes: [INFO] 20.0: eval Ne(s0, 512) [guard added] at transformers/src/transformers/models/t5/modeling_t5.py:259 in forward (_subclasses/fake_tensor.py:724 in fast_binary_impl)                                  
[2023-06-23 07:08:02,582] torch.fx.experimental.symbolic_shapes: [INFO] 20.0: eval Eq(s0, 4) [guard added] at transformers/src/transformers/models/t5/modeling_t5.py:559 in forward (_refs/__init__.py:368 in _broadcast_shapes)        
[2023-06-23 07:08:13,538] torch.fx.experimental.symbolic_shapes: [INFO] 20.0: produce_guards                       
ERROR:common:Constraints violated!                                                                                                    
  1. Could not validate constraint RelaxedUnspecConstraint(L['input_ids'].size()[0]) as L['input_ids'].size()[0] is actually a non-atomic symbolic expression 4. Did you really mean to mark this dimension as dynamic?   

Actually, this feels a bit familiar...

ezyang commented 1 year ago

I think this is the same as https://github.com/pytorch/pytorch/issues/102814#issue-1737404026 (probably bfloat16 has perturbed the graph breaks which is why it is breaking now).

Also, see this special case

        if args.only in {"hf_T5_generate"}:
            torch._dynamo.config.automatic_dynamic_shapes = True
ezyang commented 1 year ago

https://github.com/pytorch/pytorch/pull/106808 tries to turn this back on