Closed Joshuaalbert closed 1 week ago
Take directional accumuation outside the vmap. Shard over channels.
vis = 0 for dir in directions: # vmap boundary vis += predict(k, freq_shard, times) # [T, B, C, 2, 2]
Take directional accumuation outside the vmap. Shard over channels.