Jutho / TensorOperations.jl

Julia package for tensor contractions and related operations
https://jutho.github.io/TensorOperations.jl/stable/
Other
462 stars 56 forks source link

Runtime dispatch overhead for small tensor contractions using the `StridedBLAS` backend #189

Closed leburgel closed 2 weeks ago

leburgel commented 2 weeks ago

When profiling some code involving many contractions with small tensors I noticed that there is a lot of overhead due to runtime dispatch and resulting allocations and garbage collection when using the StridedBLAS backend. I've figured out some of it but I thought it would be good to report here to see if something more can be done.

I ran a profile for a dummy example contraction of small complex tensors which can sort of reproduce the issue:

using TensorOperations

T = ComplexF64
L = randn(T, 8, 8, 7)
R = randn(T, 8, 8, 7)
O = randn(T, 4, 4, 7, 7)

function local_update!(psi::Array{T,3})::Array{T,3} where {T}
    @tensor begin
        psi[-1, -2, -3] =
            psi[1, 3, 5] *
            L[-1, 1, 2] *
            O[-2, 3, 2, 4] *
            R[-3, 5, 4]
    end
    return psi
end

psi0 = randn(T, 8, 4, 8)
@profview begin
    for _ in 1:100000
        local_update!(psi0)
    end
end

which gives: small_contract_profile

So a lot of runtime dispatch overhead in TensorOperations.tensoradd! and Strided._mapreduce_order!. The effect becomes negligible for large tensor dimensions, but it turns out to be a real pain if there are a lot of these small contractions being performed.

For TensorOperations.tensoradd!, I managed to track the problem to an ambiguity caused by patterns like flag2op(conjA)(A), where the return type at compile time can be a Union of two StridedView concrete types with typeof(identity) and typeof(conj) as their op field types respectively. This leads to an issue here: https://github.com/Jutho/TensorOperations.jl/blob/c1e37eca9c2d1ab468ddd1beb2ffb9eb0993bb4b/src/implementation/strided.jl#L99-L103 where the last argument in the call to Strided._mapreducedim! is a Tuple with mixed concrete and abstract (the union described above) types in its type parameters, which seems to mess things up.

At that level I managed to fix things by just splitting into two branches to de-confuse the compiler

opA = flag2op(conjA)
if opA isa typeof(identity)
    A′ = permutedims(identity(A), linearize(pA))
    op1 = Base.Fix2(scale, α)
    op2 = Base.Fix2(scale, β)
    Strided._mapreducedim!(op1, +, op2, size(C), (C, A′))
    return C
elseif opA isa typeof(conj)
    A′ = permutedims(conj(A), linearize(pA))
    op1 = Base.Fix2(scale, α)
    op2 = Base.Fix2(scale, β)
    Strided._mapreducedim!(op1, +, op2, size(C), (C, A′))
    return C
end

which gets rid of the runtime dispatch in tensoradd! and already makes a big difference for small contractions.

I haven't descended all the way down into Strided._mapreduce_dim!, so I don't know what the issue is there.

So in the end my questions are:

Jutho commented 2 weeks ago

So I keep getting sucked into the lowest levels of the code, never having time to properly study PEPSKit.jl 😄 .

Jutho commented 2 weeks ago

Certainly, for conjA, the hope was that constant propagation would resolve the ambiguity, but apparently it does not. Maybe we need some more aggressive constant propagation.

Jutho commented 2 weeks ago

I assume the actual coding pattern does not include a function local_update! that captures global variables, right?

Jutho commented 2 weeks ago

Ok, it seems adding @constprop :aggressive before every tensorcontract! and tensoradd! definition solves the problem, but it is of course not particularly nice.

An alternative would be to explicitly encode the conjugation flag in the type domain, e.g. by having a parametric singleton struct ConjugationFlag{true}() and ConjugationFlag{false}(). I am interested in hearing the opinion of @lkdvos .

