PKUFlyingPig / CMU10-714

Learning material for CMU10-714: Deep Learning System
214 stars 34 forks source link

能请教一下关于 adam 内存check 验证失败的问题么 #3

Closed wplf closed 6 months ago

wplf commented 6 months ago

您好,关于 adam 内存验证失败后, 发现无法调试代码,减少Tensor使用,请问您能帮我看看么?

我的代码如下

class Adam(Optimizer):
    def __init__(
        self,
        params,
        lr=0.01,
        beta1=0.9,
        beta2=0.999,
        eps=1e-8,
        weight_decay=0.0,
    ):
        super().__init__(params)
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.eps = eps
        self.weight_decay = weight_decay
        self.t = 0
        from collections import defaultdict
        self.m = defaultdict(float)
        self.v = defaultdict(float)

    def step(self):
        ### BEGIN YOUR SOLUTION
        self.t += 1
        for param in self.params:
            grad = param.grad.detach() + self.weight_decay * param.detach()

            self.m[param] = self.beta1 * self.m.get(param, 0) + (1 - self.beta1) * grad
            self.v[param] = self.beta2 * self.v.get(param, 0) + (1 - self.beta2) * (grad ** 2)
            # breakpoint()
            m_t1_hat = self.m[param] / (1 - self.beta1 ** (self.t))
            v_t1_hat = self.v[param] / (1 - self.beta2 ** (self.t))

            param.cached_data -= (self.lr * m_t1_hat / ((v_t1_hat ** 0.5) + self.eps)).cached_data

命令行提示如下

========================================================================= test session starts ==========================================================================
platform linux -- Python 3.12.2, pytest-8.1.1, pluggy-1.5.0 -- /home/wplf/miniconda3/bin/python3
cachedir: .pytest_cache
rootdir: /home/wplf/dl-sys/hw2
collected 93 items / 86 deselected / 7 selected                                                                                                                        

tests/hw2/test_nn_and_optim.py::test_optim_adam_1 PASSED                                                                                                         [ 14%]
tests/hw2/test_nn_and_optim.py::test_optim_adam_weight_decay_1 PASSED                                                                                            [ 28%]
tests/hw2/test_nn_and_optim.py::test_optim_adam_batchnorm_1 PASSED                                                                                               [ 42%]
tests/hw2/test_nn_and_optim.py::test_optim_adam_batchnorm_eval_mode_1 PASSED                                                                                     [ 57%]
tests/hw2/test_nn_and_optim.py::test_optim_adam_layernorm_1 PASSED                                                                                               [ 71%]
tests/hw2/test_nn_and_optim.py::test_optim_adam_weight_decay_bias_correction_1 PASSED                                                                            [ 85%]
tests/hw2/test_nn_and_optim.py::test_optim_adam_z_memory_check_1 FAILED                                                                                          [100%]
wplf commented 6 months ago

Thanks, 我已经解决了这个问题,谢谢你的代码,原因是我在 batchnorm 里多开了很多变量!