chengchingwen / Transformers.jl

Julia Implementation of Transformer models
MIT License
526 stars 75 forks source link

No need to reshape? #26

Closed darsnack closed 4 years ago

darsnack commented 4 years ago

It appears that to pass input to the LayerNorm, the tensor is reshaped into a 2D matrix (feature_size x (sequence_length x batch_size), then reshaped back after all the norm layers are done operating? I think this happens in multiple places (i.e. the @toNd macro).

https://github.com/chengchingwen/Transformers.jl/blob/fbc8bb3582189d770778377a96994685d2c0b41c/src/basic/transformer.jl#L61

Based on a recent Zulip topic, I think this isn't required due to Julia's broadcasting machinery.

julia> x = rand(3, 2, 10)

julia> d = Flux.Diagonal(3)
Diagonal(3)

julia> d(Flux.normalise(x; dims = 1)) == reshape(d(Flux.normalise(reshape(x, 3, :); dims = 1)), 3, 2, 10)
true

As seen above, the LayerNorm body applied to a 3D tensor and reshaped tensor result in the same output.

chengchingwen commented 4 years ago

That's right, the LayerNorm doesn't need to reshape the input. However, Flux.Dense doesn't accept dimension higher than 2, those reshapes are not for LayerNorm but for Dense.