google / objax

Apache License 2.0
769 stars 77 forks source link

Prototype automatic tracing of variables used by a function #197

Closed AlexeyKurakin closed 3 years ago

AlexeyKurakin commented 3 years ago

Typically user has to specify list of variables used by a function:

model = WideResNet(nin=3, nclass=10, depth=wrn_depth, width=wrn_width)

@objax.Function.with_vars(model.vars())
def loss_fn(x, label):
  logit = model(x, training=True)
  return objax.functional.loss.cross_entropy_logits_sparse(
      logit, label).mean()

The goal of the tracing is to automatically detect used variables, so the code could look like following:

model = WideResNet(nin=3, nclass=10, depth=wrn_depth, width=wrn_width)

@objax.Function.auto_vars
def loss_fn(x, label):
  logit = model(x, training=True)
  return objax.functional.loss.cross_entropy_logits_sparse(
      logit, label).mean()