pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
16.23k stars 6.95k forks source link

Fix ONNX export of RAFT optical flow model #8116

Open farresti opened 12 months ago

farresti commented 12 months ago

🚀 The feature

Two proposed changes to fix the export of the RAFT model with dynamic batch_size and dynamic num_flow_updates.

Changes to be done for dynamic batch_size: in CorrBlock._compute_corr_volume: change corr / torch.sqrt(torch.tensor(num_channels)) to corr / torch.sqrt(torch.tensor(num_channels).float())

Changes to be done for dynamic num_flow_updates: in RAFT.forward: change:

flow_predictions = []
for _ in range(num_flow_updates):
    coords1 = coords1.detach()  # Don't backpropagate gradients through this branch, see paper
    corr_features = self.corr_block.index_pyramid(centroids_coords=coords1)

    flow = coords1 - coords0
    hidden_state, delta_flow = self.update_block(hidden_state, context, corr_features, flow)

    coords1 = coords1 + delta_flow

    up_mask = None if self.mask_predictor is None else self.mask_predictor(hidden_state)
    upsampled_flow = upsample_flow(flow=(coords1 - coords0), up_mask=up_mask)
    flow_predictions.append(upsampled_flow)

to

flow_predictions = torch.zeros((num_flow_updates, batch_size, 2, h, w))
for i in range(num_flow_updates):
    coords1 = coords1.detach()  # Don't backpropagate gradients through this branch, see paper
    corr_features = self.corr_block.index_pyramid(centroids_coords=coords1)

    flow = coords1 - coords0
    hidden_state, delta_flow = self.update_block(hidden_state, context, corr_features, flow)

    coords1 = coords1 + delta_flow

    up_mask = None if self.mask_predictor is None else self.mask_predictor(hidden_state)
    upsampled_flow = upsample_flow(flow=(coords1 - coords0), up_mask=up_mask)
    flow_predictions[i] = upsampled_flow

Thanks for all your work :)

Motivation, pitch

Being able to use the model in onnxruntime with dynamic inputs.

Alternatives

No response

Additional context

No response

NicolasHug commented 12 months ago

Thanks for the report/request @farresti . By any chance, is there a more pythonic way to still enable onnx export here? We try to be reasonably conservative about such changes that make the code less readable, as we have to balance a lot of targets in our code already (torchscript + mypy + onnx + torch.compile + [whatever comes next], all of which require ugly workarounds that quickly add up and lead to a massive maintenance burden).

farresti commented 12 months ago

Could specify which part would you like to be more pythonic please @NicolasHug? Is it the num_flow_updates, where the append is changed to a preallocation of the tensor? Unfortunately, without this preallocation, the onnx tracer does not register the num_flow_updates parameter and it does not end up in the exported graph input and stay constant. I tried to let the first implementation as it was (with append), and stacking it into a tensor afterward but the tracer does not see the link with the parameter and still export it as constant.

I based my proposition from this answer of stackoverflow: https://stackoverflow.com/a/76134353

NicolasHug commented 12 months ago

Could specify which part would you like to be more pythonic please @NicolasHug?

I was referring to

flow_predictions = torch.zeros((num_flow_updates, batch_size, 2, h, w))
flow_predictions[i] = upsampled_flow

But if there's no decent alternative then that's OK, thanks for trying.

Would you like to submit a PR with those changes? Ideally with a short non-regression test so we can be sure to not inadvertently "clean" that part later.