Open ezyang opened 7 years ago
x = Variable(torch.FloatTensor([1])) def fn(x): return x trace, _ = torch.jit.record_trace(fn, x) print(str(trace))
This outputs:
<expired TracingState>
It still fails even if you call x.add_(2) or some other inplace operation inside fn.
x.add_(2)
From @apaszke:
That's expected Either make the Variable volatile or say num_derivatives=0 The graph died before you got to first backward so it expired
This outputs:
It still fails even if you call
x.add_(2)
or some other inplace operation inside fn.From @apaszke: