-
The [example in the docs](https://nemos.readthedocs.io/en/latest/generated/api_guide/plot_05_batch_glm/) currently uses a custom loop to implement stochastic gradient descent.
An alternative would …
-
I am building GPT kinda model in Equinox, and right now the forward pass is extremely slow compared to my torch implementation. I think this is one of the cases where I would like to attach a profiler…
-
Currently, using `optax.MultiSteps` with the rest of equinox is not possible. It will fail with the error `TypeError: Cannot interpret value of type as an abstract array; it does not have a dtype att…
-
Hello, what's the current roadmap for jaxopt migration into optax? Will the scope of jaxopt be maintained, or will a trimming/expansion of features happen?
-
Note the redundancy and all the "Alias for field number X" in classes such as optax.MultiStepsState: https://optax.readthedocs.io/en/latest/api/optimizer_wrappers.html#optax.MultiStepsState
![image…
-
Is there an equivalent to `flax.optim.WeightNorm`? As `flax.optim` is effectively deprecated in favor of optax, I would like to see it implemented in optax.
-
### Description
When using optax, I found an unexpectedly large consumption of memory in the MultiStep method (https://github.com/google-deepmind/optax/issues/472).
Digging deeper, the problem …
-
If I use Adafactor with MultiStep on a bfloat16 model I get this strange error (note the error is extremely long, so I truncated it to fit in the issue; the model is T5-small):
```
Traceback (mo…
-
I am trying to solve an inverse problem with the jaxpi package. How would I go about defining the trainable inverse parameter? Would I need to change the source code of the library?
-
I was trying to run Vivit model according to the _quick start_, I obtained this error:
**TypeError: get_optimizer() missing 1 required positional argument: 'learning_rate_fn'**
Since the train_lib_…