Open tomasJwYU opened 1 year ago
I'm also eager to see the code. It would be game-changing.
+1 waiting for code. If released, plase nice guys reply me QAQ
@tomasJwYU @haveyouwantto @shuchenweng Thanks for your interest in this project. FYI, you can try the pre-release version used in the demo! Assuming you have any environments Python>=3.9 and PyTorch>=2.2 installed...
pip install awscli
mkdir amt
aws s3 cp s3://amt-deploy-public/amt/ amt --no-sign-request --recursive
cd amt/src
pip install -r requirements.txt
apt-get install sox # only required for GuitarSet preprocessing...
Dataset download
python install_dataset.py
Please refer to the READEME.md
(a bit outdated) or colab demo code for train.py
and test.py
command usage. Model checkpoints are available in amt/logs
.
Your code looks quite complex, but it was written in a way that was easier to understand than I expected. As a university student, I was able to train and test this model in a short period of time.
It's a truly impressive paper with a model that delivers outstanding performance!
Your code looks quite complex, but it was written in a way that was easier to understand than I expected. As a university student, I was able to train and test this model in a short period of time.
It's a truly impressive paper with a model that delivers outstanding performance!
Hi! Did you manage to train the MoE model on all datasets? Might I ask how long it took you and on what hardware?
FYI, The final model was trained using this options:
python train.py mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2 -p slakh2024 -d all_cross_final -it 320000 -vit 20000 -epe rope -rp 1 -enc perceiver-tf -sqr 1 -ff moe -wf 4 -nmoe 8 -kmoe 2 -act silu -ac spec -hop 300 -bsz 10 10 -xk 5 -tk mc13_full_plus_256 -dec multi-t5 -nl 26 -edr 0.05 -ddr 0.05 -atc 1 -sb 1 -ps -2 2 -st ddp -wb online
-bsz
numbers are the local batch size per GPU: the first is for CPU workers, and the second is the local batch size. Suitable for GPUs with 3-40GB of memory, such as RTX4090 or A100 (40GB).-bsz 10 10
on 8 GPUs, the global batch size is 80.-bsz 11 22
. This creates 2 data-loaders (bsz=11 for each) per GPU.-it 320000
and -vit 20000
mean 320K max iterations with validation every 20K iterations (validate 16 times). Each validation takes 0.5~1 hour. Avoid frequent validations due to the time-consuming nature of auto-regressive inference and evaluation metrics.-it 100000 -vit 10000
. It takes about 1.5 days on a single H100 80GB.FYI, The final model was trained using this options:
python train.py mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2 -p slakh2024 -d all_cross_final -it 320000 -vit 20000 -epe rope -rp 1 -enc perceiver-tf -sqr 1 -ff moe -wf 4 -nmoe 8 -kmoe 2 -act silu -ac spec -hop 300 -bsz 10 10 -xk 5 -tk mc13_full_plus_256 -dec multi-t5 -nl 26 -edr 0.05 -ddr 0.05 -atc 1 -sb 1 -ps -2 2 -st ddp -wb online
- The
-bsz
numbers are the local batch size per GPU: the first is for CPU workers, and the second is the local batch size. Suitable for GPUs with 3-40GB of memory, such as RTX4090 or A100 (40GB).- With
-bsz 10 10
on 8 GPUs, the global batch size is 80.- For 80GB GPUs like H100 or A100(80GB), use
-bsz 11 22
. This creates 2 data-loaders (bsz=11 for each) per GPU.-it 320000
and-vit 20000
mean 320K max iterations with validation every 20K iterations (validate 16 times). Each validation takes 0.5~1 hour. Avoid frequent validations due to the time-consuming nature of auto-regressive inference and evaluation metrics.- For quicker training, try
-it 100000 -vit 10000
. It takes about 1.5 days on a single H100 80GB.
Hi Thank you for sharing your great work. Currently I'm trying to train it on RTX4090 24GB, but It's not working due to GPU memory OOM. So, does GPU memory need to be 30GB or more for model training?
Here is my training script, If you have some tips to reduce memory, please let me know.
python train.py mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2 -p partial_ymt3 -d maestro_final -it 320000 -vit 20000 -epe rope -rp 1 -enc perceiver-tf -sqr 1 -ff moe -wf 4 -nmoe 8 -kmoe 2 -act silu -ac spec -hop 128 -bsz 1 1 -xk 5 -tk mt3_midi -dec multi-t5 -nl 26 -edr 0.05 -ddr 0.05 -atc 1 -sb 1 -ps -2 2 -st ddp -wb online
@noirmist Hi,
-dec multi-t5
, which should be paired with the multi-channel task: -tk mc13_full_plus_256
. This corresponds to 13-channel decoding with the FULL_PLUS
vocabulary and a max sequence length of 256.MIDI_PLUS
vocab within the multi-channel setup, use: -tk mc13_256
.PLUS
) since you're training a piano model, but it won't make much difference.-bsz 1 1
is too small, so no augmentation happens within the batch. From what I remember, -bsz 11 11
worked well, but if you run into OOM errors, try -bsz 9 9
.Firstly I'd like to take a moment to appreciate the work done by @mimbres and co-authors, the work is pretty extensive and showcases how YourMT3 is powerful in AMT. As per the paper, "YOURMT3 TOOLKIT" is mentioned in the last section. I presume this is a dataset preparation pipeline. Is this available? or does the source code itself encompass this?
Regards Cliff
@cliffordkleinsr Thanks for your interest in this project. Yes, it includes everything needed for training—defining tasks with tokens, managing data, scheduling, and evaluation metrics for different instruments. It's all in the pre-release code, but refactoring it takes time, so I'll release it with some compromises. The most reusable parts are data loading, evaluation metrics, and augmentation, though the lack of documentation may make it tricky.
For data preparation, check the code in utils/preprocess/
. It integrates around 10 datasets in different formats. For custom datasets, just prepare MIDI and audio files. The Maestro dataset is a good reference.
The more I've delve into this project the more mindblown I am. Truly incredible work. As part of a study project here at my university we replicated the training of the MoE model without much trouble, and are preparing new models and tokenization schemes using the framework -- so even in this pre-release it is an amazing toolkit.
I was wondering if it is possible to request access to the restricted datasets? To ensure our replication was faithful.
@karioth You can request the access token: https://zenodo.org/records/10016397 I'm sorry for the lack of documentation!!
Thank you so much! I just sent the request :D
@karioth I missed checking the message that came 27 days ago! (Sorry about that) It should work now.
Hi, Thanks for sharing all this information in this repo, I cannot wait to see your code~