This PR updates the reset function to convert the dones mask to a boolean mask before applying it to reset hidden states. Previously, the function used a 0/1 mask directly, which selected the first and second set of indices in the last but one dimension, rather than resetting the hidden states at the desired positions indicated by 1 in the mask. By converting the dones mask to a boolean mask, we ensure that the hidden states are correctly reset at the specified positions.
This PR updates the reset function to convert the
dones
mask to a boolean mask before applying it to reset hidden states. Previously, the function used a 0/1 mask directly, which selected the first and second set of indices in the last but one dimension, rather than resetting the hidden states at the desired positions indicated by 1 in the mask. By converting thedones
mask to a boolean mask, we ensure that the hidden states are correctly reset at the specified positions.