a class variable is_functional (we use that to choose the implementation of the ModelOutput class)
either model or func_model, func_weights, func_buffers as attributes
a method load_model_params that gets called whenever we switch checkpoints (in TRAKer.load_checkpoint) and updates the attributes
this way there is no need for explicit if statements in the TRAKer class (or user code) checking whether we're using the functional gradient computer.
Additionally, now the signatures of the compute_per_sample_grad methods match across subclasses (as they should); and last, in the TRAKer init, instead of the functional: bool argument we pass in gradient_computer: AbstractGradientComputer.
Refactored the
GradientComputer
class to haveis_functional
(we use that to choose the implementation of theModelOutput
class)model
orfunc_model
,func_weights
,func_buffers
as attributesload_model_params
that gets called whenever we switch checkpoints (inTRAKer.load_checkpoint
) and updates the attributesthis way there is no need for explicit
if
statements in theTRAKer
class (or user code) checking whether we're using the functional gradient computer.Additionally, now the signatures of the
compute_per_sample_grad
methods match across subclasses (as they should); and last, in theTRAKer
init, instead of thefunctional: bool
argument we pass ingradient_computer: AbstractGradientComputer
.