Stability-AI / stable-audio-tools

Generative models for conditional audio generation
MIT License
2.47k stars 231 forks source link

PicklingError: Can't pickle <function get_custom_metadata at 0x00000216EE5ADDA0>: import of module 'metadata_module' failed #45

Open TheZaind opened 6 months ago

TheZaind commented 6 months ago

Hey, if i want to train my model with costum audios and promts via metadata i just get this traceback:

PicklingError: Can't pickle <function get_custom_metadata at 0x00000216EE5ADDA0>: import of module 'metadata_module' failed Traceback (most recent call last): File "", line 1, in File "g:\Programms\stable audio\stable-audio-tools.conda\Lib\multiprocessing\spawn.py", line 122, in spawn_main exitcode = _main(fd, parent_sentinel) ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "g:\Programms\stable audio\stable-audio-tools.conda\Lib\multiprocessing\spawn.py", line 132, in _main self = reduction.pickle.load(from_parent) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ EOFError: Ran out of input Traceback (most recent call last): File "g:\Programms\stable audio\stable-audio-tools.conda\Lib\site-packages\pytorch_lightning\trainer\call.py", line 44, in _call_and_handle_interrupt return trainer_fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "g:\Programms\stable audio\stable-audio-tools.conda\Lib\site-packages\pytorch_lightning\trainer\trainer.py", line 581, in _fit_impl self._run(model, ckpt_path=ckpt_path) File "g:\Programms\stable audio\stable-audio-tools.conda\Lib\site-packages\pytorch_lightning\trainer\trainer.py", line 990, in _run results = self._run_stage() ^^^^^^^^^^^^^^^^^ File "g:\Programms\stable audio\stable-audio-tools.conda\Lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1036, in _run_stage self.fit_loop.run() File "g:\Programms\stable audio\stable-audio-tools.conda\Lib\site-packages\pytorch_lightning\loops\fit_loop.py", line 194, in run self.setup_data() File "g:\Programms\stable audio\stable-audio-tools.conda\Lib\site-packages\pytorch_lightning\loops\fit_loop.py", line 258, in setup_data iter(self._data_fetcher) # creates the iterator inside the fetcher ^^^^^^^^^^^^^^^^^^^^^^^^ File "g:\Programms\stable audio\stable-audio-tools.conda\Lib\site-packages\pytorch_lightning\loops\fetchers.py", line 99, in iter super().iter() File "g:\Programms\stable audio\stable-audio-tools.conda\Lib\site-packages\pytorch_lightning\loops\fetchers.py", line 48, in iter self.iterator = iter(self.combined_loader) ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "g:\Programms\stable audio\stable-audio-tools.conda\Lib\site-packages\pytorch_lightning\utilities\combined_loader.py", line 335, in iter iter(iterator) File "g:\Programms\stable audio\stable-audio-tools.conda\Lib\site-packages\pytorch_lightning\utilities\combined_loader.py", line 87, in iter super().iter() File "g:\Programms\stable audio\stable-audio-tools.conda\Lib\site-packages\pytorch_lightning\utilities\combined_loader.py", line 40, in iter self.iterators = [iter(iterable) for iterable in self.iterables] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "g:\Programms\stable audio\stable-audio-tools.conda\Lib\site-packages\pytorch_lightning\utilities\combined_loader.py", line 40, in self.iterators = [iter(iterable) for iterable in self.iterables] ^^^^^^^^^^^^^^ File "g:\Programms\stable audio\stable-audio-tools.conda\Lib\site-packages\torch\utils\data\dataloader.py", line 434, in iter self._iterator = self._get_iterator() ^^^^^^^^^^^^^^^^^^^^ File "g:\Programms\stable audio\stable-audio-tools.conda\Lib\site-packages\torch\utils\data\dataloader.py", line 387, in _get_iterator return _MultiProcessingDataLoaderIter(self) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "g:\Programms\stable audio\stable-audio-tools.conda\Lib\site-packages\torch\utils\data\dataloader.py", line 1040, in init w.start() File "g:\Programms\stable audio\stable-audio-tools.conda\Lib\multiprocessing\process.py", line 121, in start self._popen = self._Popen(self) ^^^^^^^^^^^^^^^^^ File "g:\Programms\stable audio\stable-audio-tools.conda\Lib\multiprocessing\context.py", line 224, in _Popen return _default_context.get_context().Process._Popen(process_obj) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "g:\Programms\stable audio\stable-audio-tools.conda\Lib\multiprocessing\context.py", line 336, in _Popen return Popen(process_obj) ^^^^^^^^^^^^^^^^^^ File "g:\Programms\stable audio\stable-audio-tools.conda\Lib\multiprocessing\popen_spawn_win32.py", line 95, in init reduction.dump(process_obj, to_child) File "g:\Programms\stable audio\stable-audio-tools.conda\Lib\multiprocessing\reduction.py", line 60, in dump ForkingPickler(file, protocol).dump(obj) _pickle.PicklingError: Can't pickle <function get_custom_metadata at 0x00000216EE5ADDA0>: import of module 'metadata_module' failed

During handling of the above exception, another exception occurred:

