Stability-AI / stable-audio-tools

Generative models for conditional audio generation
MIT License
2.6k stars 245 forks source link

Pickling Error with custom_metadata_module in DataLoader #34

Open HawkSP opened 8 months ago

HawkSP commented 8 months ago

Description

When running the train.py script to train an audio model, I encounter a pickling error related to the get_custom_metadata function specified in the local_training_example.json configuration. The error seems to indicate an issue with importing the module metadata_module during the multiprocessing of PyTorch's DataLoader.

Steps to Reproduce

Run the train.py script with the provided dataset and model configuration JSON files. Encounter the _pickle.PicklingError during the script execution.

Expected Behavior

The DataLoader should be able to use the get_custom_metadata function from the metadata_module without any pickling errors.

Actual Behavior

The script throws a _pickle.PicklingError, indicating the function get_custom_metadata cannot be pickled.

Environment

Operating System: Windows 10 Python Version: 3.9.13 PyTorch Version: 2.1.1 (with CUDA 12.1 support) CUDA Version: 12.1 cuDNN Version: 8801

Configuration Files

local_training_example.json:

{
  "dataset_type": "audio_dir",
  "datasets": [
    {
      "id": "nsynth_train",
      "path": "E:/nsynth/nsynth-train/audio/"
    },
    {
      "id": "nsynth_valid",
      "path": "E:/nsynth/nsynth-valid/audio/"
    },
    {
      "id": "nsynth_test",
      "path": "E:/nsynth/nsynth-test/audio/"
    }
  ],
  "random_crop": true,
  "custom_metadata_module": "./stable_audio_tools/configs/dataset_configs/custom_metadata/custom_md_example.py"
}

custom_md_example.py:

def get_custom_metadata(info, audio):
  # Use relative path as the prompt
  return {"prompt": info["relpath"]}

console output:

(venv) PS C:...\stable-audio-tools-main> python ./train.py --dataset-config ./stable_audio_tools/configs/dataset_configs/local_training_example.json --model-config ./stable_audio_tools/configs/model_configs/autoencoders/stable_audio_1_0_vae.json --name dedai
Found 305979 files
A matching Triton is not available, some optimizations will not be enabled.
Error caught was: No module named 'triton'
C:...\venv\lib\site-packages\torch\nn\utils\weight_norm.py:30: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.
warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.")
...
EOFError: Ran out of input
_pickle.PicklingError: Can't pickle <function get_custom_metadata at 0x0000014C62C10820>: import of module 'metadata_module' failed
...
TheZaind commented 7 months ago

hey same problem, could you solve it?

fred-dev commented 7 months ago

This issue seems to be related to the way Windows handles spawning and serialization. I have a potential fix, but I cannot fully verify its effectiveness as I'm unable to get other parts of the system to work properly.

If you're interested in trying this solution, you can follow the steps below:

Install dill in your environment.

Make the following edits in dataset.py:

  1. Add an import statement for dill at the top: import dill

  2. Use dill to serialse the function: Change line 146 from: self.custom_metadata_fn = custom_metadata_fn to: self.custom_metadata_fn = dill.dumps(custom_metadata_fn)

  3. Use dill to deserialise the function before it is called: Change line 202 -> 205 from:

            if self.custom_metadata_fn is not None:
                custom_metadata = self.custom_metadata_fn(info, audio)
                info.update(custom_metadata)

    To:

            if self.custom_metadata_fn is not None:
                custom_metadata_fn_deserialised = dill.loads(self.custom_metadata_fn)
                custom_metadata = custom_metadata_fn_deserialised(info, audio)
                info.update(custom_metadata)

I am not pushing these changes or making a pull request to the repository as they have not been tested. However, this approach may allow you to proceed with some of your work. Please let me know if this solution resolves your issue.

0xdevalias commented 4 months ago

See also: