pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
80.06k stars 21.52k forks source link

inductor::_reinterpret_tensor() Expected a value of type 'List[int]' for argument 'size' but instead found type 'tuple' #124498

Open fegin opened 2 months ago

fegin commented 2 months ago

🐛 Describe the bug

When running the Dynamo hf_Whisper benchmark with CompiledDDP (CompiledAutograd + DDP Python reducer), the Inductor generates the code that cannot be run. The error is inductor::_reinterpret_tensor() Expected a value of type 'List[int]' for argument 'size' but instead found type 'tuple' -- the generated code contains float which does not accept by inductor::_reinterpret_tensor(). The gradients of this model contain dynamic shape which may be related to the error. We are unable to reproduce this error with an unittest or a smaller model.

To reproduce this error, checkout https://github.com/pytorch/pytorch/pull/121315 and use the following comment: python benchmarks/dynamo/torchbench.py --performance --cold-start-latency --training --backend inductor --disable-cudagraphs --device cuda --ddp --multiprocess --optimize-ddp-mode="python_reducer" --only hf_Whisper --compiled-autograd

Error logs

  File "/data/users/chienchin/mywork/pytorch/torch/_dynamo/external_utils.py", line 36, in inner                                  16:08:50 [33/29551]    return fn(*args, **kwargs)
  File "/data/users/chienchin/mywork/pytorch/torch/_dynamo/utils.py", line 2689, in wrapper
    return compiled_fn(flat_args)
  File "/data/users/chienchin/mywork/pytorch/torch/_dynamo/eval_frame.py", line 410, in _fn
    return fn(*args, **kwargs)
  File "/data/users/chienchin/mywork/pytorch/torch/_dynamo/external_utils.py", line 36, in inner
    return fn(*args, **kwargs)
  File "/data/users/chienchin/mywork/pytorch/torch/_functorch/aot_autograd.py", line 974, in boxed_forward
    return compiled_fn(flat_args)
  File "/data/users/chienchin/mywork/pytorch/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 130, in runtime_wrapper
    all_outs = call_func_at_runtime_with_args(
  File "/data/users/chienchin/mywork/pytorch/torch/_functorch/_aot_autograd/utils.py", line 116, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
  File "/data/users/chienchin/mywork/pytorch/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 181, in rng_functionalization_wrapper
    return compiled_fw(args)
  File "/data/users/chienchin/mywork/pytorch/torch/_inductor/codecache.py", line 954, in __call__
    return self.current_callable(inputs)
  File "/data/users/chienchin/mywork/pytorch/torch/_inductor/compile_fx.py", line 853, in run
    return model(new_inputs)
  File "/tmp/tmppakyeush/im/cimukl26obc7rzhlq57a7n4zu5vd3viooh7sdulbbbrvtqlskdhl.py", line 3342, in call                          16:08:50 [12/29551]    return (reinterpret_tensor(buf139, (384, ), (1, ), 0), reinterpret_tensor(buf139, (384, ), (1, ), 384), reinterpret_tensor(buf139, (384, ), (1, ), 768), reinterpret_tensor(buf139, (384, ), (1, ), 1152), reinterpret_tensor(buf139, (384, ), (1, ), 1536), reinterpret_tensor(buf139, (384, ), (1, ), 1920), reinterpret_tensor(buf139, (384, ), (1, ), 2304), reinterpret_tensor(buf139, (384, ), (1, ), 2688), reinterpret_tensor(buf139, (384, ), (1,
), 3072), reinterpret_tensor(buf139, (384, ), (1, ), 3456), reinterpret_tensor(buf139, (384, 384, 3), (1152, 3, 1), 3840), reinterpret_tensor(buf350, (384, ), (1, ), 0), reinterpret_tensor(buf350, (384, 80, 3), (240, 3, 1), 384), reinterpret_tensor(buf350, (384, ), (1, ), 92544), reinterpret_tensor(buf350, (384, ), (1, ), 92928), reinterpret_tensor(buf350, (384, 1536), (1536, 1), 93312), reinterpret_tensor(buf350, (1536, ), (1, ), 683136), reinterpret_tensor(buf350, (1536, 384), (384, 1), 684672), reinterpret_tensor(buf350, (384, ), (1, ), 1274496), reinterpret_tensor(buf350, (384, ), (1,
), 1274880), reinterpret_tensor(buf350, (384, 384), (384, 1), 1275264), reinterpret_tensor(buf350, (384, ), (1, ), 1422720), reinterpret_tensor(buf350, (384.000000000000, 384), (384, 1), 1423104.00000000), reinterpret_tensor(buf350, (384.000000000000, 384), (384, 1), 1570560.00000000), reinterpret_tensor(buf350, (384, ), (1, ), 1718016.00000000), reinterpret_tensor(buf350, (384.000000000000, 384), (384, 1), 1718400.00000000), reinterpret_tensor(buf350, (384, ), (1, ), 1865856.00000000), reinterpret_tensor(buf350, (384, ), (1, ), 1866240.00000000), reinterpret_tensor(buf350, (384, 1536), (1536, 1), 1866624.00000000), reinterpret_tensor(buf350, (1536, ), (1, ), 2456448.00000000), reinterpret_tensor(buf350, (1536, 384), (384, 1), 2457984.00000000), reinterpret_tensor(buf350, (384, ), (1, ), 3047808.00000000), reinterpret_tensor(buf350, (384, ), (1, ), 3048192.00000000), reinterpret_tensor(buf350, (384, 384), (384, 1), 3048576.00000000), reinterpret_tensor(buf350, (384, ), (1, ), 3196032.00000000), reinterpret_tensor(buf350, (384.000000000000, 384), (384, 1), 3196416.00000000), reinterpret_tensor(buf350, (384.000000000000, 384), (384, 1), 3343872.00000000), reinterpret_tensor(buf350, (384, ), (1, ), 3491328.00000000), reinterpret_tensor(buf350, (384.000000000000, 384), (384, 1), 3491712.00000000), reinterpret_tensor(buf350, (384, ), (1, ), 3639168.00000000), reinterpret_tensor(buf350, (384, ), (1, ), 3639552.00000000), reinterpret_tensor(buf350, (384, 1536), (1536, 1),
3639936.00000000), reinterpret_tensor(buf350, (1536, ), (1, ), 4229760.00000000), reinterpret_tensor(buf350, (1536, 384), (384, 1), 4231296.00000000), reinterpret_tensor(buf350, (384, ), (1, ), 4821120.00000000), reinterpret_tensor(buf350, (384, ), (1, ), 4821504.00000000), reinterpret_tensor(buf350, (384, 384), (384, 1), 4821888.00000000), reinterpret_tensor(buf350, (384, ), (1, ), 4969344.00000000), reinterpret_tensor(buf350, (384.000000000000, 384), (384, 1), 4969728.00000000), reinterpret_tensor(buf350, (384.000000000000, 384), (384, 1), 5117184.00000000), reinterpret_tensor(buf350, (384, ), (1, ), 5264640.00000000), reinterpret_tensor(buf350, (384.000000000000, 384), (384, 1), 5265024.00000000), reinterpret_tensor(buf350, (384, ), (1, ), 5412480.00000000), reinterpret_tensor(buf350, (384, ), (1, ), 5412864.00000000), reinterpret_tensor(buf350, (384, 1536), (1536, 1), 5413248.00000000), reinterpret_tensor(buf350, (1536, ), (1, ), 6003072.00000000), reinterpret_tensor(buf350, (1536, 384), (384, 1), 6004608.00000000), reinterpret_tensor(buf412, (384, ), (1, ), 0), reinterpret_tensor(buf412, (384, ), (1, ), 384), reinterpret_tensor(buf412, (384, 384), (384, 1), 768), reinterpret_tensor(buf412, (384, ), (1, ), 148224), reinterpret_tensor(buf412, (384.000000000000, 384), (384, 1), 148608.000000000), reinterpret_tensor(buf412, (384.000000000000, 384), (384, 1), 296064.000000000), reinterpret_tensor(buf412, (384, ), (1, ), 443520.000000000), reinterpret_tensor(buf412,
(384.000000000000, 384), (384, 1), 443904.000000000), reinterpret_tensor(buf412, (384, ), (1, ), 591360.000000000), reinterpret_tensor(buf412, (2, ), (1, ), 591744.000000000), reinterpret_tensor(buf412, (2, 256), (256, 1), 591746.000000000), reinterpret_tensor(buf412, (256, ), (1, ), 592258.000000000), reinterpret_tensor(buf412, (256, 384), (384, 1), 592514.000000000), )
  File "/data/users/chienchin/mywork/pytorch/torch/_ops.py", line 870, in __call__
    return self_._op(*args, **(kwargs or {}))
RuntimeError: inductor::_reinterpret_tensor() Expected a value of type 'List[int]' for argument 'size' but instead found type 'tuple'.
Position: 1
Value: (384.0, 384)
Declaration: inductor::_reinterpret_tensor(Tensor self, int[] size, int[] stride, int offset_increment=0) -> Tensor
Cast error details: Unable to cast Python instance of type <class 'tuple'> to C++ type '?' (#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details)

Minified repro

No response

Versions

https://github.com/pytorch/pytorch/pull/121315

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang

yf225 commented 1 month ago

We could do this as a stop-gap fix, since sizevars should always be integers:

    def codegen_python_sizevar(self, x: Expr) -> str:
        x_s = V.graph.sizevars.simplify(x)
        if not x_s.free_symbols:
            x_s = sympy.Integer(x_s)
        return pexpr(x_s)

But a proper fix is to find out where the floats are generated.