Closed dmankins closed 5 years ago
Here we only run the tracing to check as early as possible if the tracing is actually possible with the chosen model.
With --jit=onsave
, we "Use regular Python model for training, but trace it on-demand for saving training state;" (see https://github.com/ELEKTRONN/elektronn3/blob/35dbc1bc/examples/train_unet_neurodata.py#L46). The rationale for this mode is that JIT-traced models sometimes have issues when they are used for training (mainly due to control flow issues and slightly different behavior due to graph optimizations etc.), so in many cases it's a better idea to train the Python model and only JIT-trace it on demand when saving it to disk.
Replacing the model
with the tracedmodel
would have the same effect as the --jit=train
option.
I will change this line to _ = torch.jit.trace(model, example_input.to(device))
to make it more clear that we don't want to actually use the resulting trace.
Around line 99 of train_unet_neurodata.py, this code appears:
Is a
model = tracedmodel
assignment missing from theif args.jit == 'onsave'
TRUE branch? That is, should the code read: