claritychallenge / clarity

Clarity Challenge toolkit - software for building Clarity Challenge systems
https://claritychallenge.github.io/clarity
MIT License
129 stars 54 forks source link

Tensor size mismatch when running CAD2 > Task2 baseline enhancement [BUG] #401

Closed awagenknecht closed 2 months ago

awagenknecht commented 3 months ago

Describe the bug After downloading and processing the data for the CAD2 > Task2 challenge, I attempted to run enhance.py to explore the baseline enhancement system. This resulted in tensor size mismatch errors on the load_separation_model() step (line 279). The parameters defined in the ConvTasNetStereo class do not match the pre-trained model checkpoint that is being loaded. I addressed the problem by changing three parameters in the ConvTasNetStereo definition in ConvTasNet/local/tasnet.py. I changed the following:

To Reproduce

  1. Download the CAD2 > Task2 data.
  2. In config.yaml, set the zenodo_download_path and root path.
  3. Process the data by running process_dataset/process_zenodo_download.py.
  4. Run enhance.py with default parameters.
  5. See error.

Expected behavior Based on the README in the baseline folder, I expect the baseline enhancement system to run with the default parameters and generate the enhanced .flac files.

Error Messages

File "/home/austin/clarity/recipes/cad2/task2/baseline/enhance.py", line 204, in load_separation_model models[instrument] = ConvTasNetStereo.from_pretrained( File "/home/austin/clarity/recipes/cad2/task2/baseline/enhance.py", line 279, in enhance separation_models = load_separation_model( File "/home/austin/clarity/recipes/cad2/task2/baseline/enhance.py", line 407, in enhance() RuntimeError: Error(s) in loading state_dict for ConvTasNetStereo: size mismatch for encoder.conv1d_U.weight: copying a param with shape torch.Size([256, 2, 20]) from checkpoint, the shape in current model is torch.Size([256, 1, 20]). size mismatch for separator.network.3.weight: copying a param with shape torch.Size([512, 256, 1]) from checkpoint, the shape in current model is torch.Size([1024, 256, 1]). size mismatch for decoder.basis_signals.weight: copying a param with shape torch.Size([40, 256]) from checkpoint, the shape in current model is torch.Size([20, 256]).

Environment Please include the following...

[ ] OS: Ubuntu 22.04.2 LTS (GNU/Linux 5.15.153.1-microsoft-standard-WSL2 x86_64) [ ] Python version: 3.8.19 [ ] clarity version: v0.6.0 [ ] Installed package versions:

groadabike commented 2 months ago

Hi @awagenknecht , Thank you for submitting this issue. I can run the baseline without getting this error. Can you please tell us if you are still experiencing this problem? What causality are you loading?

Thank you

awagenknecht commented 2 months ago

Hi @groadabike, thanks for looking into it. I'm new here, so I apologize if I'm missing something obvious.

I still get the error when running the baseline as provided. I am loading noncausal models.

I can avoid the error if I change the C and audio_channels parameters in the __init__ method of the ConvTasNetStereo class definition in the recipes/cad2/task2/ConvTasNet/local/tasnet.py file, but the baseline seems like it's set up so that I shouldn't have to touch this file.

The parameters in the ConvTasNetStereo class definition do not match the config being downloaded from HuggingFace.

groadabike commented 2 months ago

Hi @awagenknecht,

I am not sure why you are getting that error. As I understand, when you use the HuggingFace's from_pretrained method, the model is automatically initialised using the params saved in Huggingface. This is particulary useful if you have a model with different sizes, e.g., Whisper.

If you see the Whisper documentation in Huggingface https://huggingface.co/docs/transformers/en/model_doc/whisper, you can load the openai/whisper-tiny.en version or any of the other versions with different sizes https://huggingface.co/openai by just changing the mmodel tag.

Can you confirm that your load_separation_model in recipes/cad2/task2/baseline/enhanced.py looks like:

def load_separation_model(
    causality: str, device: torch.device, force_redownload: bool = True
) -> dict[str, ConvTasNetStereo]:
    """
    Load the separation model.
    Args:
        causality (str): Causality of the model (causal or noncausal).
        device (torch.device): Device to load the model.
        force_redownload (bool): Whether to force redownload the model.

    Returns:
        model: Separation model.
    """
    models = {}
    causal = {"causal": "Causal", "noncausal": "NonCausal"}

    for instrument in [
        "Bassoon",
        "Cello",
        "Clarinet",
        "Flute",
        "Oboe",
        "Sax",
        "Viola",
        "Violin",
    ]:
        logger.info(
            "Loading model "
            f"cadenzachallenge/ConvTasNet_{instrument}_{causal[causality]}"
        )
        models[instrument] = ConvTasNetStereo.from_pretrained(
            f"cadenzachallenge/ConvTasNet_{instrument}_{causal[causality]}",
            force_download=force_redownload,
        ).to(device)
    return models
groadabike commented 2 months ago

Now, looking more in details the error message, it look a bit odd.

  1. File "/home/austin/clarity/recipes/cad2/task2/baseline/enhance.py", line 204, in enhance
    separation_models = load_separation_model()

    Tha call is in line 279 and it should include 3 arguments

    separation_models = load_separation_model(
        config.separator.causality, device, config.separator.force_redownload
    )
  2. The next error line says:

    File "/home/austin/clarity/recipes/cad2/task2/baseline/enhance.py", line 204, in load_separation_model
    models[instrument] = ConvTasNetStereo.from_pretrained()

    The line number is correct, 204. However, it also should include arguments

    models[instrument] = ConvTasNetStereo.from_pretrained(
            f"cadenzachallenge/ConvTasNet_{instrument}_{causal[causality]}",
            force_download=force_redownload,
        ).to(device)

Does your baseline code include these arguments?

awagenknecht commented 2 months ago

Yes, the load_separation_model function is identical. The causality, device, and force_redownload arguments are being passed as expected to load_separation_model, and the model string and force_download are also being passed to ConvTasNetStereo.from_pretrained.

You're right, there are some mistakes in the error message. I was using ChatGPT to help debug and must have gotten mixed up which version of the error I was copying. facepalm Sorry for the confusion. I'll edit the original post with the correct error message.

The main content of the error is the same, though. It looks like all the arguments are being passed. I can see that the config.json is being downloaded from HuggingFace and cached with the correct parameters, but for some reason they are not being applied in the model initialization. This makes me think the issue is somewhere in the HuggingFace package on my end. So I don't think it's a Clarity bug anymore, and this can be closed if no one else is getting this error.

groadabike commented 2 months ago

HI @awagenknecht , Thank you for the clarification.

It would be very helpful to find the source of the error. Can you please confirm that you created a clean environment for the challenge? Also, can you share what huggingface-hub version do you have? I am using 0.23.0

awagenknecht commented 2 months ago

Updating huggingface-hub has fixed the problem. I previously had 0.21.4 and updated to 0.24.5. One of these might be the relevant fix:

I started from a clean conda environment, but I believe I may have overlooked a permissions issue when installing clarity that caused it to use packages already installed in my base environment.

That said, does clarity have a huggingface-hub version as a requirement? I just tried again in a clean environment, ensuring I fixed the permissions issue, and I needed to install huggingface-hub separately in order to run the task2 enhance.py.

Thanks so much for helping track this down!

groadabike commented 2 months ago

Hi @awagenknecht ,

Thank you for let us know. Glad to know that the issue was just a huggingface-hub version.

Regarding the requirement, thank you for letting us know. We have a extra requirements.txt in task one but not in task2.