logix-project / logix

AI Logging for Interpretability and Explainability🔬
Apache License 2.0
76 stars 6 forks source link

Set `requires_grad=False` for all parameters to save memory #33

Closed sangkeun00 closed 7 months ago

sangkeun00 commented 9 months ago

In AnaLog, we are mostly interested in per-sample gradient instead of mini-batch gradient. Therefore, we don't necessarily need to populate gradient for each parameter to param.grad. By setting requires_grad=False for all parameters, we can potentially save a lot of GPU memory.

Proposed design

class AnaLog:
    ...

    def watch(model, populate_grad=False, ...):
        if not populate_grad:
            for p in model.parameters():
                p.requires_grad = False

Caution

We should make sure that setting requires_grad=False doesn't change the behavior of forward/backward/grad_hooks of the module these parameters belong to.

YoungseogChung commented 9 months ago

@hwijeen @sangkeun00 I'll take care of this!

sangkeun00 commented 9 months ago

@YoungseogChung Another option you can consider is deleting param.grad as soon as it gets populated in (probably) backward_hook, especially if setting requires_grad=False changes the hook behavior.

sangkeun00 commented 9 months ago

I actually tried del module.weight/bias.grad idea in the recent commit.

Code: https://github.com/sangkeun00/analog/blob/main/analog/logging.py#L134-L136

However, it doesn't seem to reduce the GPU memory usage. If @YoungseogChung can debug the issue, that would be great! Also, feel free to propose a different strategy to reduce the GPU memory usage.

YoungseogChung commented 9 months ago

Got it, thanks for the heads up, and too bad it's not that simple lol. Yes, I'll tinker with it and search for solutions.

sangkeun00 commented 7 months ago

This issue is solved in d2e2fea0becc0632cb03a008fac9454f36b957d7 .