in transformer.py, line 87,
mask = Lambda(lambda x:K.repeat_elements(x, n_head, 0))(mask)
this line makes the mask shape (in readout_model) like (batch_sizen_head,x,x), but the shape of the result of reshape1 like (n_headbatch_size,x,x), it seems the same shape, but the elements not.
in transformer.py, line 87, mask = Lambda(lambda x:K.repeat_elements(x, n_head, 0))(mask) this line makes the mask shape (in readout_model) like (batch_sizen_head,x,x), but the shape of the result of reshape1 like (n_headbatch_size,x,x), it seems the same shape, but the elements not.
Maybe the repeat_elements could change to tile?