nateraw / audiocraft

Audiocraft is a library for audio processing and generation with deep learning. It features the state-of-the-art EnCodec audio compressor / tokenizer, along with MusicGen, a simple and controllable music generation LM with textual and melodic conditioning.
MIT License
52 stars 2 forks source link

export checkpoints to audiocraft format #1

Open nateraw opened 2 months ago

nateraw commented 2 months ago

moving discussion from https://github.com/facebookresearch/audiocraft/issues/480 to here, as its specific to this fork.

adenta commented 2 months ago

Let me know how I can help tackle this

nateraw commented 2 months ago

If you trained on top of one of FB's base models, the easiest way I found to convert the ckpt is to just load their model, then overwrite the base model's LM weights with your fine-tune's weights. I have code to do it from scratch but always found something wrong with xp.cfg when I did that

import torch
from huggingface_hub import hf_hub_download

# load your finetune's lightning checkpoint
ckpt_ft = torch.load("/path/to/last.ckpt", map_location="cpu")

# load the base model's ckpt
base_ckpt_path = hf_hub_download(
    repo_id="facebook/musicgen-stereo-small",
    filename="state_dict.bin",
)
ckpt_base = torch.load(base_ckpt_path, map_location="cpu")

# overwrite model weights of base model w/ fine tune's
ckpt_base["best_state"] = {".".join(k.split('.')[1:]): v for k, v in ckpt_ft["state_dict"].items() if k.startswith("model")}

# save it and pray it works! πŸ™
torch.save(ckpt_base, "/path/to/state_dict.bin")

let me know if this works for you - should add a script for this to the repo. PR would be much appreciated :)

nateraw commented 2 months ago

from scratch its something like this:

from pathlib import Path

import torch

from audiocraft import __version__

def main(ckpt_path, out_file):
    ckpt_path = Path(ckpt_path)
    out_path = Path(out_file)

    print(f"Loading checkpoint from {ckpt_path}")
    print(f"Will save exported state_dict to {out_path}")

    if not ckpt_path.exists():
        raise FileNotFoundError(f"Checkpoint file not found: {ckpt_path}")

    ckpt = torch.load(ckpt_path, map_location="cpu")
    print("loaded ckpt, now saving...")
    new_pkg = {
        "best_state": {".".join(k.split('.')[1:]): v for k, v in ckpt["state_dict"].items() if k.startswith("model")},
        "xp.cfg": ckpt["hyper_parameters"],
        "version": __version__,
        "exported": True
    }
    out_path.parent.mkdir(exist_ok=True, parents=True)
    torch.save(new_pkg, out_path)

but again, beware of issues with xp.cfg. The reason there are issues there is because I hard coded the config and I think there's something missing/not being updated properly within...don't recall exactly.

nateraw commented 2 months ago

@adenta please do let me know if somethin here works for you. I didn't run either of the above before sending, they're just what I had laying around πŸ˜…

adenta commented 2 months ago

Nate you are fantastic. Don’t want to leave you hanging.

I had to train on top of regular Audiocraft for a presentation happening Monday. Will be able to dive in next week to train on top of songstarter, this thing is amazing.

Will report back in this thread with my findings!!

On Fri, Jul 19, 2024 at 12:25β€―AM Nathan Raw @.***> wrote:

@adenta https://github.com/adenta please do let me know if somethin here works for you. I didn't run either of the above before sending, they're just what I had laying around πŸ˜…

β€” Reply to this email directly, view it on GitHub https://github.com/nateraw/audiocraft/issues/1#issuecomment-2238090203, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACGUU5QM3RULXZW2MVPHME3ZNCINDAVCNFSM6AAAAABLDYDU7KVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEMZYGA4TAMRQGM . You are receiving this because you were mentioned.Message ID: @.***>

adenta commented 2 months ago

save it and pray it works! πŸ™

torch.save(ckpt_base, "/path/to/state_dict.bin")

So this will overwrite the finetuned checkpoint (eg: last.ckpt) and then when I load it for inference, I wont get the weird issues around xp.cfg?

nateraw commented 2 months ago

it will overwrite the weights of the base model using the weights from last.ckpt, then write it to a new file state_dict.bin which you can use to load from audiocraft. the result is the same as the base model (so prevents xp.cfg issues).

this won't work if you added custom conditioning, etc. so be aware of that

adenta commented 2 months ago

can I pass the state_dict.bin directly to MusicGen.get_pretrained? Or do I have to go through the whole song and dance of get_pretrained'ing the base model and somehow overwrite the statedict with the new training?

Again, thank you so much.

On Mon, Jul 22, 2024 at 6:50β€―PM Nathan Raw @.***> wrote:

it will overwrite the weights of the base model using the weights from last.ckpt, then write it to a new file state_dict.bin which you can use to load from audiocraft.

β€” Reply to this email directly, view it on GitHub https://github.com/nateraw/audiocraft/issues/1#issuecomment-2243941461, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACGUU5X3DVAHLRSVSNENNNTZNWEBTAVCNFSM6AAAAABLDYDU7KVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDENBTHE2DCNBWGE . You are receiving this because you were mentioned.Message ID: @.***>

nateraw commented 2 months ago

