google / objax

Apache License 2.0
768 stars 77 forks source link

Closure scoping for nested objax.Functions #237

Closed mathDR closed 2 years ago

mathDR commented 2 years ago

Hi, I am trying to optimize a list of Gaussian Process models, where I create an optimizer for each "model" in the list, then run the optimizer over a loop.

This is not working. So I concocted an example that (if I can figure out how to do it) would illuminate a lot of what is wrong with my code.

I want to extend the objax.Function() example by doing something like:

import objax
import jax.numpy as jn

models = []
for i in range(2):
    models.append(objax.nn.Linear(2, 3))

def f1(x, y):
    return ((m(x) - y) ** 2).mean()

new_funcs = []
for m in models:
    new_funcs.append(objax.Function(f1,m.vars()))

I know this example doesn't work (for a lot of reasons), but I am trying to understand how to make it work. That is: how can I apply models[i].vars() to f1 inside the new_funcs loop, so that when I run new_funcs[0](x,y) and new_funcs[1](x,y), I get different values?

Because of python's closure scoping I think each function in new_models is just the last call to models, right?

AlexeyKurakin commented 2 years ago

Would something like following work:

models = []
for i in range(2):
    models.append(objax.nn.Linear(2, 3))

def loss(x, y, m):
    return ((m(x) - y) ** 2).mean()

new_funcs = []
for m in models:
    new_funcs.append(objax.Function(lambda x, y, model=m: loss(x, y, model), m.vars()))

another alternative:

models = []
for i in range(2):
    models.append(objax.nn.Linear(2, 3))

new_funcs = []
for m in models:
    def loss(x, y, model=m):
        return ((model(x) - y) ** 2).mean()
    new_funcs.append(objax.Function(loss, m.vars()))

See also example of how to make python loop works with closure:

mathDR commented 2 years ago

This is perfect. Thanks!