LuxDL / Lux.jl

Elegant & Performant Scientific Machine Learning in Julia
https://lux.csail.mit.edu/
MIT License
456 stars 54 forks source link

[WIP] Kolmogorov-Arnold Networks #627

Closed avik-pal closed 1 month ago

avik-pal commented 2 months ago

Very Alpha Implementation based on https://github.com/KindXiaoming/pykan

New API Introduced

TODOs

All forms of pruning and such can be mostly omitted

codecov[bot] commented 2 months ago

Codecov Report

Attention: Patch coverage is 22.13740% with 102 lines in your changes are missing coverage. Please review.

Project coverage is 83.47%. Comparing base (13ad936) to head (7b67734). Report is 2 commits behind head on main.

:exclamation: Current head 7b67734 differs from pull request most recent head b000086

Please upload reports for the commit b000086 to get more accurate results.

Files Patch % Lines
src/layers/kan.jl 0.00% 91 Missing :warning:
src/utils.jl 0.00% 11 Missing :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #627 +/- ## ========================================== - Coverage 86.86% 83.47% -3.40% ========================================== Files 50 51 +1 Lines 2482 2584 +102 ========================================== + Hits 2156 2157 +1 - Misses 326 427 +101 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

vpuri3 commented 2 months ago

@avik-pal I see you have added the RBF version as well. Could you allow for the use of different normalizations in place of LayerNorm? I'm getting good results with tanh. sigmoid and softsign are also possible options.

avik-pal commented 2 months ago

Current timings:

using Lux, KolmogorovArnold, Random
using LuxCUDA
using Zygote

rng = MersenneTwister(1234)
device = cpu_device()

mlp = Chain(Dense(1, 32, tanh), Dense(32, 32, tanh), Dense(32, 1)) # 1_153 parameters
kan = Chain(KDense(1, 10, 10), KDense(10, 10, 10), KDense(10, 1, 10)) # 1_320 parameters plus 30 states
kan_lux = Chain(KANDense(:RBF, 1, 10, 10; normalizer=tanh),
    KANDense(:RBF, 10, 10, 10; normalizer=tanh),
    KANDense(:RBF, 10, 1, 10; normalizer=tanh)) # 1_320 parameters plus 36 states

x = rand32(rng, 1, 1000) |> device
pM, stM = Lux.setup(rng, mlp) |> device
pK, stK = Lux.setup(rng, kan) |> device
pKL, stKL = Lux.setup(rng, kan_lux) |> device

lossfn(m, x, p, st) = sum(abs2, first(m(x, p, st)))

# CPU Timings
@btime $mlp($x, $pM, $stM)
@btime $kan($x, $pK, $stK)
@btime $kan_lux($x, $pKL, $stKL)

@btime Zygote.gradient($lossfn, $mlp, $x, $pM, $stM)
@btime Zygote.gradient($lossfn, $kan, $x, $pK, $stK)
@btime Zygote.gradient($lossfn, $kan_lux, $x, $pKL, $stKL)

## Forward pass
# 130.254 μs (11 allocations: 254.39 KiB)
# 1.915 ms (39 allocations: 1.20 MiB)
# 1.849 ms (58 allocations: 1.20 MiB)     -- RBF version
# 31.953 ms (182 allocations: 26.90 MiB)  -- BSpline version

## Backward pass
# 642.610 μs (36 allocations: 767.95 KiB)
# 5.527 ms (629 allocations: 8.21 MiB)
# 3.651 ms (624 allocations: 6.06 MiB)
# 37.945 ms (616 allocations: 63.20 MiB)

device = gpu_device()

x = x |> device;
pM, stM = pM |> device, stM |> device;
pK, stK = pK |> device, stK |> device;
pKL, stKL = pKL |> device, stKL |> device;

# CUDA Timings
@btime CUDA.@sync $mlp($x, $pM, $stM)
@btime CUDA.@sync $kan($x, $pK, $stK)
@btime CUDA.@sync $kan_lux($x, $pKL, $stKL)

@btime CUDA.@sync Zygote.gradient($lossfn, $mlp, $x, $pM, $stM)
@btime CUDA.@sync Zygote.gradient($lossfn, $kan, $x, $pK, $stK)
@btime CUDA.@sync Zygote.gradient($lossfn, $kan_lux, $x, $pKL, $stKL)

## Forward pass
# 63.207 μs (237 allocations: 5.70 KiB)
# 312.889 μs (1052 allocations: 24.83 KiB)
# 274.267 μs (1308 allocations: 29.91 KiB) -- RBF version
# 2.179 ms (8881 allocations: 193.25 KiB)  -- BSpline version

## Backward pass
# 405.359 μs (1235 allocations: 41.66 KiB)
# 1.795 ms (5671 allocations: 154.55 KiB)
# 1.677 ms (5610 allocations: 681.97 KiB)
# 4.589 ms (17719 allocations: 477.38 KiB)  -- BSpline version

To match the weights do

using Setfield

@set! pKL.layer_1.base_linear.weight = pK.layer_1.W2
@set! pKL.layer_1.main_model.spline_linear.weight = pK.layer_1.W1
@set! pKL.layer_2.base_linear.weight = pK.layer_2.W2
@set! pKL.layer_2.main_model.spline_linear.weight = pK.layer_2.W1
@set! pKL.layer_3.base_linear.weight = pK.layer_3.W2
@set! pKL.layer_3.main_model.spline_linear.weight = pK.layer_3.W1
vpuri3 commented 2 months ago

looks great. can you also do a comparison with use_base_act = false which disables the silu component and the corresponding linear transformation?

I'm getting a 2x speedup in that case which is weird because that part should not be accounting for 50% of the compute.

avik-pal commented 2 months ago

I'm getting a 2x speedup in that case which is weird because that part should not be accounting for 50% of the compute.

Not very though, CPU broadcasting is actually quite a bit bottleneck in julia. But nome of the multithreaded versions work with Zygote

vpuri3 commented 2 months ago

I saw the 2x improvement in FWD pass and 1.5x on BWD pass on the GPU. This is from my readme when I tested KDense with a 2080 Ti.

# use_base_act = true (default)
@btime CUDA.@sync $kan($x, $pK, $stK) # 155.781 μs (565 allocations: 17.50 KiB)
@btime CUDA.@sync Zygote.gradient($f_kan, $pK) # 1.250 ms (3879 allocations: 136.06 KiB)

# use_base_act = false
@btime CUDA.@sync $kan($x, $p, $st) # 83.275 μs (310 allocations: 10.00 KiB)
@btime CUDA.@sync Zygote.gradient($f, $p) # 874.364 μs (2746 allocations: 99.70 KiB)

I'm curious to see if you get similar results on GPU

avik-pal commented 1 month ago

Closing it here, I am restructuring Boltz.jl a bit and will move this there. Also because most of my experiments with Neural ODEs and stuff seems to suggest KANs don't work well there.