1. Currently bmt.OpTransformerBlockList can only handle the hidden states returned by transformer block.
Recent released flash_atten implemented transformer block returns hidden_states as well as residual in order to fuse Dropout -> Add -> LN. Additionally, the above two will be passed to the next block as input;
class Block(nn.Module):
def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None,
mixer_subset=None, mixer_kwargs=None):
if self.prenorm:
...
return hidden_states, residual
...
1. Currently
bmt.OpTransformerBlockList
can only handle the hidden states returned by transformer block.hidden_states
as well asresidual
in order to fuseDropout -> Add -> LN
. Additionally, the above two will be passed to the next block as input;bmt.OpTransformerBlockList
and cannot be properly handled by us.2. Request to support the above case which returns multiple values by a transformer block.