cgarciae / treex

A Pytree Module system for Deep Learning in JAX
https://cgarciae.github.io/treex/
MIT License
215 stars 17 forks source link

loss_and_log fails when there is no loss #75

Open jiyuuchc opened 2 years ago

jiyuuchc commented 2 years ago

Ran into this bug in a rare edge case

in loss_and_logs.py:

    def compute(self) -> tp.Tuple[jnp.ndarray, Logs, Logs]:

        if self.losses is not None:
            loss, losses_logs = self.losses.compute()
        else:
            loss = jnp.zeros(0.0, dtype=jnp.float32) <--- should be jnp.array(0., float)
            losses_logs = {}