Traceback (most recent call last): File "G:\Programms\stable audio\stable-audio-tools\train.py", line 125, in main() File "G:\Programms\stable audio\stable-audio-tools\train.py", line 120, in main trainer.fit(training_wrapper, train_dl, File "g:\Programms\stable audio\stable-audio-tools.conda\Lib\site-packages\pytorch_lightning\trainer\trainer.py", line 545, in fit call._call_and_handle_interrupt( File "g:\Programms\stable audio\stable-audio-tools.conda\Lib\site-packages\pytorch_lightning\trainer\call.py", line 68, in _call_and_handle_interrupt trainer._teardown() File "g:\Programms\stable audio\stable-audio-tools.conda\Lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1017, in _teardown loop.teardown() File "g:\Programms\stable audio\stable-audio-tools.conda\Lib\site-packages\pytorch_lightning\loops\fit_loop.py", line 407, in teardown self._data_fetcher.teardown() File "g:\Programms\stable audio\stable-audio-tools.conda\Lib\site-packages\pytorch_lightning\loops\fetchers.py", line 75, in teardown self.reset() File "g:\Programms\stable audio\stable-audio-tools.conda\Lib\site-packages\pytorch_lightning\loops\fetchers.py", line 134, in reset super().reset() File "g:\Programms\stable audio\stable-audio-tools.conda\Lib\site-packages\pytorch_lightning\loops\fetchers.py", line 71, in reset self.length = sized_len(self.combined_loader) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "g:\Programms\stable audio\stable-audio-tools.conda\Lib\site-packages\lightning_fabric\utilities\data.py", line 51, in sized_len length = len(dataloader) # type: ignore [arg-type] ^^^^^^^^^^^^^^^ File "g:\Programms\stable audio\stable-audio-tools.conda\Lib\site-packages\pytorch_lightning\utilities\combined_loader.py", line 342, in len raise RuntimeError("Please call iter(combined_loader) first.") RuntimeError: Please call iter(combined_loader) first.

Any idea how to fix it?

My custom_metadata.py looks like this:

`import os

def get_custom_metadata(info, audio):

Der Pfad zur Audiodatei

audio_path = info["relpath"]

# Der Pfad zur entsprechenden Textdatei
text_path = os.path.splitext(audio_path)[0] + '.txt'

# Lesen Sie den Text aus der Textdatei
with open(text_path, 'r') as f:
    text = f.read()

# Geben Sie den Text als "prompt" zurück
return {"prompt": text}

`

my dataset config: { "dataset_type": "audio_dir", "datasets": [ { "id": "my_audio", "path": "./trainingdata", "random_crop": true } ], "custom_metadata_module": "./custom_metadata.py" }

and my model config:

{ "model_type": "diffusion_cond", "sample_size": 4194304, "sample_rate": 44100, "audio_channels": 2, "model": { "pretransform": { "type": "autoencoder", "iterate_batch": true, "config": { "encoder": { "type": "dac", "config": { "in_channels": 2, "latent_dim": 128, "d_model": 128, "strides": [4, 4, 8, 8] } }, "decoder": { "type": "dac", "config": { "out_channels": 2, "latent_dim": 64, "channels": 1536, "rates": [8, 8, 4, 4] } }, "bottleneck": { "type": "vae" }, "latent_dim": 64, "downsampling_ratio": 1024, "io_channels": 2 } }, "conditioning": { "configs": [ { "id": "prompt", "type": "clap_text", "config": { "audio_model_type": "HTSAT-base", "enable_fusion": true, "clap_ckpt_path": "./clapmodel/music_audioset_epoch_15_esc_90.14.pt", "use_text_features": true, "feature_layer_ix": -2 } }, { "id": "seconds_start", "type": "int", "config": { "min_val": 0, "max_val": 512 } }, { "id": "seconds_total", "type": "int", "config": { "min_val": 0, "max_val": 512 } } ], "cond_dim": 768 }, "diffusion": { "type": "adp_cfg_1d", "cross_attention_cond_ids": ["prompt", "seconds_start", "seconds_total"], "config": { "in_channels": 64, "context_embedding_features": 768, "context_embedding_max_length": 79, "channels": 256, "resnet_groups": 16, "kernel_multiplier_downsample": 2, "multipliers": [4, 4, 4, 5, 5], "factors": [1, 2, 2, 4], "num_blocks": [2, 2, 2, 2], "attentions": [1, 3, 3, 3, 3], "attention_heads": 16, "attention_multiplier": 4, "use_nearest_upsample": false, "use_skip_scale": true, "use_context_time": true } }, "io_channels": 64 }, "training": { "learning_rate": 4e-5, "demo": { "demo_every": 2000, "demo_steps": 250, "num_demos": 4, "demo_cond": [ {"prompt": "80s style Whipcrack snare drum loop, 120BPM, retro, funk, energetic, nostalgia, dance, disco", "seconds_start": 0, "seconds_total": 30}, {"prompt": "Dubstep style drum loop 5, 110BPM, fast, energy, meditative, spiritual, zen, calming, focus, introspection", "seconds_start": 0, "seconds_total": 30}, {"prompt": "Dance Pop style Synth Chorus loop, fast, energy, fun, lively, upbeat, catchy, indie, fresh, vibrant", "seconds_start": 0, "seconds_total": 30}, {"prompt": "Hip hop style piano loop, 150BPM, Key A, joyful, playful, funny, upbeat, lively, cheerful, entertaining", "seconds_start": 0, "seconds_total": 30} ], "demo_cfg_scales": [3, 6, 9] } } }

Using Win 11.

Thanks! :)

fred-dev commented 6 months ago

I'm experiencing a similar issue. It appears that the function's serialization method is incompatible with Windows, though it works on Linux. Unfortunately, I don't have a solution to offer at the moment, as I'm facing the same challenge.

TheZaind commented 6 months ago

ahhhh... damm it. But thanks for the info, Fred. It's great to see you are very active here and helpful :)

0xdevalias commented 2 months ago

See also: