I implemented a basic transformer block with residual connections and am getting the following error:
NotImplementedError: `FanInSum` is only implemented for the case where all input layers guaranteed to be mean-zero Gaussian, i.e. having all `is_gaussian` set to `True`, got [True, False].
Eventually would like to also include causal masking, and if you have pointers there that would also be great as it is also not clear how to do a upper triangular mask in the infinite width seq len case.
Hi,
I implemented a basic transformer block with residual connections and am getting the following error:
It appears that it's due to
stax.Identity()
Here is the implementation:
And then taking the example data from the cookbook:
where the error occurs in the
kernel_fn
calculation.What is odd is that the
ResBlock
works in the cookbook:And it appears that with
linear_scaling=True
that theis_gaussian=True
from this line: https://github.com/google/neural-tangents/blob/c17e770bb74f1771da7be4a69fabfa68b6078960/neural_tangents/_src/stax/linear.py#L2464C14-L2468C39Eventually would like to also include causal masking, and if you have pointers there that would also be great as it is also not clear how to do a upper triangular mask in the infinite width seq len case.