mattiadg / FBK-Fairseq-ST

An adaptation of Fairseq to (End-to-end) speech translation.
Other
22 stars 13 forks source link

Error when not using any distance penalty #12

Open gegallego opened 3 years ago

gegallego commented 3 years ago

Hello @mattiadg

When I don't use the --distance-penalty flag I get the following error:

File "~/FBK-Fairseq-ST/fairseq/models/s_transformer.py", line 472, in __init__
    init_variance=(args.init_variance if args.distance_penalty == 'gauss' else None)
TypeError: __init__() got an unexpected keyword argument 'penalty'

The problem comes from the following lines in the constructor of TransformerEncoderLayer:

attn = LocalMultiheadAttention if args.distance_penalty != False else MultiheadAttention    
self.self_attn = attn(
    self.embed_dim, args.encoder_attention_heads,
    dropout=args.attention_dropout, penalty=args.distance_penalty,
    init_variance=(args.init_variance if args.distance_penalty == 'gauss' else None)
)

The argumentspenalty and init_variance do not exist in MultiheadAttention, so I substituted these lines by:

if args.distance_penalty != False:
    self.self_attn = LocalMultiheadAttention(
        self.embed_dim, args.encoder_attention_heads,
        dropout=args.attention_dropout, penalty=args.distance_penalty,
        init_variance=(args.init_variance if args.distance_penalty == 'gauss' else None)
    )
else:
    self.self_attn = MultiheadAttention(
        self.embed_dim, args.encoder_attention_heads,
        dropout=args.attention_dropout,
    )
mattiadg commented 3 years ago

Do you want to submit a pull request?

gegallego commented 3 years ago

👍