You'll save it to some folder, then also place the compression model ckpt in that dir:

your_save_dir/
  - state_dict.bin
  - compression_state_dict.bin

You can just download the compression_state_dict from the base model's repo and save that since it didn't get trained when you fine tuned the base model.

Then just MusicGen.get_pretrained("./your_save_dir")

adenta commented 2 months ago

@nateraw the above worked(πŸ˜„ ) but ran into other problems (😞 ). been doing everything on a single lambdalabs A10. The checkpoints I was generating didn't work with inference because the checkpoints I was generating didnt support passing in 'melody' parameter. didnt save the stacktrace but the specific message was: RuntimeError: This model doesn't support melody conditioning. Use themelodymodel..

Seems like a weird error message because Musicgen seems to require this melody parameter.

Tried to upgrade to the same cluster you used (8xA100), and got a cryptic error message:

Traceback (most recent call last):
  File "/home/ubuntu/waitwhat-east/nateraw/audiocraft/lit_train.py", line 139, in <module>
    fire.Fire(main)
  File "/home/ubuntu/waitwhat-east/nateraw/audiocraft/venv/lib/python3.10/site-packages/fire/core.py", line 143, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/home/ubuntu/waitwhat-east/nateraw/audiocraft/venv/lib/python3.10/site-packages/fire/core.py", line 477, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/home/ubuntu/waitwhat-east/nateraw/audiocraft/venv/lib/python3.10/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/home/ubuntu/waitwhat-east/nateraw/audiocraft/lit_train.py", line 133, in main
    trainer.fit(model, ckpt_path=resume_from_ckpt)
  File "/home/ubuntu/waitwhat-east/nateraw/audiocraft/venv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 544, in fit
    call._call_and_handle_interrupt(
  File "/home/ubuntu/waitwhat-east/nateraw/audiocraft/venv/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/ubuntu/waitwhat-east/nateraw/audiocraft/venv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 580, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/ubuntu/waitwhat-east/nateraw/audiocraft/venv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 987, in _run
    results = self._run_stage()
  File "/home/ubuntu/waitwhat-east/nateraw/audiocraft/venv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1033, in _run_stage
    self.fit_loop.run()
  File "/home/ubuntu/waitwhat-east/nateraw/audiocraft/venv/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 205, in run
    self.advance()
  File "/home/ubuntu/waitwhat-east/nateraw/audiocraft/venv/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 363, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/home/ubuntu/waitwhat-east/nateraw/audiocraft/venv/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 140, in run
    self.advance(data_fetcher)
  File "/home/ubuntu/waitwhat-east/nateraw/audiocraft/venv/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 212, in advance
    batch, _, __ = next(data_fetcher)
  File "/home/ubuntu/waitwhat-east/nateraw/audiocraft/venv/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py", line 133, in __next__
    batch = super().__next__()
  File "/home/ubuntu/waitwhat-east/nateraw/audiocraft/venv/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py", line 60, in __next__
    batch = next(self.iterator)
  File "/home/ubuntu/waitwhat-east/nateraw/audiocraft/venv/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py", line 341, in __next__
    out = next(self._iterator)
  File "/home/ubuntu/waitwhat-east/nateraw/audiocraft/venv/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py", line 78, in __next__
    out[i] = next(self.iterators[i])
  File "/home/ubuntu/waitwhat-east/nateraw/audiocraft/venv/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
    data = self._next_data()
  File "/home/ubuntu/waitwhat-east/nateraw/audiocraft/venv/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1328, in _next_data
    idx, data = self._get_data()
  File "/home/ubuntu/waitwhat-east/nateraw/audiocraft/venv/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1289, in _get_data
    raise RuntimeError('Pin memory thread exited unexpectedly')
RuntimeError: Pin memory thread exited unexpectedly
wandb: πŸš€ View run train-musicgen-32k-stereo-small-ft-facebook/musicgen-melody at: https://wandb.ai/andrew-denta-Industrial%20Allusions/audiocraft-48k-stereo-lightning/runs/loey5a68
wandb: ⭐️ View project at: https://wandb.ai/andrew-denta-Industrial%20Allusions/audiocraft-48k-stereo-lightning
wandb: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
wandb: Find logs at: ./wandb/run-20240723_201500-loey5a68/logs
wandb: WARNING The new W&B backend becomes opt-out in version 0.18.0; try it out with `wandb.require("core")`! See https://wandb.me/wandb-core for more information.

Unsure what I will focus on tomorrow but will report back.

adenta commented 2 months ago

Also- still down to talk shop with what we are working on internally with you/Splice! Think there could be a potential collab. Feel free to follow up on Linkedin if you want to set up some time.

nateraw commented 2 months ago

@adenta do you recall the base model you trained on? was it facebook/musicgen-melody? If so, are you sure that's the model that you used to overwrite weights with?

nateraw commented 2 months ago

it should work fine for melody conditioning if the base model supported melody conditioning

adenta commented 2 months ago

The first training run I did was not on musicgen melody. The second training run I am trying to do on musicgen melody hit a memory limit on an A10. I think this is an infrastructure problem from my googling (some sort of quirk of the A100 cluster perhaps)

Will continue to post updates as I have them for posterity