Unbabel / COMET

A Neural Framework for MT Evaluation
https://unbabel.github.io/COMET/html/index.html
Apache License 2.0
493 stars 76 forks source link

Specifying GPU ID for inference #120

Closed jinulee-v closed 1 year ago

jinulee-v commented 1 year ago

🚀 Feature

Adding parameters to specify GPU ids when calling model.predict()

Motivation

Currently, the COMET API only takes the number of GPU cores to work on. I wish to run COMET on multi-gpu environment, where each core is working on its own task. So it would be nice to specify visible GPU ids for the model. predict() API will probably look like:

model.predict(data, batch_size=BATCH_SIZE, gpus=2, gpu_ids=[0, 2])

Alternatives

Additional context

Since model prediction is made by ptl.Trainer, I presume that the modification will not be so hard if utilizing the ptl API.

ricardorei commented 1 year ago

Yep the modification should not be hard.

For now I suggest you launch your script using CUDA_VISIBLE_DEVICES=[0,2]

I'll leave this open for future reference

jinulee-v commented 1 year ago

This feature is a must in some use cases, like multi-GPU training with online feedback of COMET score. It can be quite annoying to handle device ids in such cases... I am looking forward for a update! Thanks for a fast & nice reply :)

ricardorei commented 1 year ago

@jinulee-v could you test the current master? I added the devices argument to predict. I tested it in the following way:

# Backwards compatibility
model_output = model.predict(data, batch_size=8, gpus=1)

# Specify which device for single GPU
model_output = model.predict(data, batch_size=8, gpus=1, devices=[7])

# Specify devices for DDP
model_output = model.predict(data, batch_size=8, gpus=2, devices=[5, 7])
ricardorei commented 1 year ago

Feel free to suggest edits.

The gpus argument could actually be replaced with devices only but I don't want to introduce breaking changes in the current predict interface.

PS: if this works you can close the issue and I'll release it in the next pip release. For now I'll wait a few weeks to see if more people report bugs/suggest features.