QueuQ / CGLB

Other
50 stars 15 forks source link

GEM Memory_data #15

Closed WMX567 closed 7 months ago

WMX567 commented 1 year ago
tmask = np.random.choice(self.mask, self.n_memories, replace = False)
tmask = np.array(tmask)
self.memory_data.append(tmask)

old_task_loss = loss[self.memory_data[old_task_i],old_task_i].mean()

However, the loss is (batch_size, 1). Why use self.memory_data[old_task_i] won't go out of memory?

QueuQ commented 1 year ago

Hi,

I guess you are actually asking about why the index does not go out of range? If that is the case, the answer is that the batch_size used here equals the size of the train_set. In other words, we are actually selecting retraining data from the complete training set, so every index does not go out of range.

https://github.com/QueuQ/CGLB/blob/f6628290b34c958ceff347c1b52c5138f3f1ef23/GCGL/pipeline.py#L339