Open haozheji opened 3 years ago
It depends on implementation. i) Algorithm 1 is O(N) time. ii) Formula (8) is O(1) time. The cuda kernel uses the first implementation, so it's O(N). But it could be O(1) if necessary (through this is impractical for long sequences).
It seems that the loop over sequence length (N) is not processed by multiple threads (int t is a local variable) so the parallel complexity is actually O(N) comparing to O(1) of the original transformer.