jettify / pytorch-optimizer

torch-optimizer -- collection of optimizers for Pytorch
Apache License 2.0
3.02k stars 297 forks source link

adahessian bug: it does not support the 3D conv #500

Open ziming-liu opened 1 year ago

ziming-liu commented 1 year ago
            param_size = hv.size()
            if len(param_size) <= 2:  # for 0/1/2D tensor
                # Hessian diagonal block size is 1 here.
                # We use that torch.abs(hv * vi) = hv.abs()
                tmp_output = hv.abs()

            elif len(param_size) == 4:  # Conv kernel
                # Hessian diagonal block size is 9 here: torch.sum() reduces
                # the dim 2/3.
                # We use that torch.abs(hv * vi) = hv.abs()
                tmp_output = torch.mean(hv.abs(), dim=[2, 3], keepdim=True)
            hutchinson_trace.append(tmp_output)

this resutls in an error:

python3.10/site-packages/torch_optimizer/adahessian.py", line 128, in get_trace
    hutchinson_trace.append(tmp_output)
UnboundLocalError: local variable 'tmp_output' referenced before assignment