google / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.39k stars 247 forks source link

Handle cases where memstats are not available for the device. #739

Closed lukebaumann closed 2 weeks ago

lukebaumann commented 3 weeks ago

Memstats are not guaranteed to be available and can throw an exception or return None. This change will handle both jaxlib.xla_extension.XlaRuntimeError if the device is not a PjRt addressable device or KeyError if the memstats returns None if they are not available.