LuxDL / Lux.jl

Elegant & Performant Scientific Machine Learning in Julia
https://lux.csail.mit.edu/
MIT License
445 stars 50 forks source link

Immutable Arrays #8

Open avik-pal opened 2 years ago

avik-pal commented 2 years ago

Testing out the Immutable Arrays from https://github.com/JuliaLang/julia/pull/44381 with #7

TLDR: Performance is a slight pain (seems broadcasting) right now, but it is very straightforward to support these once the functionality is available in Base

EDIT: Code updated to work for Lux 0.4.*

Trial 1: From the Usage Example

using Lux, Random, Functors

make_immutable(x::AbstractArray) = ImmutableArray(copy(x))
make_immutable(x) = x

# Construct the layer
model = Chain(BatchNorm(128), Dense(128, 256, tanh), BatchNorm(256),
                        Chain(Dense(256, 1, tanh), Dense(1, 10)))

# Parameter and State Variables
ps, st = Lux.setup(MersenneTwister(0), model)
ps_immutable = fmap(make_immutable, ps)
st_immutable = fmap(make_immutable, st)

# Dummy Input
x = randn(Float32, 128, 1024)
x_immutable = make_immutable(x)

# Run the model
@benchmark $model($x, $ps, $st)
@benchmark $model($x_immutable, $ps_immutable, $st_immutable)

Standard Abstract Arrays

BenchmarkTools.Trial: 1296 samples with 1 evaluation.
 Range (min … max):  2.125 ms … 26.658 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     3.096 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   3.836 ms ±  2.313 ms  ┊ GC (mean ± σ):  2.58% ± 7.71%

    ▂█                                                        
  ▆▄██▇▆▄▄▅▅▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▂▁▂▂▂▂▂▁▂▂▂▂▁▁▂▂▂▂▂▁▂▂▂▂▁▂▂▂ ▃
  2.13 ms        Histogram: frequency by time        14.1 ms <

 Memory estimate: 3.60 MiB, allocs estimate: 144.

Immutable Arrays

BenchmarkTools.Trial: 41 samples with 1 evaluation.
 Range (min … max):  107.855 ms … 159.665 ms  ┊ GC (min … max): 3.98% … 2.64%
 Time  (median):     119.911 ms               ┊ GC (median):    3.54%
 Time  (mean ± σ):   123.706 ms ±  10.746 ms  ┊ GC (mean ± σ):  3.54% ± 0.67%

              ▂█▄                                                
  ▄▁▁▁▁▁▁▁▄▆▄█████▄▁▄▆▄▆▁▁▄▁▁▄▁▁▁▁▁▁▄▁▁▁▁▄▁▁▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▆ ▁
  108 ms           Histogram: frequency by time          160 ms <

 Memory estimate: 58.32 MiB, allocs estimate: 3418558.

Trial 2: Only a Dense Layer

# Construct the layer
model = Dense(128, 256)

# Parameter and State Variables
ps, st = Lux.setup(MersenneTwister(0), model)
ps_immutable = fmap(make_immutable, ps);
st_immutable = fmap(make_immutable, st);

# Dummy Input
x = randn(Float32, 128, 1024);
x_immutable = make_immutable(x);

# Run the model
@benchmark $model($x, $ps, $st)
@benchmark $model($x_immutable, $ps_immutable, $st_immutable)

Standard Abstract Arrays

BenchmarkTools.Trial: 4469 samples with 1 evaluation.
 Range (min … max):  483.810 μs … 30.894 ms  ┊ GC (min … max): 0.00% …  0.00%
 Time  (median):     716.669 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):     1.100 ms ±  1.501 ms  ┊ GC (mean ± σ):  5.01% ± 12.19%

  █▆▆▅▄▃▂▂▂▂▃▃▃▂▁                                              ▁
  █████████████████▇▇▇▆▇▆▅▅▃▃▄▅▅▄▃▅▁▁▆▄▅▁▃▃▃▃▅▁▃▃▃▃▁▃▁▁▃▁▁▁▁▃▅ █
  484 μs        Histogram: log(frequency) by time      7.69 ms <

 Memory estimate: 2.00 MiB, allocs estimate: 4.

Immutable Arrays

BenchmarkTools.Trial: 259 samples with 1 evaluation.
 Range (min … max):  15.392 ms … 52.229 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     17.997 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   19.327 ms ±  4.194 ms  ┊ GC (mean ± σ):  1.72% ± 4.44%

    ▃▆█ ▂                                                      
  ▃▆███▆█▇▅▇▇▄▆▃▆▄▄▅▄▄▄▄▄▃▄▄▃▂▁▃▃▂▁▃▂▁▁▂▂▁▂▂▂▁▃▁▃▂▂▁▁▁▂▂▁▂▂▁▂ ▃
  15.4 ms         Histogram: frequency by time        32.6 ms <

 Memory estimate: 7.00 MiB, allocs estimate: 262153.

Seems like there is a lot of time being spent on broadcasting the bias (seems like a problem with broadcasting in general)

julia> @benchmark $ps_immutable.weight * $x_immutable
BenchmarkTools.Trial: 4032 samples with 1 evaluation.
 Range (min … max):  346.287 μs … 51.079 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     540.489 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):     1.224 ms ±  1.854 ms  ┊ GC (mean ± σ):  2.36% ± 8.18%

  █▆▄▄▃▁▁▁ ▂▂▁▁▁▂▂▁▁  ▁▁                                       ▁
  █████████████████████████▇▇▇▆▇▆▇▆▆▃▆▆▆▅▅▅▅▄▅▅▅▆▅▅▅▅▅▅▄▃▁▁▁▃▃ █
  346 μs        Histogram: log(frequency) by time      8.78 ms <

 Memory estimate: 1.00 MiB, allocs estimate: 5.