With Cthulhu.jl, I couldn't spot an additional ambiguity/dynamic dispatch in mapreduce_order!, so I am not sure if those two are related or if that is yet another issue.

lkdvos commented 2 weeks ago

We are already doing this for the istemp flag with a Val{true}, so in principle I have nothing against this change. It's a bit unfortunate that this has to be breaking, so we might want to investigate a bit more... Do you know if we can do const propagation for single arguments?

Jutho commented 2 weeks ago

No that doesn't work. I also tried only annotating the flag2op function, so that also doesn't work. Maybe there is some in between balance where you only need to annotate some of the tensor operation definitions, but that can only be found by trial and error I think.

Another nonbreaking solution is to manually split the code based on the flag values, but there are quite a few instances where flag2op is called, and in contract there are two of those, so then one has 4 different cases.

lkdvos commented 2 weeks ago

To be fair, if just splitting the flag2op into if statements works, that might be the cleanest. I do see how that function is inherently type unstable, so either this information needs to be type domain or we do a manual Union-splitting type construction, which I had hoped the compiler would have figured out but apparently it doesn't.

Jutho commented 2 weeks ago

Ok, I'll take a stab at it asap.

leburgel commented 2 weeks ago

I assume the actual coding pattern does not include a function local_update! that captures global variables, right?

No, in the realistic setting everything is passed through locally and all of the concrete array types are properly inferred in the call to tensorcontract!. I wasn't really paying attention with to the globals in the example, but it runs into the same problem lower down.

leburgel commented 2 weeks ago

With Cthulhu.jl, I couldn't spot an additional ambiguity/dynamic dispatch in mapreduce_order!, so I am not sure if those two are related or if that is yet another issue.

After the branch splitting with an if statement the ambiguity in tensoradd! is completely removed, but the runtime dispatch in mapreduce_order! remains. So that might be another issue, but I couldn't really figure it out.

Jutho commented 2 weeks ago

I think the dynamic dispatch in the mapreduce lowering chain comes from the fact that in _mapreduce_fuse! and _mapreduce_block!, the compiler doesn't specify on the function arguments f, op and initop. I think this is a known behaviour of the Julia compiler to reduce compile times: if you want to specify function arguments on the specific function, you should actually make it a type parameter in your method. In this case, the non-specialisation is in principle fine, as this has no effect on any of the types of the variables in those function bodies, so everything is perfectly type stable.

But then in finally selecting _mapreduce_kernel, that being a generated function, it does again specialise on f etc, and so there is some dynamic dispatch going on. Not sure what the impact of that is.

Jutho commented 2 weeks ago

Ok, so I am not sure how reliable this is, this is from simply timing the first call (so timing compilation) and then benchmarking the call on my computer, both including the fix from the PR, but the first one with the vanilla version of Strided.jl:

julia> @time local_update!(psi0);
 19.850534 seconds (34.72 M allocations: 2.236 GiB, 4.29% gc time, 100.00% compilation time)

julia> @benchmark local_update!($psi0)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  30.459 μs … 123.357 ms  ┊ GC (min … max):  0.00% … 99.86%
 Time  (median):     41.032 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   71.495 μs ±   1.255 ms  ┊ GC (mean ± σ):  32.50% ±  4.91%

  ▄▇▄▂▂▅█▅▁▂▂▄▃                     ▁▂▂▂▃▂▂▂▂▂▁                ▂
  ██████████████▇▆▇▆▇▇▆▆▇▆▄▄▆▅▄▆▅▆▇██████████████▇▇▇▆▆▆▅▄▄▅▁▃▅ █
  30.5 μs       Histogram: log(frequency) by time       130 μs <

 Memory estimate: 135.19 KiB, allocs estimate: 59.

and then with a modified Strided.jl where f::F1, op::F2 and initop::F3 are now explicit type parameters of the relevant methods:

julia> @time local_update!(psi0);
 36.712724 seconds (27.60 M allocations: 1.740 GiB, 1.54% gc time, 100.00% compilation time)

