caelum02 / Lux-AI-Season-2

Lux AI Season 2 - NeurIPS Stage | Team Martian
1 stars 0 forks source link

Jax, Jux tips #3

Open caelum02 opened 11 months ago

caelum02 commented 11 months ago

Jax tutorial?

Glossary

DeviceArray

JAX’s analog of the numpy.ndarray. See jaxlib.xla_extension.DeviceArray.

Tracer

An object used as a standin for a JAX DeviceArray in order to determine the sequence of operations performed by a Python function. Internally, JAX implements this via the jax.core.Tracer class.

JIT & Jaxpr

JIT Jaxpr Jax core explained different-kinds-of-jax-values

jux.tree_utils.map_to_aval

def map_to_aval(pytree):
    return jax.tree_map(lambda x: x.aval, pytree)

PyTree

Weak typed value

Type Promotion

Jax Profiling

https://jax.readthedocs.io/en/latest/profiling.html

Transfer Guard