SkyTNT / midi-model

Midi event transformer for symbolic music generation
Apache License 2.0
183 stars 29 forks source link

Train.py - ValueError: Transformers now supports natively BetterTransformer optimizations #14

Closed Camille-Molinier closed 3 weeks ago

Camille-Molinier commented 9 months ago

Hey, I'm working on a fine-tuning of your model to build a Pokémon diamond music generator. For now, I'm trying to understand your code and run it. When I try to run the train.py (data is my folder with midi file) :

python .\train.py --data ".\data\"

I encounter this error :

Traceback (most recent call last):
  File "D:\Fac\ESIR3\IA\TP\midi-model\train.py", line 336, in <module>
    model = TrainMIDIModel(tokenizer, flash=True, lr=opt.lr, weight_decay=opt.weight_decay,
  File "D:\Fac\ESIR3\IA\TP\midi-model\train.py", line 101, in __init__
    super(TrainMIDIModel, self).__init__(tokenizer=tokenizer, n_layer=n_layer, n_head=n_head, n_embd=n_embd,
  File "D:\Fac\ESIR3\IA\TP\midi-model\midi_model.py", line 30, in __init__
    self.net = self.net.to_bettertransformer()
  File "D:\Fac\ESIR3\IA\TP\midi-model\venv\lib\site-packages\transformers\modeling_utils.py", line 4314, in to_bettertransformer
    return BetterTransformer.transform(self)
  File "C:\Users\diman\AppData\Local\Programs\Python\Python39\lib\contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "D:\Fac\ESIR3\IA\TP\midi-model\venv\lib\site-packages\optimum\bettertransformer\transformation.py", line 211, in transform
    raise ValueError(
ValueError: Transformers now supports natively BetterTransformer optimizations (torch.nn.functional.scaled_dot_product_attention) for the model type llama. Please upgrade to transformers>=4.36 and torch>=2.1.1 to use it. Details: https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-and-memory-efficient-attention-through-pytorchs-scaleddotproductattention

I'm using torch==2.1.2+cu121 and transformers==4.36.2 Any ideas ?

beezisback commented 9 months ago

I also have this problem now. on the old operating system, I managed to do training, but I don't know what version of python and torch I used then in September 2023. who manages to train, please post. in the past I used the python command train.py --data midi_data/train --lr 0.001 --batch-size-train 16 --max-step 1000 --val-step 25

bruzo commented 8 months ago

I also encountered this - I am totally new to this so what I did was probably totally wrong, but I got it running at least. Still get quite a few warnings from pytorch, but at least it runs:

In train.py ( line 335 ) there is a call to TrainMidiModel with flash=True, I changed it to False and that did it for me.

Also I had to reconfigure the torch options quite a bit to make finetuning possible on my 8gig GPU. Its painful, but at least it's possible ( about half an hour / epoch ) and a full validation will not work because of too low memory.