Closed AlexeyKurakin closed 3 years ago
Typically user has to specify list of variables used by a function:
model = WideResNet(nin=3, nclass=10, depth=wrn_depth, width=wrn_width) @objax.Function.with_vars(model.vars()) def loss_fn(x, label): logit = model(x, training=True) return objax.functional.loss.cross_entropy_logits_sparse( logit, label).mean()
The goal of the tracing is to automatically detect used variables, so the code could look like following:
model = WideResNet(nin=3, nclass=10, depth=wrn_depth, width=wrn_width) @objax.Function.auto_vars def loss_fn(x, label): logit = model(x, training=True) return objax.functional.loss.cross_entropy_logits_sparse( logit, label).mean()
Typically user has to specify list of variables used by a function:
The goal of the tracing is to automatically detect used variables, so the code could look like following: