lucidrains / musiclm-pytorch

Implementation of MusicLM, Google's new SOTA model for music generation using attention networks, in Pytorch
MIT License
3.15k stars 254 forks source link

beartype type-checking causes this error when importing #67

Open AnthonyYeh opened 7 months ago

AnthonyYeh commented 7 months ago

code :

from musiclm_pytorch import MuLaN, MuLaNTrainer, AudioSpectrogramTransformer, TextTransformer, MuLaNEmbedQuantizer, MusicLM
from audiolm_pytorch import SoundStream, SoundStreamTrainer, HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer, HubertWithKmeans, CoarseTransformer, CoarseTransformerWrapper, CoarseTransformerTrainer, FineTransformer, FineTransformerWrapper, FineTransformerTrainer, AudioLM

error(traceback):

2024-04-03 02:05:51 | INFO | fairseq.tasks.text_to_speech | Please install tensorboardX: pip install tensorboardX
---------------------------------------------------------------------------
BeartypeDecorHintParamDefaultViolation    Traceback (most recent call last)
Cell In[43], line 1
----> 1 from musiclm_pytorch import MuLaN, MuLaNTrainer, AudioSpectrogramTransformer, TextTransformer, MuLaNEmbedQuantizer, MusicLM
      2 from audiolm_pytorch import SoundStream, SoundStreamTrainer, HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer, HubertWithKmeans, CoarseTransformer, CoarseTransformerWrapper, CoarseTransformerTrainer, FineTransformer, FineTransformerWrapper, FineTransformerTrainer, AudioLM

File /usr/local/lib/python3.10/dist-packages/musiclm_pytorch/__init__.py:1
----> 1 from musiclm_pytorch.musiclm_pytorch import (
      2     MuLaN,
      3     MuLaNEmbedQuantizer,
      4     MusicLM,
      5     AudioSpectrogramTransformer,
      6     TextTransformer,
      7     SigmoidContrastiveLearning,
      8     SoftmaxContrastiveLearning
      9 )
     11 from musiclm_pytorch.trainer import MuLaNTrainer

File /usr/local/lib/python3.10/dist-packages/musiclm_pytorch/musiclm_pytorch.py:10
      6 from torch import nn, einsum
      8 from torchaudio.transforms import Spectrogram, TimeStretch, FrequencyMasking, TimeMasking
---> 10 from audiolm_pytorch import AudioLM
     11 from audiolm_pytorch.utils import AudioConditionerBase
     13 import torch.distributed as dist

File /usr/local/lib/python3.10/dist-packages/audiolm_pytorch/__init__.py:18
     15 from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec
     16 from audiolm_pytorch.hubert_kmeans import HubertWithKmeans
---> 18 from audiolm_pytorch.trainer import SoundStreamTrainer, SemanticTransformerTrainer, FineTransformerTrainer, CoarseTransformerTrainer
     20 from audiolm_pytorch.audiolm_pytorch import get_embeds

File /usr/local/lib/python3.10/dist-packages/audiolm_pytorch/trainer.py:210
    206                 self.scheduler.step()
    208 # main trainer class
--> 210 class SoundStreamTrainer(nn.Module):
    211     @beartype
    212     def __init__(
    213         self,
   (...)
    254         force_clear_prev_results: bool = None  # set to True | False to skip the prompt
    255     ):
    256         """
    257         Initialize with a SoundStream instance and either a folder containing audio data or
    258         train/val DataLoader instances.
    259         """

File /usr/local/lib/python3.10/dist-packages/audiolm_pytorch/trainer.py:212, in SoundStreamTrainer()
    210 class SoundStreamTrainer(nn.Module):
    211     @beartype
