facebookresearch / audiocraft

Audiocraft is a library for audio processing and generation with deep learning. It features the state-of-the-art EnCodec audio compressor / tokenizer, along with MusicGen, a simple and controllable music generation LM with textual and melodic conditioning.
MIT License
20.17k stars 2.01k forks source link

Should 'decay_rates' and 'frequencies' in XPos and ROPE be non-persistent? #417

Open dramaticmeow opened 4 months ago

dramaticmeow commented 4 months ago

Hello audiocraft Team,

I've encountered a couple of issues related to model parameter registration and Exponential Moving Average (EMA) compatibility which I believe could impact the model's performance, efficiency, and functionality during the validation phase.

Issue 1: Duplicate Parameters in State Dict

The model seems to register rope_matrix and xpos parameters in a way that results in each instance of the StreamingMultiheadAttention having its own copy of these parameters in the state dict. Consequently, when the state dict is saved, it includes multiple redundant copies of these parameters. This not only increases the model's size unnecessarily but also affects the efficiency of model saving/loading procedures.

Expected Behavior:

Parameters like rope_matrix and xpos should ideally be registered in a way that they are shared across different instances of the model components, ensuring that they appear only once in the model's state dict. This would make the model more memory-efficient and streamline the saving/loading process.

Possible Solution:

A potential solution could be to use the register_buffer method with the persistent=False flag for these parameters, ensuring that they are not redundantly included in the state dict while still being part of the model's state and properly moved to the specified device during model.to(device) calls.

Issue 2: EMA Compatibility with State Manager

Additionally, the project's implementation of Exponential Moving Average (EMA) uses named_buffers instead of the state_dict, which leads to incompatibility with the state manager. This incompatibility results in errors during the validation phase, specifically when the EMA state needs to be applied or reverted.

Expected Behavior:

EMA should be compatible with the state manager, ensuring a smooth transition of model states between training and validation phases without errors.

Resolution:

Implementing the persistent=False flag for the buffers used in EMA might resolve the compatibility issues with the state manager, ensuring the EMA-related buffers are not redundantly included in the state_dict.

I hope this detailed description helps in understanding and addressing the issues. I appreciate your prompt attention to these matters and am looking forward to any suggestions, fixes, or insights you might provide.

Best regards