facebookresearch / higher

higher is a pytorch library allowing users to obtain higher order gradients over losses spanning training loops rather than individual training steps.
Apache License 2.0
1.59k stars 123 forks source link

Does higher work with hugging face Adafactor? #124

Open brando90 opened 2 years ago

brando90 commented 2 years ago

When I try to run https://github.com/brando90/higher/blob/master/examples/maml-omniglot.py i get this error:

args.scheduler=None
--------------------- META-TRAIN ------------------------
Starting training!
Traceback (most recent call last):
  File "/home/miranda9/automl-meta-learning/automl-proj-src/experiments/meta_learning/main_metalearning.py", line 441, in <module>
    main_resume_from_checkpoint(args)
  File "/home/miranda9/automl-meta-learning/automl-proj-src/experiments/meta_learning/main_metalearning.py", line 403, in main_resume_from_checkpoint
    run_training(args)
  File "/home/miranda9/automl-meta-learning/automl-proj-src/experiments/meta_learning/main_metalearning.py", line 413, in run_training
    meta_train_fixed_iterations(args)
  File "/home/miranda9/automl-meta-learning/automl-proj-src/meta_learning/training/meta_training.py", line 233, in meta_train_fixed_iterations
    args.outer_opt.step()
  File "/home/miranda9/miniconda3/envs/metalearning_gpu/lib/python3.9/site-packages/torch/optim/optimizer.py", line 88, in wrapper
    return func(*args, **kwargs)
  File "/home/miranda9/miniconda3/envs/metalearning_gpu/lib/python3.9/site-packages/transformers/optimization.py", line 577, in step
    update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
  File "/home/miranda9/miniconda3/envs/metalearning_gpu/lib/python3.9/site-packages/transformers/optimization.py", line 508, in _approx_sq_grad
    return torch.mm(r_factor.unsqueeze(-1), c_factor.unsqueeze(0))
RuntimeError: mat1 must be a matrix, got 4-D tensor

related:

brando90 commented 1 year ago

https://github.com/facebookresearch/higher/issues/139

sfxjh commented 1 year ago

I also encountered a similar problem, did you solve it?