AlessioSam / CHICO-PoseForecasting

Repository for "Pose Forecasting in Industrial Human-Robot Collaboration" (ECCV 2022)
30 stars 5 forks source link

Error while converting the model to TorchScript #2

Closed OmkarKabadagi5823 closed 1 year ago

OmkarKabadagi5823 commented 1 year ago

Issue

After exploring a bit with the repository, I decided to convert the model to TorchScript to load it in C++. There are two ways to convert the model into a TorchScript.

  1. Tracing
  2. Scripting

The first issue that I am facing with tracing is that the function torch::jit::trace takes as input the model and an example input to trace the flow. Because, the forward function of the model also takes the maskA and maskT as input, the trace function gives an error stating that maskA and maskT are not provided and I am not able to find a way to pass the mask to the torch::jit::trace function. I think this issue is more relevant for the pytorch community but I am stating it nonetheless, to know if you have fixed it before.

TypeError: forward() missing 2 required positional arguments: 'maskA' and 'maskT'

The second issue that I am facing is when trying to convert the model using scripiting (torch::jit::script) which just takes the model and directly compiles it to TorchScript. The error while running the script function is pasted below:

RuntimeError: 
Expected integer literal for index. ModuleList/Sequential indexing is only supported with integer literals. Enumeration is supported, e.g. 'for index, v in enumerate(self): ...':
  File "/home/omkar/ws/cobot/CHICO-PoseForecasting/models/SeSGCN_student.py", line 339

        for i in range(1,self.n_txcnn_layers):
            x = self.prelus[i](self.txcnns[i](x)) +x # residual connection
                ~~~~~~~~~~~~~~ <--- HERE

        return x

Fix

I don't have a fix for the first issue as of yet, but I was able to fix the second issue with the help of error generated. The error suggests that variable indexing is not supported by ModuleList. This might because static analysis of the model is not be possible with variable indexing and the torch::jit::script needs to ensure that the indices are valid. Hence, an alternate method suggested is enumeration. After applying the fixes the code looks like this:

for i, (prelu, txcnn) in enumerate(zip(self.prelus, self.txcnns)):
            if i == 0:
                x = prelu(txcnn(x))
            else:
                x = prelu(txcnn(x)) + x

I have verified the changes do not fundamentally alter the model and still work with the provided checkpoints. I am also publishing a pull request to apply this changes to the code.