--> 212     def __init__(
    213         self,
    214         soundstream: SoundStream,
    215         *,
    216         num_train_steps: int,
    217         batch_size: int,
    218         data_max_length: int = None,
    219         data_max_length_seconds: Union[int, float] = None,
    220         folder: str = None,
    221         dataset: Optional[Dataset] = None,
    222         val_dataset: Optional[Dataset] = None,
    223         train_dataloader: Optional[DataLoader] = None,
    224         val_dataloader: Optional[DataLoader] = None,
    225         lr: float = 2e-4,
    226         grad_accum_every: int = 4,
    227         wd: float = 0.,
    228         warmup_steps: int = 1000,
    229         scheduler: Optional[Type[_LRScheduler]] = None,
    230         scheduler_kwargs: dict = dict(),
    231         discr_warmup_steps: Optional[int] = None,
    232         discr_scheduler: Optional[Type[_LRScheduler]] = None,
    233         discr_scheduler_kwargs: dict = dict(),
    234         max_grad_norm: float = 0.5,
    235         discr_max_grad_norm: float = None,
    236         save_results_every: int = 100,
    237         save_model_every: int = 1000,
    238         log_losses_every: int = 1,
    239         results_folder: str = './results',
    240         valid_frac: float = 0.05,
    241         random_split_seed: int = 42,
    242         use_ema: bool = True,
    243         ema_beta: float = 0.995,
    244         ema_update_after_step: int = 500,
    245         ema_update_every: int = 10,
    246         apply_grad_penalty_every: int = 4,
    247         dl_num_workers: int = 0,
    248         accelerator: Optional[Accelerator] = None,
    249         accelerate_kwargs: dict = dict(),
    250         init_process_group_timeout_seconds = 1800,
    251         dataloader_drop_last = True,
    252         split_batches = False,
    253         use_wandb_tracking = False,
    254         force_clear_prev_results: bool = None  # set to True | False to skip the prompt
    255     ):
    256         """
    257         Initialize with a SoundStream instance and either a folder containing audio data or
    258         train/val DataLoader instances.
    259         """
    260         super().__init__()

File /usr/local/lib/python3.10/dist-packages/beartype/_decor/decorcache.py:77, in beartype(obj, conf)
     62 # Else, "conf" is a configuration.
     63 #
     64 # If passed an object to be decorated, this decorator is in decoration
   (...)
     74 # here; callers that are doing this are sufficiently intelligent to be
     75 # trusted to violate PEP 561-compliance if they so choose. So... *shrug*
     76 elif obj is not None:
---> 77     return beartype_object(obj, conf)
     78 # Else, we were passed *NO* object to be decorated. In this case, this
     79 # decorator is in configuration rather than decoration mode.
     80 
   (...)
     86 # "None" otherwise (i.e., if this is the first call to this public
     87 # decorator passed this configuration in configuration mode). Phew!
     88 beartype_confed_cached = _bear_conf_to_decor.get(conf)

File /usr/local/lib/python3.10/dist-packages/beartype/_decor/decorcore.py:87, in beartype_object(obj, conf, **kwargs)
     51 '''
     52 Decorate the passed **beartypeable** (i.e., caller-defined object that may
     53 be decorated by the :func:`beartype.beartype` decorator) with optimal
   (...)
     81     Memoized parent decorator wrapping this unmemoized child decorator.
     82 '''
     83 # print(f'Decorating object {repr(obj)}...')
     84 
     85 # Return either...
     86 return (
---> 87     _beartype_object_fatal(obj, conf=conf, **kwargs)
     88     # If this beartype configuration requests that this decorator raise
     89     # fatal exceptions at decoration time, defer to the lower-level
     90     # decorator doing so;
     91     if conf.warning_cls_on_decorator_exception is None else
     92     # Else, this beartype configuration requests that this decorator emit
     93     # fatal warnings at decoration time. In this case, defer to the
     94     # lower-level decorator doing so.
     95     _beartype_object_nonfatal(obj, conf=conf, **kwargs)
     96 )

File /usr/local/lib/python3.10/dist-packages/beartype/_decor/decorcore.py:136, in _beartype_object_fatal(obj, **kwargs)
    100 '''
    101 Decorate the passed **beartypeable** (i.e., caller-defined object that may
    102 be decorated by the :func:`beartype.beartype` decorator) with optimal
   (...)
    126     Memoized parent decorator wrapping this unmemoized child decorator.
    127 '''
    129 # Return either...
    130 return (
    131     # If this object is a class, this class decorated with type-checking.
    132     beartype_type(obj, **kwargs)  # type: ignore[return-value]
    133     if isinstance(obj, type) else
    134     # Else, this object is a non-class. In this case, this non-class
    135     # decorated with type-checking.
--> 136     beartype_nontype(obj, **kwargs)  # type: ignore[return-value]
    137 )

File /usr/local/lib/python3.10/dist-packages/beartype/_decor/_decornontype.py:174, in beartype_nontype(obj, **kwargs)
    169     return beartype_func_contextlib_contextmanager(obj, **kwargs)  # type: ignore[return-value]
    170 # Else, this function is *NOT* a @contextlib.contextmanager-based isomorphic
    171 # decorator closure.
    172 
    173 # Return a new callable decorating that callable with type-checking.
--> 174 return beartype_func(obj, **kwargs)

File /usr/local/lib/python3.10/dist-packages/beartype/_decor/_decornontype.py:239, in beartype_func(func, conf, **kwargs)
    236 bear_call = make_beartype_call(func, conf, **kwargs)  # pyright: ignore[reportGeneralTypeIssues]
    238 # Generate the raw string of Python statements implementing this wrapper.