julia> @benchmark local_update!($psi0)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  26.941 μs … 122.038 ms  ┊ GC (min … max):  0.00% … 99.86%
 Time  (median):     38.208 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   69.025 μs ±   1.243 ms  ┊ GC (mean ± σ):  33.64% ±  4.82%

  ▃▆▇▃▂▄█▇▃▁▂▄▂                     ▁▁▂▂▂▂▂▂▂▁▁▁               ▂
  ██████████████▇▇█▇▅▆▆▇▆▆▅▅▄▃▂▄▄▅▆█████████████████▇▇▆▅▆▄▅▅▄▄ █
  26.9 μs       Histogram: log(frequency) by time       127 μs <

 Memory estimate: 131.98 KiB, allocs estimate: 15.

So allocations and runtime are a little bit lower, but compilation time is up by a factor of almost two.

leburgel commented 2 weeks ago

Thanks for figuring out the mapreduce issue! It is a small effect and a lot more compilation time, so I'm not sure if you think it's worth actually making the change in Strided.jl?

I can anyway make the change locally now that I understand what's going on, so as far as I'm concerned my problems are solved :)

Jutho commented 2 weeks ago

Ok, let us know if you encounter cases where specialising the methods in Strided do actually lead to a noticeable speedup. I will close this and merge the PR.

lkdvos commented 2 weeks ago

We could consider putting the extra compilation time into a precompile call in TensorOperations: we keep the non-specialised generic implementation in Strided.jl, but make a specialised method in TensorOperations, which we precompile there. I do feel like this might be worthwhile, as in TensorKit, with the symmetries we call these methods quite often, so it might actually add up.

We should probably test this first however, as this is quite a bit of extra complexity to maintain.

Jutho commented 2 weeks ago

Wouldn't this de facto mean a decoupling of TensorOperations from Strided? I don't really see how to have a separate call stack for TensorOperations without duplicating most of the Strided.jl code and simply adding the specialisations.

lkdvos commented 2 weeks ago

I guess so, but it would also mean that not everyone necessary has to pay the compilation time price... I'm not entirely sure how much of the call stack would need to be duplicated, but for these low-level features this might just be worth it. Of course, it's also an option to just keep the compile time in Strided, precompile the specific TensorOperations kernels to mitigate that, and then see if it becomes hindering for other Strided usage?

In a perfect world it should really be possible to have non-allocating usage of @tensor with either preallocated temporaries or Bumper.jl, which is now not the case. Keeping in mind that this function is called once for every subblock associated to a fusiontree, this can easily become quite a large amount of overhead. We should probably try out some Hubbard-type symmetry, but it's not urgent, I can add it to the to-do:)

As an aside, just playing games with the compiler here: what happens if you add a @nospecialize in the kernel as well? I guess that would get rid of the type instability but hurt performance too much? Does that even work?

Jutho commented 2 weeks ago

I guess that would mean that dynamic dispatch is happening in the hot loop of the kernel? Since it is the function that is called in the loop body itself that is unknown, whereas the variables in the loop are all inferred, it might be that the dynamic dispatch of selecting the actual function is hoisted out of the loop. Just speculating though.

There used to be @nospecialize annotators before the function arguments in the _mapreduce methods (not sure if that included the kernel), and at some point I removed them and was happy to see that this did not have any significant effect compilation time. I guess I now understand why 😄 .

Jutho commented 2 weeks ago

Also, there are other small allocations in Strided.jl, which have to do with how to block the problem or how to distribute it across threads.

lkdvos commented 2 weeks ago

But I guess for problems of this size the multithreading does not yet kick in?

Jutho commented 2 weeks ago

No, and that's anyway only if there is more than one thread. I am not sure if the extra allocations are only there for multithreading though, this I should check. There might also be some for other parts of the blocking and loop reordering, even though I tried to make everything based on type-stable tuples.

