SkyTNT / midi-model

Midi event transformer for symbolic music generation
Apache License 2.0
183 stars 29 forks source link

Error while training #13

Closed Gymothy closed 10 months ago

Gymothy commented 10 months ago

Hi! thank you for making this, I've been looking for something exactly like it for a while and I think it will suit me perfectly. I had no problems installing and inferencing with the pretrained model but when i try to train i get a strange error

 seed=0, lr=2e-05, weight_decay=0.01, warmup_step=1000.0, max_step=60000, grad_clip=1.0, batch_size_train=2, batch_size_val=2, workers_train=8, workers_val=8, acc_grad=2, accelerator='gpu', devices=-1, fp32=False, disable_benchmark=False, log_step=1, val_step=3200)
Seed set to 0
---load dataset---
train: 41681  val: 2560
Traceback (most recent call last):
  File "H:\AI\MIDI\midi-model\train.py", line 335, in <module>
    model = TrainMIDIModel(tokenizer, flash=True, lr=opt.lr, weight_decay=opt.weight_decay,
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "H:\AI\MIDI\midi-model\train.py", line 101, in __init__
    super(TrainMIDIModel, self).__init__(tokenizer=tokenizer, n_layer=n_layer, n_head=n_head, n_embd=n_embd,
  File "H:\AI\MIDI\midi-model\midi_model.py", line 26, in __init__
    self.net = self.net.to_bettertransformer()
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\ProgramData\Miniconda3\envs\midi\Lib\site-packages\transformers\modeling_utils.py", line 4314, in to_bettertransformer
    return BetterTransformer.transform(self)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\ProgramData\Miniconda3\envs\midi\Lib\contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "C:\ProgramData\Miniconda3\envs\midi\Lib\site-packages\optimum\bettertransformer\transformation.py", line 211, in transform
    raise ValueError(
ValueError: Transformers now supports natively BetterTransformer optimizations (torch.nn.functional.scaled_dot_product_attention) for the model type llama. Please upgrade to transformers>=4.36 and torch>=2.1.1 to use it. Details: https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-and-memory-efficient-attention-through-pytorchs-scaleddotproductattention

I'm certain that i'm running these versions of both transformers and torch and unsure why model type is showing as llama? I couldn't find anyone else encountering this and would be grateful for any assistance. thanks. i'm using conda venv and my pip list is below.

aiofiles                  23.2.1
aiohttp                   3.9.1
aiosignal                 1.3.1
altair                    5.2.0
annotated-types           0.6.0
anyio                     4.2.0
attrs                     23.2.0
Brotli                    1.0.9
certifi                   2023.11.17
cffi                      1.16.0
charset-normalizer        2.0.4
click                     8.1.7
colorama                  0.4.6
coloredlogs               15.0.1
contourpy                 1.2.0
cryptography              41.0.7
cycler                    0.12.1
datasets                  2.14.4
dill                      0.3.7
fastapi                   0.108.0
ffmpy                     0.3.1
filelock                  3.13.1
fonttools                 4.47.0
frozenlist                1.4.1
fsspec                    2023.12.2
gmpy2                     2.1.2
gradio                    3.41.2
gradio_client             0.5.0
h11                       0.14.0
httpcore                  1.0.2
httpx                     0.26.0
huggingface-hub           0.20.1
humanfriendly             10.0
idna                      3.4
importlib-resources       6.1.1
Jinja2                    3.1.2
jsonschema                4.20.0
jsonschema-specifications 2023.12.1
kiwisolver                1.4.5
lightning-utilities       0.10.0
MarkupSafe                2.1.1
matplotlib                3.8.2
mkl-fft                   1.3.8
mkl-random                1.2.4
mkl-service               2.4.0
mpmath                    1.3.0
multidict                 6.0.4
multiprocess              0.70.15
networkx                  3.1
numpy                     1.26.2
optimum                   1.16.1
orjson                    3.9.10
packaging                 23.2
pandas                    2.1.4
Pillow                    10.0.1
pip                       23.3.1
protobuf                  4.25.1
pyarrow                   14.0.2
pycparser                 2.21
pydantic                  2.5.3
pydantic_core             2.14.6
pydub                     0.25.1
pyFluidSynth              1.3.2
pyOpenSSL                 23.2.0
pyparsing                 3.1.1
pyreadline3               3.4.1
PySocks                   1.7.1
python-dateutil           2.8.2
python-multipart          0.0.6
pytorch-lightning         2.1.3
pytz                      2023.3.post1
PyYAML                    6.0.1
referencing               0.32.0
regex                     2023.12.25
requests                  2.31.0
rpds-py                   0.16.2
safetensors               0.4.1
semantic-version          2.10.0
sentencepiece             0.1.99
setuptools                68.2.2
six                       1.16.0
sniffio                   1.3.0
starlette                 0.32.0.post1
sympy                     1.12
tokenizers                0.15.0
toolz                     0.12.0
torch                     2.1.2+cu118
torchaudio                2.1.2+cu118
torchmetrics              1.2.1
torchvision               0.16.2+cu118
tqdm                      4.66.1
transformers              4.36.2
typing_extensions         4.9.0
tzdata                    2023.4
urllib3                   1.26.18
uvicorn                   0.25.0
websockets                11.0.3
wheel                     0.41.2
win-inet-pton             1.1.0
xxhash                    3.4.1
yarl                      1.9.4
Gymothy commented 10 months ago

looks like this specific error was fixed by downgrading optimum to 1.12.0 - the requirement may need to be adjusted