facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.38k stars 6.4k forks source link

fail to convert data2vec2 model with torch.jit.script #4967

Closed oewi closed 1 year ago

oewi commented 1 year ago

❓ Questions and Help

Before asking:

  1. search the issues.
  2. search the docs.

What is your question?

tried to convert data2vec2 model with torch.jit.script but got the following error

  File "/opt/conda/lib/python3.8/site-packages/torch/_jit_internal.py", line 375, in get_type_hint_captures                                                   
    src = inspect.getsource(fn)                                                                                                                               
  File "/opt/conda/lib/python3.8/inspect.py", line 1012, in getsource          
    lines, lnum = getsourcelines(object)                                                                                                                      
  File "/opt/conda/lib/python3.8/inspect.py", line 994, in getsourcelines                                                                                     
    lines, lnum = findsource(object)                                                                                                                          
  File "/opt/conda/lib/python3.8/inspect.py", line 813, in findsource                                                                                         
    raise OSError('could not get source code')

the error is raised because it fails to get source file from the object(<function __createfn__..__init_\ at 0x7fda58ad08b0>), which is the instance of the module(<module 'examples.data2vec.models.modalities.audio' from '/fairseq/examples/data2vec/models/modalities/audio.py'>)

so i just directly give the correct path, but got another error

  File "/opt/conda/lib/python3.8/site-packages/torch/_jit_internal.py", line 455, in createResolutionCallbackForClassMethods                                  
    captures.update(get_type_hint_captures(fn))                                
  File "/opt/conda/lib/python3.8/site-packages/torch/_jit_internal.py", line 396, in get_type_hint_captures                                                   
    raise RuntimeError(f"Expected {fn} to be a function")                                                                                                     
RuntimeError: Expected <function __create_fn__.<locals>.__init__ at 0x7fea5ae888b0> to be a function                                                          

During handling of the above exception, another exception occurred:
...
  File "/opt/conda/lib/python3.8/site-packages/torch/jit/_recursive.py", line 230, in infer_concrete_type_builder                                             
    sub_concrete_type = get_module_concrete_type(item, share_types)                                                                                           
  File "/opt/conda/lib/python3.8/site-packages/torch/jit/_recursive.py", line 424, in get_module_concrete_type                                                
    concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module)                                                                                
  File "/opt/conda/lib/python3.8/site-packages/torch/jit/_recursive.py", line 365, in get_or_create_concrete_type                                             
    concrete_type_builder = infer_concrete_type_builder(nn_module)                                                                                            
  File "/opt/conda/lib/python3.8/site-packages/torch/jit/_recursive.py", line 329, in infer_concrete_type_builder                                             
    attr_type, inferred = infer_type(name, value)                                                                                                             
  File "/opt/conda/lib/python3.8/site-packages/torch/jit/_recursive.py", line 183, in infer_type                                                              
    raise RuntimeError(                                                                                                                                       
RuntimeError: Error inferring type for modality_cfg: {'type': <Modality.AUDIO: 1>, 'prenet_depth': 4, 'prenet_layerdrop': 0.1, 'prenet_dropout': 0.0, 'start_d
rop_path_rate': 0.0, 'end_drop_path_rate': 0.0, 'num_extra_tokens': 0, 'init_extra_token_zero': True, 'mask_noise_std': 0.01, 'mask_prob_min': None, 'mask_pro
b': 0.45, 'inverse_mask': False, 'mask_prob_adjust': 0.05, 'keep_masked_pct': 0.0, 'mask_length': 10, 'add_masks': False, 'remove_masks': False, 'mask_dropout
': 0.0, 'encoder_zero_mask': False, 'mask_channel_prob': 0.1, 'mask_channel_length': 64, 'ema_local_encoder': False, 'local_grad_mult': 0.0, 'use_alibi_encode
r': True, 'alibi_scale': 1.0, 'learned_alibi': False, 'alibi_max_pos': None, 'learned_alibi_scale': True, 'learned_alibi_scale_per_head': True, 'learned_alibi
_scale_per_layer': False, 'num_alibi_heads': 12, 'model_depth': 8, 'decoder': {'decoder_dim': 384, 'decoder_groups': 16, 'decoder_kernel': 7, 'decoder_layers'
: 4, 'input_dropout': 0.1, 'add_positions_masked': False, 'add_positions_all': False, 'decoder_residual': True, 'projection_layers': 1, 'projection_ratio': 2.
0}, 'extractor_mode': 'layer_norm', 'feature_encoder_spec': '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]', 'conv_pos_width': 95, 'conv_pos_
groups': 16, 'conv_pos_depth': 5, 'conv_pos_pre_ln': False}: Expected <function __create_fn__.<locals>.__init__ at 0x7fea5ae888b0> to be a function

any idea? Thank you for the help

What's your environment?

oewi commented 1 year ago

solved

ArtyZe commented 1 year ago

the error is raised because it fails to get source file from the object(<function __create_fn..init__ at 0x7fda58ad08b0>), which is the instance of the module(<module 'examples.data2vec.models.modalities.audio' from '/fairseq/examples/data2vec/models/modalities/audio.py'>)

hello, how could you be sure fn is for examples.data2vec.models.modalities.audio ? and add path to where ? thanks