hill-a / stable-baselines

A fork of OpenAI Baselines, implementations of reinforcement learning algorithms
http://stable-baselines.readthedocs.io/
MIT License
4.16k stars 725 forks source link

Zero_grad() #997

Closed amirhosseinzlf closed 4 years ago

amirhosseinzlf commented 4 years ago

I'm trying to implement an FGSM attack on TRPO but I don't know how should I use zero_grad with the TRPO model here is a part of the code

loss2 = F.nll_loss(model.action_probability(state), model.predict(state))
model.action_probability.zero_grad()
# Calculate gradients of model in backward pass
loss2.backward()
# Collect datagrad
data_grad = state_batch.grad.data

problem obviously action_probablity is a function so how I can have an access to the policy network I need the network output and gradient.

Run example

'function' object has no attribute 'zero_grad'

System Info Colab

Miffyli commented 4 years ago

zero_grad() is a PyTorch function. This repository code is based on Tensorflow. You might want to check out stable-baselines3.