lifelong-learning-systems / rlblocks

Reinforcement Learning Blocks for Researchers
MIT License
0 stars 0 forks source link

SCP doesn't aggregate matrices across batches in DqnScp #28

Closed coreylowman closed 2 years ago

coreylowman commented 2 years ago

The current usage of SCP in DqnScp is:

            self.scp_loss_fn.set_anchors(key)
            for transitions in self.replay_sampler.generate_batches(128, True):
                batch = collate(transitions)
                self.scp_loss_fn.store_synaptic_response(key, batch.state)

Note that the store_synaptic_response is called for a number of different batches.

However at the top of store_synaptic_response, the synaptic weight matrices for the task are cleared, meaning the repeated calls are erasing the work of previous calls:

    def store_synaptic_response(self, key, batch_state):
        ## Initialize the Synaptic matrix per task
        self._synaptic_response[key] = {}
        for name, curr_param in self._model.named_parameters():
            self._synaptic_response[key][name] = torch.zeros(curr_param.shape) 
coreylowman commented 2 years ago

@neilfendley just realized this, thoughts?

neilfendley commented 2 years ago

Is a bug, will fix asap!