amilsted commented 2 weeks ago

@Jutho Was the thing you tried literally just to add type parameters on those function arguments down the _mapreduce callchain? If so, I guess we can reproduce that easily and see if it helps our case. Might it be enough to only introduce the specialization on the methods that actually call the functions?

Wait, I see you mention this above. I guess this is currently what happens and you get dynamic dispatch when calling the kernel.

amilsted commented 2 weeks ago

I think the overhead is kind of important, as it prevents Strided's version of permutedims from being a general replacement. I didn't use it in QuantumOpticsBase in the LazyTensor implementation because this overhead is noticeable for small systems. Maybe some explicit inlining could bring down the compile time a bit?

Jutho commented 2 weeks ago

@amilsted and @leburgel: yes, getting rid of the runtime dynamic dispatch was obtained by fully specializing the f, op and initop arguments by giving them type parameters in the method. I have pushed this to a branch to facilitate trying out: https://github.com/Jutho/Strided.jl/tree/jh/specialize

However, I am not sure if the overhead you see is from that, or simply from the kind of computations that Strided does to find a good loop order and blocking scheme. That overhead can be nonnegligable in the case of small systems. Maybe it pays off to have a specific optimized implementation for tensor permutation. I'd certainly be interested in collaborating on that.

lkdvos commented 2 weeks ago

Do we know how something like LoopVectorization.jl stacks up? From what I know, that should also attempt to find some loop order and blocking scheme and unrolling etc, which might be important for these kinds of systems...

Jutho commented 2 weeks ago

Last time I checked (which is a while ago), LoopVectorization was great at loop ordering, unrolling and inserting vectorised operations, but it did not do blocking. Octavian does this separately I think. The other major issue was that LoopVectorisation did not support composite types such as Complex, which is kinda important :-).

Finally, also, LoopVectorization would lead to a separate compilation stage for every new permutation. I don't know if we want to go that way? Maybe that's fine for a custom backend, but not for the default one.

The other thing to revive is HPTT; I once wrote the jll package for that using BinaryBuilder, so it shouldn't be much work to write the wrappers to make an HPTTBLAS backend (meaning HPTT for permutations and (permute+BLAS) for contractions).

Jutho commented 2 weeks ago

Also, on the Strided side, maybe there is a way to "cache" the blocking scheme etc, but I am not really sure if that is worth it.

lkdvos commented 2 weeks ago

And of course compare to tblis as well, I don't think I benchmarked that for smaller systems (and I think it leverages hptt under the hood).

Let me revive https://github.com/Jutho/TensorOperations.jl/tree/ld/benchmark as well at some point, to maybe more rigorously benchmark all implementations and changes. I have a set of permutations and contractions there, but maybe @amilsted or @leburgel, if you could supply some typical contractions and or array sizes, we can tailor our efforts a bit more towards that.

amilsted commented 2 weeks ago

Is it worth considering switching to Julia's own permutedims for small arrays?

lkdvos commented 2 weeks ago

You should be able to try this out already, the BaseCopy and BaseView backends should do this if I'm not mistaken

amilsted commented 2 weeks ago

@leburgel looks like BaseCopy would be the one to try.

leburgel commented 2 weeks ago

I'll give it a go!

Jutho commented 2 weeks ago

The Base... methods were not really meant for performance. The problem with using Base.permutedims is that it doesn't support the alpha and beta parameters. And while in the case of a tensor network contraction, we don't typically need them (i.e. they take there default value), it is annoying to have to check and correct for that. Maybe copying the Base implementation and generalising it to include the alpha en beta is one way forward. But as far as I remember from looking at quite some time ago, it goes via PermuteDimsArray, which encodes the permutation in its type parameters, and thus also leads to new compilation for every new permutation. Maybe that is the way to go, and it is not excessive in real life applications.

I often benchmarked a number of randomly generated permutations which were ran just once, and then of course such approaches fail miserably. But I guess that is not a very realistic benchmark.