--> 239 func_wrapper_code = generate_code(bear_call)
    241 # If that callable requires *NO* type-checking, silently reduce to a noop
    242 # and thus the identity decorator by returning that callable as is.
    243 if not func_wrapper_code:

File /usr/local/lib/python3.10/dist-packages/beartype/_decor/wrap/wrapmain.py:118, in generate_code(bear_call)
     38 '''
     39 Generate a Python code snippet dynamically defining the wrapper function
     40 type-checking the passed decorated callable.
   (...)
    112     happen, a private non-human-readable exception is raised in this case.
    113 '''
    115 # Python code snippet type-checking all callable parameters if one or more
    116 # such parameters are annotated with unignorable type hints *OR* the empty
    117 # string otherwise.
--> 118 code_check_params = _code_check_args(bear_call)
    120 # Python code snippet type-checking the callable return if this return is
    121 # annotated with an unignorable type hint *OR* the empty string otherwise.
    122 code_check_return = _code_check_return(bear_call)

File /usr/local/lib/python3.10/dist-packages/beartype/_decor/wrap/_wrapargs.py:309, in code_check_args(bear_call)
    303         # Else, *NO* warnings were issued.
    304     # If any exception was raised, reraise this exception with each
    305     # placeholder substring (i.e., "EXCEPTION_PLACEHOLDER" instance)
    306     # replaced by a human-readable description of this callable and
    307     # annotated parameter.
    308     except Exception as exception:
--> 309         reraise_exception_placeholder(
    310             exception=exception,
    311             #FIXME: Embed the kind of parameter both here and above as well
    312             #(e.g., "positional-only", "keyword-only", "variadic
    313             #positional"), ideally by improving the existing
    314             #prefix_callable_arg_name() function to introspect this kind from
    315             #the callable code object.
    316             target_str=prefix_callable_arg_name(
    317                 func=bear_call.func_wrappee,
    318                 arg_name=arg_name,
    319                 is_color=bear_call.conf.is_color,
    320             ),
    321         )
    323 # If this callable accepts one or more positional type-checked parameters,
    324 # prefix this code by a snippet localizing the number of these parameters.
    325 if is_args_positional:

File /usr/local/lib/python3.10/dist-packages/beartype/_util/error/utilerrraise.py:138, in reraise_exception_placeholder(exception, target_str, source_str)
    132         exception.args = (exception_message,) + exception.args[1:]
    133     # Else, this message remains preserved as is.
    134 # Else, this is an unconventional exception. In this case, preserve this
    135 # exception as is.
    136 
    137 # Re-raise this exception while preserving its original traceback.
--> 138 raise exception.with_traceback(exception.__traceback__)

File /usr/local/lib/python3.10/dist-packages/beartype/_decor/wrap/_wrapargs.py:205, in code_check_args(bear_call)
    200     continue
    201 # Else, this hint is unignorable.
    202 
    203 # If this parameter is optional *AND* the default value of this
    204 # optional parameter violates this hint, raise an exception.
--> 205 _die_if_arg_default_unbearable(
    206     bear_call=bear_call, arg_default=arg_default, hint=hint)
    207 # Else, this parameter is either optional *OR* the default value
    208 # of this optional parameter satisfies this hint.
    209 
   (...)
    216 # beartype would fail to reduce to a noop for otherwise
    217 # ignorable callables -- which would be rather bad, really.
    218 if arg_kind in _ARG_KINDS_POSITIONAL:

File /usr/local/lib/python3.10/dist-packages/beartype/_decor/wrap/_wrapargs.py:473, in _die_if_arg_default_unbearable(bear_call, arg_default, hint)
    470 conf = BeartypeConf(**conf_kwargs)
    472 # Raise this type of violation exception.
--> 473 die_if_unbearable(
    474     obj=arg_default,
    475     hint=hint,
    476     conf=conf,
    477     exception_prefix=EXCEPTION_PREFIX_DEFAULT,
    478 )

File /usr/local/lib/python3.10/dist-packages/beartype/door/_doorcheck.py:106, in die_if_unbearable(obj, hint, conf, exception_prefix)
    102 func_raiser = make_func_raiser(hint, conf, exception_prefix)
    104 # Either raise an exception or emit a warning only if the passed object
    105 # violates this hint.
--> 106 func_raiser(obj)

File <@beartype(__beartype_checker_38) at 0x56533f6f2730>:19, in __beartype_checker_38(__beartype_pith_0, __beartype_exception_prefix, __beartype_get_violation, __beartype_hint, __beartype_conf)

BeartypeDecorHintParamDefaultViolation: Method audiolm_pytorch.trainer.SoundStreamTrainer.__init__() parameter "data_max_length" default value "None" violates type hint <class 'int'>, as <class "builtins.NoneType"> "None" not instance of int.