lanl / scico

Scientific Computational Imaging COde
BSD 3-Clause "New" or "Revised" License
100 stars 17 forks source link

Incorrect handling of scale in `Loss.grad` #468

Closed bwohlberg closed 10 months ago

bwohlberg commented 10 months ago

There is a bug in Loss.grad handling of the scale attribute, but only when it's set via scalar multiplication:

import jax
from scico.loss import SquaredL2Loss
from scico.functional import L2Norm
import scico.numpy as snp

f = SquaredL2Loss(y=snp.zeros((4,)))
g = SquaredL2Loss(y=snp.zeros((4,)), scale=5)
h = 10 * f

# __call__ is correct
f(snp.ones((4,)))
>> Array(2., dtype=float32)
g(snp.ones((4,)))
>> Array(20., dtype=float32)
h(snp.ones((4,)))
>> Array(20., dtype=float32)

# __grad__ is broken
f.grad(snp.ones((4,)))
>> Array([1., 1., 1., 1.], dtype=float32)
g.grad(snp.ones((4,)))
>> Array([10., 10., 10., 10.], dtype=float32)
h.grad(snp.ones((4,)))
>> Array([1., 1., 1., 1.], dtype=float32

The same bug is not present in Functional.grad:

f = L2Norm()
g = 10 * f

f.grad(snp.ones((4,)))
>> Array([0.5, 0.5, 0.5, 0.5], dtype=float32)
g.grad(snp.ones((4,)))
>> Array([5., 5., 5., 5.], dtype=float32)
bwohlberg commented 10 months ago

The bug turns out to be due to a combination of this https://github.com/lanl/scico/blob/5ffd1f9046fa8e7e2e2f7d37846993a3ab2221dd/scico/functional/_functional.py#L35-L36 and this https://github.com/lanl/scico/blob/5ffd1f9046fa8e7e2e2f7d37846993a3ab2221dd/scico/loss.py#L126-L129 The copy call does not result in an __init__ call, so the new Loss object ends up with _grad set to the function that was originally constructed when __init__ was called for the "original", unscaled Loss object.

PR #470 has a simple fix, but this issue raises a few broader design questions: