shivamsaboo17 / Overcoming-Catastrophic-forgetting-in-Neural-Networks

Elastic weight consolidation technique for incremental learning.
121 stars 22 forks source link

Fisher Update causing errors #2

Open Sharut opened 4 years ago

Sharut commented 4 years ago

I am trying to run EWC on my dataset with resnet50 model. While updating the fisher matrix using your function, My code says Cuda out of memory due to "log_liklihoods.append(output[:, target])" in the code. I read this "https://stackoverflow.com/questions/59805901/unable-to-allocate-gpu-memory-when-there-is-enough-of-cached-memory" and figured out the problem using 'detach()'. After doing detach etc, I get an error: RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.To further solve this, I set "allow_unused=True" in autograd. As a result, all my gradients go to 0. Why is this happening?

shivamsaboo17 commented 4 years ago

detach() won't work because then pytorch cannot track gradients for that tensor. So I see that the problem could be due to the fact that I am saving the probabilities for entire dataset and then taking the mean. A simple solution to your problem would be to keep running mean of the liklihoods and not store them in list as I am doing. Let me know if this works for you.

Sharut commented 4 years ago

Oh, Will, that be the correct implementation of the fisher matrix? Can you please share a line of code about what you are saying so that I can make sure that I am doing the right thing. I am new to PyTorch, so I am not that confident. it would be really helpful if you could please clarify with the code

shivamsaboo17 commented 4 years ago

instead of this https://github.com/shivamsaboo17/Overcoming-Catastrophic-forgetting-in-Neural-Networks/blob/e056e6d9f07a1e866b190a053ec5ca1314b3eef5/elastic_weight_consolidation.py#L30 can you try

dl = DataLoader(current_ds, batch_size, shuffle=True)
log_liklihoods = 0
for i, (input, target) in enumerate(dl):
        if i > num_batch:
            break
        output = F.log_softmax(self.model(input), dim=1)
        log_liklihoods = ((i+1) * log_liklihoods + output[:, target]) / (i+2)

I am not sure if this is the actual issue on further thoughts as your computation graph will still keep on getting large (you will have to then flush the liklihood (using detach) after every batch in similar way). However give it a try and let me know if this works

Sharut commented 4 years ago

def save_fisher(fim: Dict[str, Tensor], name, scale=3): for p_name, params in fim.items(): print(p_name) print(params.ndimension()) if params.ndimension() == 1: height, width = params.size(0), 1 elif params.ndimension() == 2: height, width = params.size() else: raise NotImplementedError img = params.view(height, 1, width, 1) \ .repeat(1, scale, 1, scale) \ .view(height scale, width scale) \ .to(torch.device("cpu")) \ .numpy() plt.figure(figsize=(24, 24)) sns.heatmap(img, cmap='Greys') plt.savefig(f"/data/2015P002510/Sharut/EWC_2/fisher_results/{pname:s}{name:s}.png") plt.close()

def fim_diag(model: Module, data_loader: DataLoader, samples_no: int = None, empirical: bool = False, device: torch.device = None, verbose: bool = False, every_n: int = None) -> Dict[int, Dict[str, Tensor]]: fim = {} for name, param in model.named_parameters(): if param.requires_grad: fim[name] = torch.zeros_like(param)

seen_no = 0
last = 0
tic = time.time()
all_fims = dict({})

while samples_no is None or seen_no < samples_no:
    data_iterator = iter(data_loader)
    try:
        data, target = next(data_iterator)
    except StopIteration:
        if samples_no is None:
            break
        data_iterator = iter(data_loader)
        data, target = next(data_loader)

    if device is not None:
        data = data.to(device)
        if empirical:
            target = target.to(device)

    logits = model(data)
    if empirical:
        outdx = target.unsqueeze(1)
        # print("emperical outdx is ", outdx)
    else:
        outdx = Categorical(logits=logits).sample().unsqueeze(1).detach()
        # print("observed outdx is ", outdx)

    samples = logits.gather(1, outdx)
    # print(samples)

    idx, batch_size = 0, data.size(0)
    while idx < batch_size and (samples_no is None or seen_no < samples_no):
        model.zero_grad()
        torch.autograd.backward(samples[idx], retain_graph=True)
        for name, param in model.named_parameters():
            if param.requires_grad:
                fim[name] += (param.grad * param.grad)
                fim[name].detach_()
        seen_no += 1
        idx += 1

        if verbose and seen_no % 100 == 0:
            toc = time.time()
            fps = float(seen_no - last) / (toc - tic)
            tic, last = toc, seen_no
            sys.stdout.write(f"\rSamples: {seen_no:5d}. Fps: {fps:2.4f} samples/s.")

        if every_n and seen_no % every_n == 0:
            print("hello")
            all_fims[seen_no] = {n: f.clone().div_(seen_no).detach_()
                                 for (n, f) in fim.items()}

if verbose:
    if seen_no > last:
        toc = time.time()
        fps = float(seen_no - last) / (toc - tic)
    sys.stdout.write(f"\rSamples: {seen_no:5d}. Fps: {fps:2.5f} samples/s.\n")

# for name, grad2 in fim.items():
#     grad2 /= float(seen_no)

all_fims[seen_no] = fim
print(all_fims)
print("Fisher Matrix found !")

What do you think about this implementation?

afshinrahimi commented 3 years ago

Shouldn't instead of

log_liklihoods.append(output[:, target])

We have

log_liklihoods.append(torch.gather(output, dim=1, index=target.unsqueeze(-1)))

?

Assume target size (batch size) is 64 and output is 64x4 (4 classes), output[:, target] gives me a 64x64 tensor while the intention is to get a 64x1 tensor, right? The alternative line does that.

Great work BTW.

ThomasAtlantis commented 10 months ago

Shouldn't instead of

log_liklihoods.append(output[:, target])

We have

log_liklihoods.append(torch.gather(output, dim=1, index=target.unsqueeze(-1)))

?

Assume target size (batch size) is 64 and output is 64x4 (4 classes), output[:, target] gives me a 64x64 tensor while the intention is to get a 64x1 tensor, right? The alternative line does that.

Great work BTW.

Definitely right!

image