MusicLang / musiclang_predict

AI Prediction api of the MusicLang package
GNU General Public License v3.0
250 stars 16 forks source link

[Bug] Use GPU in the predict methods #7

Closed floriangardin closed 4 months ago

floriangardin commented 5 months ago

The predictions are currently not working on the GPU because we don't move the input tensors to the proper device.

This :

# Predict a song

from musiclang_predict import predict, MusicLangTokenizer
from transformers import GPT2LMHeadModel

model = GPT2LMHeadModel.from_pretrained('musiclang/musiclang-4k').to('cuda')
tokenizer = MusicLangTokenizer('musiclang/musiclang-4k')

# Predict a 8 bar song with this model
soundtrack = predict(model, tokenizer, chord_duration=4, nb_chords=8)
soundtrack.to_midi('song.mid', tempo=120, time_signature=(4, 4))
play_music()

raises :

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)