julia> @benchmark $ps_immutable.weight * $x_immutable .+ $ps_immutable.bias
BenchmarkTools.Trial: 338 samples with 1 evaluation.
 Range (min … max):  11.177 ms … 33.105 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     13.699 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   14.792 ms ±  3.901 ms  ┊ GC (mean ± σ):  2.43% ± 5.87%

   █▃                                                          
  ▅██▇▇▅▅▇▅▇▅▅▄▅▅▄▃▃▄▄▃▂▃▃▁▂▃▁▃▂▃▃▃▁▃▂▂▂▁▃▁▁▂▂▂▂▁▁▁▁▁▁▁▁▁▃▂▁▂ ▃
  11.2 ms         Histogram: frequency by time        30.9 ms <

 Memory estimate: 7.00 MiB, allocs estimate: 262153.

Trial 3: No broadcasting

model = Dense(128, 256; bias=false)

# Parameter and State Variables
ps, st = Lux.setup(MersenneTwister(0), model)
ps_immutable = fmap(make_immutable, ps);
st_immutable = fmap(make_immutable, st);

# Run the model
@benchmark $model($x, $ps, $st)
@benchmark $model($x_immutable, $ps_immutable, $st_immutable)

Standard Abstract Arrays

BenchmarkTools.Trial: 5501 samples with 1 evaluation.
 Range (min … max):  295.161 μs … 23.801 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     451.402 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   899.925 μs ±  1.386 ms  ┊ GC (mean ± σ):  3.10% ± 8.68%

  █▆▆▄▃▂▁▂▁▁▁▂▂▂▂▁ ▁                                           ▁
  ██████████████████▇█▇█▇▇▆▆▇▇▆▆▆▆▆▆▅▅▅▆▅▅▁▆▄▆▅▃▅▄▅▄▆▄▅▁▄▆▅▅▃▅ █
  295 μs        Histogram: log(frequency) by time      6.98 ms <

 Memory estimate: 1.00 MiB, allocs estimate: 2.

Immutable Arrays

BenchmarkTools.Trial: 5303 samples with 1 evaluation.
 Range (min … max):  311.574 μs … 26.953 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     436.316 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   930.509 μs ±  1.488 ms  ┊ GC (mean ± σ):  3.23% ± 8.75%

  █▆▅▃▂▁   ▁▁▂▁▁                                               ▁
  █████████████████▆█▇▇▆▆▆▆▆▆▆▆▆▅▅▅▅▅▅▄▄▅▅▅▅▂▅▂▄▅▄▅▄▄▃▂▃▄▄▂▃▂▃ █
  312 μs        Histogram: log(frequency) by time      7.61 ms <

 Memory estimate: 1.00 MiB, allocs estimate: 5.

Trial 4

model = Chain(Dense(128, 256; bias=false), Chain(Dense(256, 512; bias=false),
                                                                                   Dense(512, 10; bias=false)))

# Parameter and State Variables
ps, st = Lux.setup(MersenneTwister(0), model)
ps_immutable = fmap(make_immutable, ps);
st_immutable = fmap(make_immutable, st);

# Run the model
@benchmark $model($x, $ps, $st)
@benchmark $model($x_immutable, $ps_immutable, $st_immutable)

Standard Abstract Arrays

BenchmarkTools.Trial: 1372 samples with 1 evaluation.
 Range (min … max):  1.380 ms … 49.871 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     2.918 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   3.615 ms ±  3.116 ms  ┊ GC (mean ± σ):  2.42% ± 7.94%

  ▅█    ▃                                                     
  ███▇▆▇██▇▆▅▄▄▄▃▃▃▃▂▃▃▃▂▃▂▃▂▂▂▂▁▂▂▂▂▂▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▂▁▂▂ ▃
  1.38 ms        Histogram: frequency by time        15.8 ms <

 Memory estimate: 3.04 MiB, allocs estimate: 6.

Immutable Arrays

BenchmarkTools.Trial: 894 samples with 1 evaluation.
 Range (min … max):  1.505 ms … 66.104 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     4.153 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   5.561 ms ±  5.432 ms  ┊ GC (mean ± σ):  1.87% ± 7.54%

  █▆▅▅▅▄▅▆▆▅▄▄▂▂▂▂▁     ▁  ▁     ▁                            
  █████████████████▇█▆███▆▇█▅▆▇███▆▇█▄▇▇▇▅▄▆▅▅▁▄▁▆▄▁▅▇▅▄▄▆▁▅ █
  1.5 ms       Histogram: log(frequency) by time     23.1 ms <

 Memory estimate: 3.04 MiB, allocs estimate: 17.

cc @ChrisRackauckas @ianatol @aviatesk

ianatol commented 2 years ago

I think the poor broadcasting performance likely has to do with some missed chance to perform our memory optimization in the broadcast logic somewhere (i.e., we think it is unsafe to optimize in a place where it's actually safe to do so). I will take a look into this when I get a chance, but thanks for putting this together and providing a nice, realistic benchmark for performance going forward!

ianatol commented 2 years ago

Also, minor nit, but ImmutableArray will copy by itself if we can't optimize, so don't think copy is necessary here:

make_immutable(x::AbstractArray) = ImmutableArray(copy(x))
avik-pal commented 2 years ago

I added it for ReshapedArray doesn't seem to have a dispatch for that