graphnet-team / graphnet

A Deep learning library for neutrino telescopes
https://graphnet-team.github.io/graphnet/
Apache License 2.0
85 stars 86 forks source link

enforce gpu arg in .predict_as_dataframe #636

Closed RasmusOrsoe closed 7 months ago

RasmusOrsoe commented 7 months ago

Currently, the gpus argument from the CLI in our training examples doesn't get passed to model.predict_as_dataframe, essentially forcing the prediction to happen on CPU.

This has caused some confusion for users that have tried to extend the training examples into training scripts for their own use cases.

This PR passes the argument to the model.predict_as_dataframe.