-
Hi, I trying for run to saycan code.
but, I met some errors
my setting
OS: Ubuntu 22.04
GPU: RTX3090, nvidia-driver 535, Cuda: 12.2, cuDNN: 8.9.5
chex 0.1.8
optax …
-
The [example in the docs](https://nemos.readthedocs.io/en/latest/generated/api_guide/plot_05_batch_glm/) currently uses a custom loop to implement stochastic gradient descent.
An alternative would …
-
Dear All-
I have a very simple question. I have two neural networks of type `MLP` and I want to initialize optimizer via `optax`.
When I have one neural network I do like
`pyt1 = eqx.filter(bn, e…
-
The newest version of jax seems to require jaxlib v0.3.7, which breaks the trainer script:
```bash
$ ./run_pretrain.sh
2022-04-16 23:34:03.151271: W tensorflow/stream_executor/platform/default/dso…
-
# Description
I would like to create a neural CDE for regression. For that, I have taken the example from [neural CDE for classification](https://docs.kidger.site/diffrax/examples/neural_cde/) and ad…
-
### System Info
all known
### Who can help?
@muellerz @SunMarc
### Information
- [X] The official example scripts
- [ ] My own modified scripts
### Tasks
- [ ] An officially supp…
-
**Describe the bug**
AdamW implementation (see [here](https://github.com/NVIDIA/apex/blob/a7de60e57f0534266841e1733262601ad76aaa74/csrc/multi_tensor_adam.cu#L333)) does not truly decouple the weight…
-
This topic is in my mind every once in a while, it has already been discussed extensively (e.g. https://github.com/deepmind/optax/issues/197#issuecomment-982548377), but I feel it needs new life becau…
-
The conflict is caused by:
praxis 1.4.0 depends on tfds-nightly==4.8.3.dev202303280045
The user requested jax==0.4.26
jax[cuda12] 0.4.26 depends on jax 0.4.26 (from https://pypi.tuna.ts…
-
it seems like the flax examples could use a version bump?
### System information
- OS Platform and Distribution: Ubuntu 22.04.4 LTS
- Flax, jax, jaxlib versions:
```
pip show flax jax jaxlib
…