ltatzel / PyTorchHessianFree

PyTorch implementation of the Hessian-free optimizer
BSD 3-Clause "New" or "Revised" License
30 stars 6 forks source link

Use the HF optimizer for Reinforcement Learning (TD3) #5

Closed f-pfeiffer closed 4 months ago

f-pfeiffer commented 4 months ago

Hi,

I tried to use this optimizer firstly only for the actor network of a TD3 agent. Here is some simplified Code: (actor and qf1 are my NNs)

actor_optimizer = HessianFree(params=self.actor.parameters())

actions = actor(observation)

# ---------------------------------------------
# Training logic for the Q-Values etc.
# ---------------------------------------------

actor_optimizer.zero_grad()
def forward_actor():
    loss = -self.qf1(torch.cat([observations, actor(observations)], 1)).mean()
    return loss, outputs
actor_optimizer.step(forward=forward_actor)

I get the following error message

File ~/anaconda3/envs/ma310/lib/python3.10/site-packages/backpack/hessianfree/hvp.py:45, in hessian_vector_product(f, params, v, grad_params, detach)
     43     df_dx = tuple(grad_params)
     44 else:
---> 45     df_dx = torch.autograd.grad(f, params, create_graph=True, retain_graph=True)
     47 Hv = R_op(df_dx, params, v)
     49 if detach:

RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

I guess its because the loss is the output of a NN. Is there a correct way to define the forward function?

Best, Felix

ltatzel commented 4 months ago

Hi Felix, Thanks for opening an issue!

I guess its because the loss is the output of a NN.

In principle, the optimizer should work fine with a loss that depends on the output of a NN.

In your example, some things are not clear to me, e.g. how is outputs in forward_actor defined? Could you maybe provide a minimal stand-alone example such that I can replicate the error?

Best, Lukas

f-pfeiffer commented 4 months ago

Hi Lukas,

Thank you for your response. Sorry my Code was misleading and the error may also be caused by something else. Here is a minimal stand-alone example to replicate the error:

import torch
from hessianfree.optimizer import HessianFree

BATCH_SIZE = 1
OBSERVATION_SPACE = 32
ACTION_SPACE = 4

observation = torch.randn(BATCH_SIZE, OBSERVATION_SPACE)

QNet = torch.nn.Sequential(
    torch.nn.Linear(OBSERVATION_SPACE + ACTION_SPACE, 1),
)

Actor = torch.nn.Sequential(
    torch.nn.Linear(OBSERVATION_SPACE, ACTION_SPACE),
)

opt = HessianFree(
    params=Actor.parameters(),
    verbose=True
)

def forward():
    actor_losses = -QNet(torch.cat([observation, Actor(observation)], 1))
    loss = actor_losses.mean()
    outputs = Actor(observation)
    return loss, outputs

loss = opt.step(forward=forward)

Best, Felix

f-pfeiffer commented 4 months ago

I think recomputing the actions caused the problem because then the output is not part of the computation graph of the loss:

def forward():
    actions = Actor(observation)
    actor_losses = -QNet(torch.cat([observation, actions], 1))
    actor_loss = actor_losses.mean()
    return actor_loss, actions

However, now I get an other error message:

File ~/PyTorchHessianFree/hessianfree/optimizer.py:258, in HessianFree.step(self, forward, grad, mvp, M_func, test_deterministic)
    255 store_x_at_iters = None if self.use_cg_backtracking else [0]
    257 # Apply cg
