Closed lkct closed 11 months ago
Unfortunately, it's not possible to fully decouple reparam_*
and fold_mask
.
The only obstacle is the MixingLayer
where the mask affects the re-norm of weights (only those reparams that require normalization need mask info).
However, as discussed, for CP (and all layers that reduce the arity dim by product), the mask should still be in the layer (reduction) instead of on params -- a whole branch should be pruned instead of assigning "magic" mask values to params to produce an "appear-pruned" value.
Following the discussion with @loreloc today:
fold_mask
introduced in #129 should be removed fromreparam_*
functions. The masking should be done in the layer class.reparam_*
(freed fromfold_mask
) will be turned into a class that saves the parameter instances. Layers only refer to the reparam classes. This provides another level of abstraction for computational graph of params, useful for chained transformations (e.g. expectation of square).