Closed BitCalSaul closed 10 months ago
This problem is caused by windows.shape[0]
. The tensor windows
is determined during inference, making the graph non-static. More exactly, a tensor is not a node in the graph but a stream "flowing" in the graph (yes, like "tensorflow"), and thus they have no attributes like shape
. This is a limitation of torch.fx
. The solution to this problem is replacing tensor.shape
used in forward()
with a pre-defined deterministic value, albeit requiring some manual effort.
According to my experiment, it can run successfully when tensor.shape
is the parameter of reshape()
. In other cases (e.g. int(x.shape[0])
), the tensor.shape
cannot be a parameter of the function.
Thanks I fixed B to a number instead of the shape[0], then the profiler works well. Next time I will try to fix the model to a static flow model.
Hi, I tried to test the Swin transformer's performance using torch_flops. But there came an error like this:
I went through it and found the variable "windows" is a Proxy variable just shown in the pic below.
The function is from the official repo of Swin https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py#L45.
This is the minimal code where you could identify this error:
Thanks if you need more information plz let me know.