MorvanZhou / pytorch-A3C

Simple A3C implementation with pytorch + multiprocessing
https://mofanpy.com
MIT License
608 stars 142 forks source link

Are local gradients accumulated and never reset? #21

Open jakkarn opened 3 years ago

jakkarn commented 3 years ago

I can't see that the local gradients are ever reset. The values are overwritten by the global weights, but the optimizer opt is assigned to the global parameters, so won't this accumulate gradients in the local network?

https://github.com/MorvanZhou/pytorch-A3C/blob/5ab27abee2c3ac3ca921ac393bfcbda4e0a91745/utils.py#L41

jakkarn commented 3 years ago

I just found out that the state_dict contains the gradients. So they should at least be somewhat reset when loading the global state_dict (with new gradients) to the local nn.

From the pytorch documentation: "torch.nn.Module.load_state_dict: Loads a model’s parameter dictionary using a deserialized state_dict.".

To me, that sounds like it loads a copy of the global parameters, meaning that the gradients will be added to the previous global gradients.

Bear-kai commented 2 years ago

I have a similar question about the gradient.

Acturally, after lnet.load_state_dict(gnet.state_dict()) being excuted, all the parameters in both lnet and gnet are shared. That is to say, the opt.zero_grad() will set the gradients in lnet and gnet to zero! And, the loss.backward() will make lnet and gnet have the same gradient! So after the 1st iteration, gp._grad = lp.grad is useless because they are already the same! I find another implementation here involving a if-return criterion (I guess it corresponds to my claim that the grad assignment is useless after the 1st iteration).

# copy from continuous A3C, consider the cases after the 1st iteration
opt.zero_grad()         # zero gradient in both lnet and gnet
loss.backward()         # parameters in both lnet and gnet have the same gradients
for lp, gp in zip(lnet.parameters(), gnet.parameters()):     # the for loop is useless
    # if gp.grad is not None:       
    #     return                           # This "if-return" code are copied from above link 
    gp._grad = lp.grad   
opt.step()                                 # update gnet parameters (parameters in lnet will not change!)
lnet.load_state_dict(gnet.state_dict())    # update lnet parameters

It is confused to me and it might be a (serious) bug. What if worker A is updating gnet by opt.step and worker B just clears/modifies the gradients by opt.zero_grad()/loss.backward() ? However, the code just works (look the episode reward curve and the visualization)!

BTW, the state_dict does not contain any gradient info! It is an OrderedDict of weights and biases of parameters.

MorvanZhou commented 2 years ago

The lnet.load_state_dict() function shows as below:

def load_state_dict(self, state_dict):
    # deepcopy, to be consistent with module API
    state_dict = deepcopy(state_dict)
    # Validate the state_dict
    groups = self.param_groups
    saved_groups = state_dict['param_groups']

it uses deepcopy to isolate parameters from the gnet. So there is no memory share on here.

So after the 1st iteration, gp._grad = lp.grad is useless because they are already the same!

Once local worker has moved to another worker, the gp._grad is necessary to switch to another worker's grad.

Bear-kai commented 2 years ago

Thanks for your reply! @MorvanZhou

  1. Yes, the load_state_dict() will not make parameters shared. I found it and had scratched out the sentence before.
  2. It will be very kind of you to explain if there might be conficts between workers without locking the shared model.

    What if worker A is updating gnet by opt.step and worker B just clears/modifies the gradients by opt.zero_grad()/loss.backward() ?

  3. Note that I made the following comments by step-by-step debug.
    # copy from continuous A3C, consider the cases after the 1st iteration
    opt.zero_grad()         # zero gradient in both lnet and gnet
    loss.backward()         # parameters in both lnet and gnet have the same gradients
    for lp, gp in zip(lnet.parameters(), gnet.parameters()):     # the for loop is useless after the 1st iteration ??
    # if gp.grad is not None:       
    #     return                           # This "if-return" code are copied from above link 
    gp._grad = lp.grad   
    opt.step()                                 # update gnet parameters (parameters in lnet will not change!)
    lnet.load_state_dict(gnet.state_dict())    # update lnet parameters
MorvanZhou commented 2 years ago

It will be very kind of you to explain if there might be conficts between workers without locking the shared model.

A lock could be applied in this case, but take a look of HOGWILD for the analysis of backprop without locking.