FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.47k stars 604 forks source link

Performance issues when switching from local machine to HPC cluster. #2484

Open SGHoekstra opened 1 week ago

SGHoekstra commented 1 week ago

Hey everyone,

First time posting here so let me know if I did not do it correctly or need to add more information.

I am dealing with a performance issues switching from my local machine to an HPC cluster. I found that the bottleneck is Flux. I got a minimal working example on both machines.

This is the minimal working example: [Edited to use three backticks for a code block]

using Flux

t1 = time();

xtrain     = Float32[0.1 0.2; 0.3 0.5; 0.4 0.1; 0.5 0.4; 0.7 0.9; 0.2 0.1]
ytrain     = Float32[0.3; 0.8; 0.5; 0.9; 1.6; 0.3]
xtest      = Float32[0.5 0.6; 0.14 0.2; 0.3 0.7]
ytest      = Float32[1.1; 0.36; 1.0]

model = Dense(2, 1) # Use Chain if you want to stack layers

loss(x, y) = Flux.mse(model(x), y)
ps = Flux.params(model)
dataset = [(xtrain', ytrain')] # Use DataLoader for easy minibatching
opt = ADAGrad()

for i = 1:1000
    Flux.train!(loss, ps, dataset, opt)
end
elapsed_time = time() - t1;

println("Elapsed time: ", elapsed_time, " seconds");

On both machines I use Julia v1.10.4 and Flux v0.14.19.

On my local machine (OSX 14.5) I get Elapsed time: 0.5202600955963135 seconds. On the HPC cluster (Linux 5.14.0-427.31.1.el9_4.x86_64) I get 12.309736967086792 seconds.

I run many small models simultaneous and sequently so this slowdown makes the simulations I need to run unfeasible. Does anyone have an idea what would cause this slow down?

Thank you in advance. Best, Steven

mcabbott commented 1 week ago

Timing like this mostly measures compilation time. There's some chance this is much slower on your server because it's the first run with these package versions, or something like that. I'd suggest as a first step timing separately the first and subsequent runs, like so:

@time for i = 1:2   # this will print only compilation time
    Flux.train!(loss, ps, dataset, opt)
end
@time for i = 3:1000  # this will print the time to run after compilation
    Flux.train!(loss, ps, dataset, opt)
end
mcabbott commented 1 week ago

Aside, I highly recommend changing this to use "explicit" style, as in current docs https://fluxml.ai/Flux.jl/stable/guide/training/training/ . This "implicit" style with Flux.params will stop working soon.

loss(m, x, y) = Flux.mse(m(x), y)  # no longer closes over global variable `model`

opt_state = Flux.setup(ADAGrad(), model)

for i = 1:1000
    Flux.train!(loss, model, dataset, opt_state)
end
SGHoekstra commented 5 days ago

Thank you for your suggestions. I have implemented your suggestions and I get now that the Linux HPC is faster than my local machine for the test script.

I cannot achieve the same speed up for the model I am actually interested in training. I have a created a minimal example of the code of the training process below. For this script I still get that the HPC is about 10 times slower than my local machine. Any help would be much appreciated!

Output local OSX machine: 7.833103 seconds (30.90 M allocations: 1.934 GiB, 4.84% gc time, 99.69% compilation time) 2.098892 seconds (1.22 M allocations: 6.829 GiB, 7.78% gc time)

Output HPC: 23.756881 seconds (30.78 M allocations: 1.924 GiB, 5.56% gc time, 99.32% compilation time) 20.209329 seconds (1.22 M allocations: 6.810 GiB, 0.66% gc time)

# code [edited to add code block, three `s]

using Flux

network_width = 32

perceptron = Chain(
    Dense(6, network_width, leakyrelu), 
    Dense(network_width, network_width, leakyrelu), 
    Dense(network_width, network_width, leakyrelu), 
    Dense(network_width, network_width, leakyrelu),  
    Dense(network_width, 3, relu)      
  )

opt_state =  Flux.Optimiser(Flux.Adam(1e-6),ClipValue(1e-5))
state = Flux.setup(opt_state, perceptron)

k = Float32.(Vector(range(1,100,1000)))
b = Float32.(Vector(range(1,100,1000)))
w = Float32.(Vector(range(1,100,1000)))
r_k = Float32.(vcat(fill(0.1, length(k))...))
r_b = Float32.(vcat(fill(0.01, length(k))...))    
p = Float32.(vcat(fill(1, length(k))...))    
pi_ = Float32.(vcat(fill(0.01, length(k))...))

function abs_appr(x)
    y = sqrt.(x.^2 .+ Float32(1e-6)) 
    return y 
