jax-ml / coix

Inference Combinators in JAX
https://coix.readthedocs.io/en/latest/
Apache License 2.0
43 stars 2 forks source link

Add aggregate argument for more flexible loss #38

Closed fehiepsi closed 6 months ago

fehiepsi commented 6 months ago

This is useful if we need to mask out part of the batch.