Implement biasd-path reparametrization instead of learning-signal based optimization for learning the controller.
Differentiable relaxation of the sampling provide a more controllable way of backward the gradients through the discrete arch r.v.s. As the 'hard/soft' level can control the bias-variance trade-off of the learning process.
Details:
Need a new DifferentiableSuperNet weights manager and a new DifferentiableController controller;
In the DifferentiableController, the sampling probability of the node/op on each edge are modeled as global parameters. At first, we can the sample the operation only.
The conditional dependency implied by RNN-based controller is principlely prefered, as the architecture decisions/actions should be dependent, especially the choices of input nodes for one step. And, I suppose that with independent node/op decision for every step, the search will be trapped more quickly into a local minimum, and does not explore very well. But how can we sample all the architecture parameters, when there are so many discrete decisions in the whole sampling process. Or we can use a network similar to rl controller networks to sample the sampled path, and use the global parameters as the learnable prior of every op/edges, when there is no sampling for that edge, just use the prior op distribution on that edge...
The rollout object passed through weights manager and controller is also different from the current Rollout, as the arch representation is different now... So rollout must have their subclasses too... The weights manager (the consumer/assembler) and the controller (the producer/sampler) are generally not agnostic to the rollout type, so it's reasonable to add an interface to specificy which type of rollout a controller produce and a weights manager can take. The main script can be responsible to check if this rollout interface match. The handling of DifferentiableRollout in trainer is different too... e.g. mepa trainer should call set_perf with in-graph loss tensor when using differentiable rollout (eval should pass self._criterion instead of _ce_loss_mean in), but call set_perf with acc or detached loss when using DiscreteRollout...
Reuse some of the controller network code, add supports for sample and return actions as one-hot sample (could be a "soft" relaxed one, e.g. samples from a gumbel-softmax as a relaxation for categorical samples.) Cannot reuse the code... As the differential relaxation need sample for every op and edges...
Description
Implement biasd-path reparametrization instead of learning-signal based optimization for learning the controller.
Differentiable relaxation of the sampling provide a more controllable way of backward the gradients through the discrete arch r.v.s. As the 'hard/soft' level can control the bias-variance trade-off of the learning process.
Details:
DifferentiableSuperNet
weights manager and a newDifferentiableController
controller;DifferentiableController
, the sampling probability of the node/op on each edge are modeled as global parameters. At first, we can the sample the operation only.Rollout
, as the arch representation is different now... So rollout must have their subclasses too... The weights manager (the consumer/assembler) and the controller (the producer/sampler) are generally not agnostic to the rollout type, so it's reasonable to add an interface to specificy which type of rollout a controller produce and a weights manager can take. The main script can be responsible to check if this rollout interface match. The handling of DifferentiableRollout in trainer is different too... e.g. mepa trainer should callset_perf
with in-graph loss tensor when using differentiable rollout (eval should passself._criterion
instead of_ce_loss_mean
in), but callset_perf
with acc or detached loss when using DiscreteRollout...Reuse some of the controller network code, add supports for sample and return actions as one-hot sample (could be a "soft" relaxed one, e.g. samples from a gumbel-softmax as a relaxation for categorical samples.)Cannot reuse the code... As the differential relaxation need sample for every op and edges...