poets-ai / elegy

A High Level API for Deep Learning in JAX
https://poets-ai.github.io/elegy/
MIT License
469 stars 32 forks source link

[Bug] Example is broken on GPU #238

Open CForgie opened 2 years ago

CForgie commented 2 years ago

The example listed in https://poets-ai.github.io/elegy/getting-started/high-level-api/ doesn't work on GPU with the latest version of elegy - when calling model.fit training starts to take place but stops short of completing one epoch and hangs. Downgrading to 0.8.4 seems to fix this.

Library Info

elegy 0.8.6 flax 0.4.2 jax 0.3.13 jax-metrics 0.1.2 jaxlib 0.3.10 optax 0.1.2 treex 0.6.10

CUDA Version: 11.6 Tesla T4

Running ami: Deep Learning Base AMI (Amazon Linux 2) Version 53*

CForgie commented 2 years ago

Seems similar to #234