--> 258 x_iters, m_iters, cg_reason = cg(
    259     A=lambda x: mvp(x) + damping * x,  # Add damping
    260     b=-grad,
    261     x0=state["x0"],
    262     M=M_func,
    263     max_iter=cg_max_iter,
...
     76     # We need to do a regular size check, without going through
     77     # the operator, to be able to handle unbacked symints
     78     # (expect_true ensures we can deal with unbacked)

AttributeError: 'NoneType' object has no attribute 'is_nested'
ltatzel commented 4 months ago

Hi Felix, The error originates from the implementation of the GGN/Hessian vector products in backpack.hessianfree. Re-installing backpack from the ggn-materialize-grad branch, i.e. pip install backpack-for-pytorch git+https://github.com/f-dangel/backpack.git@ggn-materialize-grad seems to solve the issue. At least, this MWE works fine:

import torch

from hessianfree.optimizer import HessianFree

BATCH_SIZE = 1
OBSERVATION_SPACE = 32
ACTION_SPACE = 4

observation = torch.randn(BATCH_SIZE, OBSERVATION_SPACE)

QNet = torch.nn.Sequential(
    torch.nn.Linear(OBSERVATION_SPACE + ACTION_SPACE, 1),
)

Actor = torch.nn.Sequential(
    torch.nn.Linear(OBSERVATION_SPACE, ACTION_SPACE),
)

opt = HessianFree(
    params=Actor.parameters(),
    verbose=True
)

def forward():
    actions = Actor(observation)
    actor_losses = -QNet(torch.cat([observation, actions], 1))
    actor_loss = actor_losses.mean()
    return actor_loss, actions

loss = opt.step(forward=forward)
f-pfeiffer commented 4 months ago

Hello Lukas,

I really appreciate your help!

I've recently tried applying the optimizer to the critics, but unfortunately, I'm running into the same error I encountered earlier. It seems like the issue might be related to using torch.min, because the function works fine if I return e.g. qf1_a_values in the forward_df(). Do you have any suggestions on how to correctly implement the forward function e.g. computing the min of the two tensors and keep them in the computation graph? I'd be really grateful for any advice you can offer.

import torch
import torch.nn.functional as F
from hessianfree.optimizer import HessianFree

BATCH_SIZE = 4
OBSERVATION_SPACE = 32
ACTION_SPACE = 4

observation = torch.randn(BATCH_SIZE, OBSERVATION_SPACE)
next_q_value = torch.randn(BATCH_SIZE, 1)

QNet1 = torch.nn.Sequential(
    torch.nn.Linear(OBSERVATION_SPACE + ACTION_SPACE, 1),
)

QNet2 = torch.nn.Sequential(
    torch.nn.Linear(OBSERVATION_SPACE + ACTION_SPACE, 1),
)

Actor = torch.nn.Sequential(
    torch.nn.Linear(OBSERVATION_SPACE, ACTION_SPACE),
)

opt_qf = HessianFree(
    params=list(QNet1.parameters()) + list(QNet2.parameters()),
    verbose=False
)

def forward_qf():
    input = torch.cat([observation, Actor(observation)], 1)
    qf1_a_values = QNet1(input)
    qf2_a_values = QNet2(input)
    outputs = torch.min(qf1_a_values, qf2_a_values)

    qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
    qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
    qf_loss = qf1_loss + qf2_loss

    return qf_loss, outputs

loss_qf = opt_qf.step(forward=forward_qf)
Traceback (most recent call last):
  File "/home/felix/Masterthesis/code/second_order_methods_testing/hf_td3_mwe.py", line 74, in main
    loss_qf = opt_qf.step(forward=forward_qf)
  File "/home/felix/anaconda3/envs/ma310/lib/python3.10/site-packages/torch/optim/optimizer.py", line 385, in wrapper
    out = func(*args, **kwargs)
  File "/home/felix/Masterthesis/third_party/PyTorchHessianFree/hessianfree/optimizer.py", line 258, in step
    x_iters, m_iters, cg_reason = cg(
  File "/home/felix/Masterthesis/third_party/PyTorchHessianFree/hessianfree/cg.py", line 188, in cg
    r = A(x0) - b
  File "/home/felix/Masterthesis/third_party/PyTorchHessianFree/hessianfree/optimizer.py", line 259, in <lambda>
    A=lambda x: mvp(x) + damping * x,  # Add damping
  File "/home/felix/Masterthesis/third_party/PyTorchHessianFree/hessianfree/optimizer.py", line 240, in mvp
    return self._Gv(loss, outputs, self._params_list, x)
  File "/home/felix/Masterthesis/third_party/PyTorchHessianFree/hessianfree/optimizer.py", line 454, in _Gv
    Gv = ggn_vector_product_from_plist(loss, outputs, params_list, vec_list)
  File "/home/felix/anaconda3/envs/ma310/lib/python3.10/site-packages/backpack/hessianfree/ggnvp.py", line 55, in ggn_vector_product_from_plist
    HJv = hessian_vector_product(loss, output, Jv)
  File "/home/felix/anaconda3/envs/ma310/lib/python3.10/site-packages/backpack/hessianfree/hvp.py", line 45, in hessian_vector_product
    df_dx = torch.autograd.grad(f, params, create_graph=True, retain_graph=True)
  File "/home/felix/anaconda3/envs/ma310/lib/python3.10/site-packages/torch/autograd/__init__.py", line 411, in grad
    result = Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.
ltatzel commented 4 months ago

Hi Felix,

Sure, no problem!

Maybe you can use the Hessian instead of the GGN (I am skeptical anyway whether the GGN makes sense in your case, as it requires a very specific structure of the loss function).

opt_qf = HessianFree(
    params=list(QNet1.parameters()) + list(QNet2.parameters()),
    verbose=False,
    curvature_opt="hessian",
)

This seems to work fine.

The problem with the GGN is that, internally, we compute HJv = hessian_vector_product(qf_loss, outputs, ...), i.e. we compute derivatives of qf_loss with respect to outputs. But since outputs is not used for the computation of qf_loss, this results in an error.

f-pfeiffer commented 4 months ago

Hi Lukas,

Thank you very much. This really helped me a lot.

Best Felix

ltatzel commented 4 months ago

You're very welcome 😉