patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.02k stars 135 forks source link

Export attention function similar to torch.nn.functional.scaled_dot_product_attention #764

Open Artur-Galstyan opened 2 months ago

Artur-Galstyan commented 2 months ago

Currently, the functions exist in the _attention.py file but are not explicitly exported. But a lot of people want to write their own custom MHA implementation and could use these functions.

(I'm aware that I can simply import them nonetheless, but because it's not in the docs and not everyone goes through the source code, that can be easily overseen)

WDYT? Other framework have a dedicated "functional" package in them. It'd be great to have something similar.

patrick-kidger commented 2 months ago

I'd be happy to add these to the public API. Just never had a request for that before!

On the topic of functional APIs, one of the nice thinks about the functional-programming-nature of JAX+Equinox is how we kind of get that for free! if you want functions that look like this:

weight_and_bias = init_params(...)
linear(weight_and_bias, x)

then these can be obtained as just:

init_params = eqx.nn.Linear.__init__
linear = eqx.nn.Linear.__call__