facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.17k stars 6.37k forks source link

A layer_idx problem in wav2vec2 conformer, need help!!! #5386

Open jasonking1234 opened 9 months ago

jasonking1234 commented 9 months ago

my fairseq version is 0.12.2, i use airseq-hydra-train for base model is ok when i use fairseq-hydra-train for base model with conformer backbone it shows like this: fairseq-hydra-train task.data=/path/to/LibriSpeech/ distributed_training.distributed_world_size=1 --config-dir /root/fairseq/examples/wav2vec/config/pretraining/ --config-name wav2vec2_conformer_base_librispeech [2023-11-23 19:57:11,628] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect) 2023-11-23 19:57:13 | INFO | fairseq.tasks.text_to_speech | Please install tensorboardX: pip install tensorboardX /root/fairseq/fairseq_cli/hydra_train.py:25: UserWarning: The version_base parameter is not specified. Please specify a compatability version level, or None. Will assume defaults for version 1.1 @hydra.main(config_path=os.path.join("..", "fairseq", "config"), config_name="config") /root/miniconda/envs/asr/lib/python3.8/site-packages/hydra/core/default_element.py:124: UserWarning: In 'wav2vec2_conformer_base_librispeech': Usage of deprecated keyword in package header '# @package group'. See https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/changes_to_package_header for more information deprecation_warning( sys:1: UserWarning: 'wav2vec2_conformer_base_librispeech' is validated against ConfigStore schema with the same name. This behavior is deprecated in Hydra 1.1 and will be removed in Hydra 1.2. See https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/automatic_schema_matching for migration instructions. /root/miniconda/envs/asr/lib/python3.8/site-packages/hydra/main.py:94: UserWarning: 'wav2vec2_conformer_base_librispeech' is validated against ConfigStore schema with the same name. This behavior is deprecated in Hydra 1.1 and will be removed in Hydra 1.2. See https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/automatic_schema_matching for migration instructions. _run_hydra( /root/miniconda/envs/asr/lib/python3.8/site-packages/hydra/_internal/hydra.py:119: UserWarning: Future Hydra versions will no longer change working directory at job runtime by default. See https://hydra.cc/docs/1.2/upgrades/1.1_to_1.2/changes_to_job_working_dir/ for more information. ret = run_job( [2023-11-23 19:57:14,500][fairseq_cli.train][INFO] - {'_name': None, 'common': {'_name': None, 'no_progress_bar': False, 'log_interval': 200, 'log_format': 'json', 'log_file': None, 'aim_repo': None, 'aim_run_hash': None, 'tensorboard_logdir': None, 'wandb_project': None, 'azureml_logging': False, 'seed': 1, 'cpu': False, 'tpu': False, 'bf16': False, 'memory_efficient_bf16': False, 'fp16': True, 'memory_efficient_fp16': False, 'fp16_no_flatten_grads': False, 'fp16_init_scale': 128, 'fp16_scale_window': None, 'fp16_scale_tolerance': 0.0, 'on_cpu_convert_precision': False, 'min_loss_scale': 0.0001, 'threshold_loss_scale': None, 'amp': False, 'amp_batch_retries': 2, 'amp_init_scale': 128, 'amp_scale_window': None, 'user_dir': None, 'empty_cache_freq': 0, 'all_gather_list_size': 16384, 'model_parallel_size': 1, 'quantization_config_path': None, 'profile': False, 'reset_logging': False, 'suppress_crashes': False, 'use_plasma_view': False, 'plasma_path': '/tmp/plasma'}, 'common_eval': {'_name': None, 'path': None, 'post_process': None, 'quiet': False, 'model_overrides': '{}', 'results_path': None}, 'distributed_training': {'_name': None, 'distributed_world_size': 1, 'distributed_num_procs': 1, 'distributed_rank': 0, 'distributed_backend': 'nccl', 'distributed_init_method': None, 'distributed_port': -1, 'device_id': 0, 'distributed_no_spawn': False, 'ddp_backend': 'legacy_ddp', 'ddp_comm_hook': 'none', 'bucket_cap_mb': 25, 'fix_batches_to_gpus': False, 'find_unused_parameters': False, 'gradient_as_bucket_view': False, 'fast_stat_sync': False, 'heartbeat_timeout': -1, 'broadcast_buffers': False, 'slowmo_momentum': None, 'slowmo_base_algorithm': 'localsgd', 'localsgd_frequency': 3, 'nprocs_per_node': 1, 'pipeline_model_parallel': False, 'pipeline_balance': None, 'pipeline_devices': None, 'pipeline_chunks': 0, 'pipeline_encoder_balance': None, 'pipeline_encoder_devices': None, 'pipeline_decoder_balance': None, 'pipeline_decoder_devices': None, 'pipeline_checkpoint': 'never', 'zero_sharding': 'none', 'fp16': True, 'memory_efficient_fp16': False, 'tpu': False, 'no_reshard_after_forward': False, 'fp32_reduce_scatter': False, 'cpu_offload': False, 'use_sharded_state': False, 'not_fsdp_flatten_parameters': False}, 'dataset': {'_name': None, 'num_workers': 6, 'skip_invalid_size_inputs_valid_test': True, 'max_tokens': 1400000, 'batch_size': None, 'required_batch_size_multiple': 8, 'required_seq_len_multiple': 1, 'dataset_impl': None, 'data_buffer_size': 10, 'train_subset': 'train', 'valid_subset': 'valid', 'combine_valid_subsets': None, 'ignore_unused_valid_subsets': False, 'validate_interval': 1, 'validate_interval_updates': 0, 'validate_after_updates': 0, 'fixed_validation_seed': None, 'disable_validation': False, 'max_tokens_valid': 1400000, 'batch_size_valid': None, 'max_valid_steps': None, 'curriculum': 0, 'gen_subset': 'test', 'num_shards': 1, 'shard_id': 0, 'grouped_shuffling': False, 'update_epoch_batch_itr': False, 'update_ordered_indices_seed': False}, 'optimization': {'_name': None, 'max_epoch': 0, 'max_update': 400000, 'stop_time_hours': 0.0, 'clip_norm': 0.0, 'sentence_avg': False, 'update_freq': [1], 'lr': [0.0005], 'stop_min_lr': -1.0, 'use_bmuf': False, 'skip_remainder_batch': False, 'debug_param_names': False}, 'checkpoint': {'_name': None, 'save_dir': 'checkpoints', 'restore_file': 'checkpoint_last.pt', 'continue_once': None, 'finetune_from_model': None, 'reset_dataloader': False, 'reset_lr_scheduler': False, 'reset_meters': False, 'reset_optimizer': False, 'optimizer_overrides': '{}', 'save_interval': 1, 'save_interval_updates': 25000, 'keep_interval_updates': 1, 'keep_interval_updates_pattern': -1, 'keep_last_epochs': -1, 'keep_best_checkpoints': -1, 'no_save': False, 'no_epoch_checkpoints': True, 'no_last_checkpoints': False, 'no_save_optimizer_state': False, 'best_checkpoint_metric': 'loss', 'maximize_best_checkpoint_metric': False, 'patience': -1, 'checkpoint_suffix': '', 'checkpoint_shard_count': 1, 'load_checkpoint_on_all_dp_ranks': False, 'write_checkpoints_asynchronously': False, 'model_parallel_size': 1}, 'bmuf': {'_name': None, 'block_lr': 1.0, 'block_momentum': 0.875, 'global_sync_iter': 50, 'warmup_iterations': 500, 'use_nbm': False, 'average_sync': False, 'distributed_world_size': 1}, 'generation': {'_name': None, 'beam': 5, 'beam_mt': 0, 'nbest': 1, 'max_len_a': 0.0, 'max_len_b': 200, 'max_len_a_mt': 0.0, 'max_len_b_mt': 200, 'min_len': 1, 'match_source_len': False, 'unnormalized': False, 'no_early_stop': False, 'no_beamable_mm': False, 'lenpen': 1.0, 'lenpen_mt': 1.0, 'unkpen': 0.0, 'replace_unk': None, 'sacrebleu': False, 'score_reference': False, 'prefix_size': 0, 'no_repeat_ngram_size': 0, 'sampling': False, 'sampling_topk': -1, 'sampling_topp': -1.0, 'constraints': None, 'temperature': 1.0, 'diverse_beam_groups': -1, 'diverse_beam_strength': 0.5, 'diversity_rate': -1.0, 'print_alignment': None, 'print_step': False, 'lm_path': None, 'lm_weight': 0.0, 'iter_decode_eos_penalty': 0.0, 'iter_decode_max_iter': 10, 'iter_decode_force_max_iter': False, 'iter_decode_with_beam': 1, 'iter_decode_with_external_reranker': False, 'retain_iter_history': False, 'retain_dropout': False, 'retain_dropout_modules': None, 'decoding_format': None, 'no_seed_provided': False, 'eos_token': None}, 'eval_lm': {'_name': None, 'output_word_probs': False, 'output_word_stats': False, 'context_window': 0, 'softmax_batch': 9223372036854775807}, 'interactive': {'_name': None, 'buffer_size': 0, 'input': '-'}, 'model': {'_name': 'wav2vec2', 'extractor_mode': default, 'encoder_layers': 12, 'encoder_embed_dim': 768, 'encoder_ffn_embed_dim': 3072, 'encoder_attention_heads': 12, 'activation_fn': gelu, 'layer_type': conformer, 'dropout': 0.1, 'attention_dropout': 0.1, 'activation_dropout': 0.0, 'encoder_layerdrop': 0.05, 'dropout_input': 0.1, 'dropout_features': 0.1, 'final_dim': 256, 'layer_norm_first': False, 'conv_feature_layers': '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]', 'conv_bias': False, 'logit_temp': 0.1, 'quantize_targets': True, 'quantize_input': False, 'same_quantizer': False, 'target_glu': False, 'feature_grad_mult': 0.1, 'quantizer_depth': 1, 'quantizer_factor': 3, 'latent_vars': 320, 'latent_groups': 2, 'latent_dim': 0, 'mask_length': 10, 'mask_prob': 0.65, 'mask_selection': static, 'mask_other': 0.0, 'no_mask_overlap': False, 'mask_min_space': 1, 'require_same_masks': True, 'mask_dropout': 0.0, 'mask_channel_length': 10, 'mask_channel_prob': 0.0, 'mask_channel_before': False, 'mask_channel_selection': static, 'mask_channel_other': 0.0, 'no_mask_channel_overlap': False, 'mask_channel_min_space': 1, 'num_negatives': 100, 'negatives_from_everywhere': False, 'cross_sample_negatives': 0, 'codebook_negatives': 0, 'conv_pos': 128, 'conv_pos_groups': 16, 'pos_conv_depth': 1, 'latent_temp': [2.0, 0.5, 0.999995], 'max_positions': 100000, 'checkpoint_activations': False, 'required_seq_len_multiple': 2, 'crop_seq_to_multiple': 1, 'depthwise_conv_kernel_size': 31, 'attn_type': 'espnet', 'pos_enc_type': 'rel_pos', 'fp16': False, 'adp_num': -1, 'adp_dim': 64, 'adp_act_fn': 'relu', 'adp_trf_idx': 'all'}, 'task': {'_name': 'audio_pretraining', 'data': '/path/to/LibriSpeech/', 'labels': None, 'multi_corpus_keys': None, 'multi_corpus_sampling_weights': None, 'binarized_dataset': False, 'sample_rate': 16000, 'normalize': False, 'enable_padding': False, 'max_sample_size': 250000, 'min_sample_size': 32000, 'num_batch_buckets': 0, 'tpu': False, 'text_compression_level': none, 'rebuild_batches': True, 'precompute_mask_config': None, 'post_save_script': None, 'subsample': 1.0, 'seed': 1}, 'criterion': {'_name': 'wav2vec', 'infonce': True, 'loss_weights': [0.1, 10.0], 'log_keys': ['prob_perplexity', 'code_perplexity', 'temp']}, 'optimizer': {'_name': 'adam', 'adam_betas': '(0.9,0.98)', 'adam_eps': 1e-06, 'weight_decay': 0.01, 'use_old_adam': False, 'fp16_adam_stats': False, 'tpu': False, 'lr': [0.0005]}, 'lr_scheduler': {'_name': 'polynomial_decay', 'warmup_updates': 32000, 'force_anneal': None, 'end_learning_rate': 0.0, 'power': 1.0, 'total_num_update': 400000.0, 'lr': [0.0005]}, 'scoring': None, 'bpe': None, 'tokenizer': None, 'ema': {'_name': None, 'store_ema': False, 'ema_decay': 0.9999, 'ema_start_update': 0, 'ema_seed_model': None, 'ema_update_freq': 1, 'ema_fp32': False}, 'job_logging_cfg': {'version': 1, 'formatters': {'simple': {'format': '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'}}, 'handlers': {'console': {'class': 'logging.StreamHandler', 'formatter': 'simple', 'stream': 'ext://sys.stdout'}, 'file': {'class': 'logging.FileHandler', 'formatter': 'simple', 'filename': '/root/fairseq/examples/wav2vec/outputs/2023-11-23/19-57-14/hydra_train.log'}}, 'root': {'level': 'INFO', 'handlers': ['console', 'file']}, 'disable_existing_loggers': False}} Error executing job with overrides: ['task.data=/path/to/LibriSpeech/', 'distributed_training.distributed_world_size=1'] Traceback (most recent call last): File "/root/fairseq/fairseq_cli/hydra_train.py", line 27, in hydra_main _hydra_main(cfg) File "/root/fairseq/fairseq_cli/hydra_train.py", line 56, in _hydra_main distributed_utils.call_main(cfg, pre_main, kwargs) File "/root/fairseq/fairseq/distributed/utils.py", line 404, in call_main main(cfg, kwargs) File "/root/fairseq/fairseq_cli/train.py", line 96, in main model = task.build_model(cfg.model) File "/root/fairseq/fairseq/tasks/audio_pretraining.py", line 224, in build_model model = super().build_model(model_cfg, from_checkpoint) File "/root/fairseq/fairseq/tasks/fairseq_task.py", line 355, in build_model model = models.build_model(cfg, self, from_checkpoint) File "/root/fairseq/fairseq/models/init.py", line 106, in build_model return model.build_model(cfg, task) File "/root/fairseq/fairseq/models/wav2vec/wav2vec2.py", line 427, in build_model return cls(cfg) File "/root/fairseq/fairseq/models/wav2vec/wav2vec2.py", line 407, in init self.encoder = encoder_cls(cfg) File "/root/fairseq/fairseq/models/wav2vec/wav2vec2.py", line 1193, in init super().init(args) File "/root/fairseq/fairseq/models/wav2vec/wav2vec2.py", line 1060, in init [self.build_encoder_layer(args, layer_idx=ii) for ii in range(args.encoder_layers)] File "/root/fairseq/fairseq/models/wav2vec/wav2vec2.py", line 1060, in [self.build_encoder_layer(args, layer_idx=ii) for ii in range(args.encoder_layers)] TypeError: build_encoder_layer() got an unexpected keyword argument 'layer_idx'

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

Saravinnai commented 8 months ago

@jasonking1234 Have u fixed this issue? Could you pls help?

Lauler commented 8 months ago

You need to specify model.layer_type, model.attn_type and model.pos_enc_type in the launch command. Otherwise fairseq doesn't seem to understand that you are trying to train a conformer, and tries to build regular TransformerEncoder layers instead of ConformerEncoder.

The instructions in the wav2vec2 README are wrong or outdated. You cannot prepend -- when specifying the commands like they did in the README. This is the command I use to launch Conformer training (notice how layer_type, attn_type, and pos_enc_type are specified):

torchrun \
    --nproc_per_node=$NPROC_PER_NODE \
    --nnodes=$SLURM_JOB_NUM_NODES \
    --node_rank=$SLURM_NODEID \
    --master_addr=$MASTER_ADDR \
    --master_port=$MASTER_PORT \
    $(/usr/bin/which fairseq-hydra-train) \
    task.data=/leonardo_work/EUHPC_D01_040/data/p4/manifest/ \
    checkpoint.save_dir=$CHECKPOINT_DIR \
    distributed_training.distributed_world_size=$WORLD_SIZE \
    model.layer_type=conformer \
    model.attn_type=espnet \
    model.pos_enc_type=rel_pos \
    --config-dir $CONFIG_DIR \
    --config-name wav2vec2_large_conformer

If you get an error that ConformerEncoder doesn't recognize certain args you are passing to it. You also need to add an extra **kwargs argument to this line:

https://github.com/facebookresearch/fairseq/blob/fad2c4d1ebe14d974876de52dcb06db6d99b0b4a/fairseq/models/wav2vec/wav2vec2.py#L1175C5-L1175C41

mahendranitttrc commented 6 months ago

I also have same problem while building model using wav2vec2. Please Help!!!

[self.build_encoder_layer(args, layer_idx=ii) for ii in range(args.encoder_layers)] TypeError: ConformerEncoder.build_encoder_layer() got an unexpected keyword argument 'layer_idx'

zw76859420 commented 2 months ago

I also have same problem while building model using wav2vec2. Please Help!!!

Lauler commented 2 months ago

Read the comment above.

Change the code build_encoder_layer to allow a **kwargs argument. So from this:

class ConformerEncoder(TransformerEncoder):
    def build_encoder_layer(self, args):
        layer = ConformerWav2Vec2EncoderLayer(
            embed_dim=self.embedding_dim,
            ffn_embed_dim=args.encoder_ffn_embed_dim,
            attention_heads=args.encoder_attention_heads,
            dropout=args.dropout,
            depthwise_conv_kernel_size=args.depthwise_conv_kernel_size,
            activation_fn="swish",
            attn_type=args.attn_type,
            pos_enc_type=args.pos_enc_type,
            use_fp16=args.fp16,  # only used for rope
        )
        layer = fsdp_wrap(layer)
        if args.checkpoint_activations:
            layer = checkpoint_wrapper(layer)
        return layer

change to this:

class ConformerEncoder(TransformerEncoder):
    def build_encoder_layer(self, args, **kwargs):
        layer = ConformerWav2Vec2EncoderLayer(
            embed_dim=self.embedding_dim,
            ffn_embed_dim=args.encoder_ffn_embed_dim,
            attention_heads=args.encoder_attention_heads,
            dropout=args.dropout,
            depthwise_conv_kernel_size=args.depthwise_conv_kernel_size,
            activation_fn="swish",
            attn_type=args.attn_type,
            pos_enc_type=args.pos_enc_type,
            use_fp16=args.fp16,  # only used for rope
        )
        layer = fsdp_wrap(layer)
        if args.checkpoint_activations:
            layer = checkpoint_wrapper(layer)
        return layer
MagicLuke commented 2 months ago

@Lauler Hi, Which version of fairseq you are using? I faced the same problem and looked into the code snippet which looks different from what you have here. Mine was like:

def build_encoder_layer(self, args, **kwargs):
        if args.layer_type == "transformer":
            layer = TransformerSentenceEncoderLayer(
                embedding_dim=self.embedding_dim,
                ffn_embedding_dim=args.encoder_ffn_embed_dim,
                num_attention_heads=args.encoder_attention_heads,
                dropout=self.dropout,
                attention_dropout=args.attention_dropout,
                activation_dropout=args.activation_dropout,
                activation_fn=args.activation_fn,
                layer_norm_first=args.layer_norm_first,
            )
        elif args.layer_type == "conformer":
            layer = ConformerWav2Vec2EncoderLayer(
                embed_dim=self.embedding_dim,
                ffn_embed_dim=args.encoder_ffn_embed_dim,
                attention_heads=args.encoder_attention_heads,
                dropout=args.dropout,
                depthwise_conv_kernel_size=args.depthwise_conv_kernel_size,
                activation_fn="swish",
                attn_type=args.attn_type,
                use_fp16=args.fp16,
                pos_enc_type="abs",
            )
        elif args.layer_type == "trf_adp":
            use_adp = False
            if args.adp_trf_idx == "all":
                use_adp = True
            else:
                adp_trf_idx = list(range(*[int(g) for g in args.adp_trf_idx.split(":")]))
                if kwargs.get("layer_idx", None) in adp_trf_idx:
                    use_adp = True
            if use_adp:
                layer = TransformerSentenceEncoderWithAdapterLayer(
                    embedding_dim=self.embedding_dim,
                    ffn_embedding_dim=args.encoder_ffn_embed_dim,
                    num_attention_heads=args.encoder_attention_heads,
                    dropout=self.dropout,
                    attention_dropout=args.attention_dropout,
                    activation_dropout=args.activation_dropout,
                    activation_fn=args.activation_fn,
                    layer_norm_first=args.layer_norm_first,
                    adapter_num=args.adp_num,
                    adapter_dim=args.adp_dim,
                    adapter_act_fn=args.adp_act_fn,
                )
            else:
                layer = TransformerSentenceEncoderLayer(
                    embedding_dim=self.embedding_dim,
                    ffn_embedding_dim=args.encoder_ffn_embed_dim,
                    num_attention_heads=args.encoder_attention_heads,
                    dropout=self.dropout,
                    attention_dropout=args.attention_dropout,
                    activation_dropout=args.activation_dropout,
                    activation_fn=args.activation_fn,
                    layer_norm_first=args.layer_norm_first,
                )

        layer = fsdp_wrap(layer)
        if args.checkpoint_activations:
            layer = checkpoint_wrapper(layer)
        return layer

But follow your suggestion still give me same error

MagicLuke commented 2 months ago

@Lauler Hi Lauler, I am sorry, I just found the snippet you mentioned and changed as you said which works! Thanks!

MagicLuke commented 2 months ago

Additionally, we probably need to change one more place to run the script

# def extract_features(self, x, padding_mask=None, tgt_layer=None): # original
    def extract_features(self, x, padding_mask=None, tgt_layer=None, **kwargs):
        if padding_mask is not None:
            x = index_put(x, padding_mask, 0)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # B X T X C here
        position_emb = None
        if self.pos_enc_type == "rel_pos":
            position_emb = self.embed_positions(x)

        if not self.layer_norm_first:
            x = self.layer_norm(x)

        x = F.dropout(x, p=self.dropout, training=self.training)

        layer_results = []
        r = None
        for i, layer in enumerate(self.layers):
            dropout_probability = np.random.random()
            if not self.training or (dropout_probability > self.layerdrop):
                x, z = layer(
                    x,
                    self_attn_padding_mask=padding_mask,
                    need_weights=False,
                    position_emb=position_emb,
                )
                if tgt_layer is not None:
                    layer_results.append((x, z))
            if i == tgt_layer:
                r = x
                break

        if r is not None:
            x = r

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        return x, layer_results

so after altering two places mention above, the ConformerEncoder class will look like:

class ConformerEncoder(TransformerEncoder):
    # def build_encoder_layer(self, args): # original
    def build_encoder_layer(self, args, **kwargs):
        layer = ConformerWav2Vec2EncoderLayer(
            embed_dim=self.embedding_dim,
            ffn_embed_dim=args.encoder_ffn_embed_dim,
            attention_heads=args.encoder_attention_heads,
            dropout=args.dropout,
            depthwise_conv_kernel_size=args.depthwise_conv_kernel_size,
            activation_fn="swish",
            attn_type=args.attn_type,
            pos_enc_type=args.pos_enc_type,
            use_fp16=args.fp16,  # only used for rope
        )
        layer = fsdp_wrap(layer)
        if args.checkpoint_activations:
            layer = checkpoint_wrapper(layer)
        return layer

    def __init__(self, args):
        super().__init__(args)
        self.args = args
        self.dropout = args.dropout
        self.embedding_dim = args.encoder_embed_dim
        self.pos_enc_type = args.pos_enc_type
        max_source_positions = self.max_positions()

        if self.pos_enc_type == "rel_pos":
            self.embed_positions = RelPositionalEncoding(
                max_source_positions, self.embedding_dim
            )
        elif self.pos_enc_type == "rope":
            self.embed_positions = None
        else:
            raise Exception("Unsupported positional encoding type")

        self.layers = nn.ModuleList(
            [self.build_encoder_layer(args) for _ in range(args.encoder_layers)]
        )
        self.layer_norm_first = args.layer_norm_first
        self.layer_norm = LayerNorm(self.embedding_dim)
        self.layerdrop = args.encoder_layerdrop

        self.apply(init_bert_params)

    # def extract_features(self, x, padding_mask=None, tgt_layer=None): # original
    def extract_features(self, x, padding_mask=None, tgt_layer=None, **kwargs):
        if padding_mask is not None:
            x = index_put(x, padding_mask, 0)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # B X T X C here
        position_emb = None
        if self.pos_enc_type == "rel_pos":
            position_emb = self.embed_positions(x)

        if not self.layer_norm_first:
            x = self.layer_norm(x)

        x = F.dropout(x, p=self.dropout, training=self.training)

        layer_results = []
        r = None
        for i, layer in enumerate(self.layers):
            dropout_probability = np.random.random()
            if not self.training or (dropout_probability > self.layerdrop):
                x, z = layer(
                    x,
                    self_attn_padding_mask=padding_mask,
                    need_weights=False,
                    position_emb=position_emb,
                )
                if tgt_layer is not None:
                    layer_results.append((x, z))
            if i == tgt_layer:
                r = x
                break

        if r is not None:
            x = r

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        return x, layer_results