Once you define the initial thunk graph via constructing a model via chained tensor ops, the thunk graph doesn't need to be regenerated every time we execute the same forward pass (and backward pass). Nor do the kernels. For inference it doesn't really matter since inputs are always different and model parameters are fixed (and there is no backward pass), but for training we update model parameters and repeatedly run forward/backward passes with the same kernels. Right now they all get re-generated when Tensor.realize is called. Ideally the optimizer updates the weights, and we just recompute the pass with the already generated kernels. The update of the weights only affect the buffers anyway.
Once you define the initial thunk graph via constructing a model via chained tensor ops, the thunk graph doesn't need to be regenerated every time we execute the same forward pass (and backward pass). Nor do the kernels. For inference it doesn't really matter since inputs are always different and model parameters are fixed (and there is no backward pass), but for training we update model parameters and repeatedly run forward/backward passes with the same kernels. Right now they all get re-generated when Tensor.realize is called. Ideally the optimizer updates the weights, and we just recompute the pass with the already generated kernels. The update of the weights only affect the buffers anyway.