LuxDL / Lux.jl

Elegant & Performant Scientific Machine Learning in Julia
https://lux.csail.mit.edu/
MIT License
446 stars 50 forks source link

Add CartesianEmbedding layer #668

Closed ldeso closed 1 month ago

ldeso commented 1 month ago

Lux currently offers an Embedding layer that calls the NNlib.gather method with linear indices as arguments.

This pull request adds a CartesianEmbedding layer that allows using the other NNlib.gather method with Cartesian indices as arguments. This eliminates the need to convert manually from linear to Cartesian indices to use NNlib.gather, which significantly improves performance and reduce the number of allocations, especially on the GPU.

Example

# Julia v1.10.3
using Pkg
Pkg.activate(; temp=true)
Pkg.add(; url="https://github.com/ldeso/Lux.jl", rev="add-cartesianembedding-layer")
using Lux, Random

weights = [10.0 20.0 30.0
           40.0 50.0 60.0]
i = [1, 1, 1, 2, 2]
j = [1, 2, 3, 1, 2]
idx = (i, j)

rng = Random.default_rng()
model = CartesianEmbedding(size(weights) => 1; init_weight=Returns(weights))

ps, st = Lux.setup(rng, model)

model(idx, ps, st)[1]  # returns [10.0, 20.0, 30.0, 40.0, 50.0]

Closes #670

ldeso commented 1 month ago

Force-pushed to pass the method ambiguity check, also moved the tests to the "Miscellaneous Layers" test item.

avik-pal commented 1 month ago

Does this need to be a separate layer? We can just use multiple dispatch to combine with the Embedding layer

ldeso commented 1 month ago

I created a separate pull request with multiple dispatch to let you compare (#670). While the solution with multiple dispatch results in less code, it also requires a modification of the Embedding struct which might be a bit risky (even though it passes all tests).

avik-pal commented 1 month ago

https://github.com/LuxDL/Lux.jl/pull/670 supersedes this