Closed avik-pal closed 1 month ago
Attention: Patch coverage is 22.13740%
with 102 lines
in your changes are missing coverage. Please review.
Project coverage is 83.47%. Comparing base (
13ad936
) to head (7b67734
). Report is 2 commits behind head on main.:exclamation: Current head 7b67734 differs from pull request most recent head b000086
Please upload reports for the commit b000086 to get more accurate results.
Files | Patch % | Lines |
---|---|---|
src/layers/kan.jl | 0.00% | 91 Missing :warning: |
src/utils.jl | 0.00% | 11 Missing :warning: |
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
@avik-pal I see you have added the RBF version as well. Could you allow for the use of different normalizations in place of LayerNorm? I'm getting good results with tanh
. sigmoid
and softsign
are also possible options.
Current timings:
using Lux, KolmogorovArnold, Random
using LuxCUDA
using Zygote
rng = MersenneTwister(1234)
device = cpu_device()
mlp = Chain(Dense(1, 32, tanh), Dense(32, 32, tanh), Dense(32, 1)) # 1_153 parameters
kan = Chain(KDense(1, 10, 10), KDense(10, 10, 10), KDense(10, 1, 10)) # 1_320 parameters plus 30 states
kan_lux = Chain(KANDense(:RBF, 1, 10, 10; normalizer=tanh),
KANDense(:RBF, 10, 10, 10; normalizer=tanh),
KANDense(:RBF, 10, 1, 10; normalizer=tanh)) # 1_320 parameters plus 36 states
x = rand32(rng, 1, 1000) |> device
pM, stM = Lux.setup(rng, mlp) |> device
pK, stK = Lux.setup(rng, kan) |> device
pKL, stKL = Lux.setup(rng, kan_lux) |> device
lossfn(m, x, p, st) = sum(abs2, first(m(x, p, st)))
# CPU Timings
@btime $mlp($x, $pM, $stM)
@btime $kan($x, $pK, $stK)
@btime $kan_lux($x, $pKL, $stKL)
@btime Zygote.gradient($lossfn, $mlp, $x, $pM, $stM)
@btime Zygote.gradient($lossfn, $kan, $x, $pK, $stK)
@btime Zygote.gradient($lossfn, $kan_lux, $x, $pKL, $stKL)
## Forward pass
# 130.254 μs (11 allocations: 254.39 KiB)
# 1.915 ms (39 allocations: 1.20 MiB)
# 1.849 ms (58 allocations: 1.20 MiB) -- RBF version
# 31.953 ms (182 allocations: 26.90 MiB) -- BSpline version
## Backward pass
# 642.610 μs (36 allocations: 767.95 KiB)
# 5.527 ms (629 allocations: 8.21 MiB)
# 3.651 ms (624 allocations: 6.06 MiB)
# 37.945 ms (616 allocations: 63.20 MiB)
device = gpu_device()
x = x |> device;
pM, stM = pM |> device, stM |> device;
pK, stK = pK |> device, stK |> device;
pKL, stKL = pKL |> device, stKL |> device;
# CUDA Timings
@btime CUDA.@sync $mlp($x, $pM, $stM)
@btime CUDA.@sync $kan($x, $pK, $stK)
@btime CUDA.@sync $kan_lux($x, $pKL, $stKL)
@btime CUDA.@sync Zygote.gradient($lossfn, $mlp, $x, $pM, $stM)
@btime CUDA.@sync Zygote.gradient($lossfn, $kan, $x, $pK, $stK)
@btime CUDA.@sync Zygote.gradient($lossfn, $kan_lux, $x, $pKL, $stKL)
## Forward pass
# 63.207 μs (237 allocations: 5.70 KiB)
# 312.889 μs (1052 allocations: 24.83 KiB)
# 274.267 μs (1308 allocations: 29.91 KiB) -- RBF version
# 2.179 ms (8881 allocations: 193.25 KiB) -- BSpline version
## Backward pass
# 405.359 μs (1235 allocations: 41.66 KiB)
# 1.795 ms (5671 allocations: 154.55 KiB)
# 1.677 ms (5610 allocations: 681.97 KiB)
# 4.589 ms (17719 allocations: 477.38 KiB) -- BSpline version
To match the weights do
using Setfield
@set! pKL.layer_1.base_linear.weight = pK.layer_1.W2
@set! pKL.layer_1.main_model.spline_linear.weight = pK.layer_1.W1
@set! pKL.layer_2.base_linear.weight = pK.layer_2.W2
@set! pKL.layer_2.main_model.spline_linear.weight = pK.layer_2.W1
@set! pKL.layer_3.base_linear.weight = pK.layer_3.W2
@set! pKL.layer_3.main_model.spline_linear.weight = pK.layer_3.W1
looks great. can you also do a comparison with use_base_act = false
which disables the silu component and the corresponding linear transformation?
I'm getting a 2x speedup in that case which is weird because that part should not be accounting for 50% of the compute.
I'm getting a 2x speedup in that case which is weird because that part should not be accounting for 50% of the compute.
Not very though, CPU broadcasting is actually quite a bit bottleneck in julia. But nome of the multithreaded versions work with Zygote
I saw the 2x improvement in FWD pass and 1.5x on BWD pass on the GPU. This is from my readme when I tested KDense with a 2080 Ti.
# use_base_act = true (default)
@btime CUDA.@sync $kan($x, $pK, $stK) # 155.781 μs (565 allocations: 17.50 KiB)
@btime CUDA.@sync Zygote.gradient($f_kan, $pK) # 1.250 ms (3879 allocations: 136.06 KiB)
# use_base_act = false
@btime CUDA.@sync $kan($x, $p, $st) # 83.275 μs (310 allocations: 10.00 KiB)
@btime CUDA.@sync Zygote.gradient($f, $p) # 874.364 μs (2746 allocations: 99.70 KiB)
I'm curious to see if you get similar results on GPU
Closing it here, I am restructuring Boltz.jl a bit and will move this there. Also because most of my experiments with Neural ODEs and stuff seems to suggest KANs don't work well there.
Very Alpha Implementation based on https://github.com/KindXiaoming/pykan
New API Introduced
KANDense
TODOs
All forms of pruning and such can be mostly omitted