end

    function Residuals(perceptron, r_k, r_b, k, b, w, p, pi_,weights)
    n = size(w, 1) 

    s = hcat(r_k, r_b, k, b, w, p)'  
    x = perceptron(s)  

    c  = x[1, :] 
    k1 = x[2, :]
    b1 = x[3, :]

    d = k1 .- (1 .+ r_k) .* k

    rknext = Float32.(max.(exp.(log.(1 .+ r_k) .* 0.9 .+ 0.1 .+ 0.1 * randn(Float32, n)) .- 1,0))
    rbnext = Float32.(exp.(log.(1 .+ r_b) .* 0.9 .+ 0.1 .+ 0.1 * randn(Float32, n)) .- 1)
    pinext = Float32.(exp.(log.(1 .+ pi_) .* 0.9 .+ 0.1 .+ 0.1 * randn(Float32, n)) .- 1)
    wnext = Float32.(w .* 0.9 .+ 0.1 .+ 0.1 * randn(Float32, n))
    p1 = Float32.(p .* (1 .+ pinext))

    s = hcat(rknext, rbnext, k1, b1, wnext, p1)'  
    x = perceptron(s)  
    c1  = x[1, :] 
    k2  = x[2, :]

    d1 = k2 .- (1 .+ rknext) .* k1

    R1 =  Float32.(1 .- 0.95 .* (1 .+ rbnext) .* (c1  ./ c ).^(-1.5) .* (p ./ p1))
    R2 =  Float32.(w .+ (1 .+ r_b) .* b .+ (1 .+ r_k) .* k .- c .* p .- b1 .- 0.01 .* abs_appr.(d).^1.5 .- k1)
    R3 =  Float32.(1 .+ d .* 0.01 .* 1.5 .* abs_appr.(d).^(1.5 - 2) .- 0.95 .* (1 .+ rknext) .* (c1 ./ c ).^(-1.5) .* (p ./ p1) .* (1 .+ d1 .* 0.01 .* 1.5 .* abs_appr.(d1).^(1.5 - 2)))

    R_squared = sum(weights[1] * R1.^2 + weights[2] *R2.^2 + weights[3] *R3.^2)/n

    return R_squared
    end

   function train_me!(epochs, perceptron, w, k, b, r_k, r_b, p, pi_, state; weights = [1,1,1])

    for epoch in 1:epochs
        # Compute the value and gradients of the loss function
        val, grads = Flux.withgradient(perceptron) do m

            loss = Residuals(m, r_k, r_b, k, b, w, p, pi_, weights)

        end

        Flux.update!(state, perceptron, grads[1])

    end
    end

   @time train_me!(2, perceptron, w, k, b, r_k, r_b, p, pi_, state; weights = [1, 0.1, 1]);

   @time train_me!(1000, perceptron, w, k, b, r_k, r_b, p, pi_, state; weights = [1, 0.1, 1]);
mcabbott commented 5 days ago

Xref https://discourse.julialang.org/t/flux-slows-down-by-10x-when-moving-from-local-system-to-high-performance-cluster/119753

No time to look closely now. But there is some chance your mac is just fast! M-processor memory is very quick. Some chance matmul isn't taking advantage of threads -- what does LinearAlgebra.BLAS.get_num_threads() say? Could investigate AppleAccelerate / MKL.jl as Julia's default is OpenBLAS which is su-optimal on different processors to varying degrees.

SGHoekstra commented 5 days ago

I don't think it necessarily has to do that with that my Mac is faster. The first example I gave was faster on the cluster than on my Mac after I implemented the changes you suggested. The cluster took about 0.027025 seconds (97.80 k allocations: 6.853 MiB, 38.61% gc time) for 1000 training steps while my Mac took 0.469306 seconds (399.92 k allocations: 41.605 MiB, 0.97% gc time). This makes me think the performance differences are due to some non optimal code...

I get 8 threads on my Mac machine and 48 threads on the Linus machine. What does this mean in terms of performance?

Finally, I am new to asking help on forums like Github. Is cross posting bad form? I thought some that this would perhaps get some extra insights. Thanks!

mcabbott commented 5 days ago

Fine to ask in a few places, I just like cross-linking so that anyone can check whether what they're about to say has already been typed up nicely elsewhere.

I'm not so surprised if relative performance looks different at different sizes. For large enough matrices, matmul will usually dominate the time, and this will depend on what BLAS library & how it works with your processor. Here at sizes like 32x1000 I'm not super-sure.

Standard advice would be to profile & see where time is spent, but IMO this is seldom revealing once Zygote is involved. I don't see obvious performance-killing mistakes.

SGHoekstra commented 4 days ago

The user p_f suggested increasing the number of BLAS threads. Strangely setting the BLAS threads lower improved performance.

I have tried experimenting a bit with setting the number of threads to a different number. I got the following results:

Running on node: fcn1 Number of processors allocated by SLURM: 16 Number of threads: 1 21.707765 seconds (31.16 M allocations: 1.916 GiB, 6.08% gc time, 99.87% compilation time) 2.532860 seconds (1.01 M allocations: 6.150 GiB, 4.12% gc time) Number of threads: 2 3.861321 seconds (1.01 M allocations: 6.150 GiB, 3.04% gc time) Number of threads: 4 4.540994 seconds (1.01 M allocations: 6.150 GiB, 2.26% gc time) Number of threads: 8 5.932519 seconds (1.01 M allocations: 6.150 GiB, 2.47% gc time) Number of threads: 16 7.079593 seconds (1.01 M allocations: 6.150 GiB, 1.41% gc time) Number of threads: 32 11.259781 seconds (1.01 M allocations: 6.150 GiB, 0.86% gc time)

Apparently less threads are better in this case? Setting BLAS threads to one makes the HPC as fast as my Mac (excluding compilation time).