NVIDIA / modulus-launch

Repo of optimized training recipes for accelerating PyTorch workflows of AI driven surrogates for physical systems
Apache License 2.0
56 stars 27 forks source link

Fix batch size in inference #35

Closed BriacMB closed 1 year ago

BriacMB commented 1 year ago

The dataloader in inference is using the configuration batch size. When loading a new graph line 82, the whole batch was loaded on a single graph. For a config with a batch_size !=1 this break the inference. For example, with batch_size=3, the loaded graph is : Graph(num_nodes=5769, num_edges=33210,...) with a problem of num_nodes=1923. This lead to an error line 119, where mask and pred_i[:, 0:2] have incompatible shapes.

This fix set the batch_size to one, to ensure the inference is running with only one graph loaded at a time to have compatible shapes.