self.scp_loss_fn.set_anchors(key)
for transitions in self.replay_sampler.generate_batches(128, True):
batch = collate(transitions)
self.scp_loss_fn.store_synaptic_response(key, batch.state)
Note that the store_synaptic_response is called for a number of different batches.
However at the top of store_synaptic_response, the synaptic weight matrices for the task are cleared, meaning the repeated calls are erasing the work of previous calls:
def store_synaptic_response(self, key, batch_state):
## Initialize the Synaptic matrix per task
self._synaptic_response[key] = {}
for name, curr_param in self._model.named_parameters():
self._synaptic_response[key][name] = torch.zeros(curr_param.shape)
The current usage of SCP in DqnScp is:
Note that the
store_synaptic_response
is called for a number of different batches.However at the top of
store_synaptic_response
, the synaptic weight matrices for the task are cleared, meaning the repeated calls are erasing the work of previous calls: