HMUNACHI / nanodl

A Jax-based library for designing and training transformer models from scratch.
MIT License
274 stars 11 forks source link

Create custom dropout layer which again abstracts the complicated Dropout in Flax/Jax #9

Closed HMUNACHI closed 8 months ago

HMUNACHI commented 9 months ago

Creating a custom dropout layer that abstracts the complicated dropout implementation in Flax/Jax can make it easier to use dropout in various Jax setups.

HMUNACHI commented 8 months ago

done in the dev branch