dbaranchuk / memory-efficient-maml

Memory efficient MAML using gradient checkpointing
MIT License
84 stars 6 forks source link

what is meant by implicit parameter updates? #2

Open raymond00000 opened 3 years ago

raymond00000 commented 3 years ago

Hi,

I am interested to try your solution. but I am not familiar to the concept of "implicit parameter updates". could you explain more on what is this meant and how to handle it? I am working on a model with LSTM layer, I wonder if LSTM has the implicit parameter updates? Thanks for advice.

dbaranchuk commented 3 years ago

Hi,

By "implicit parameter updates" I meant layer parameters (buffers) that are not trained with backpropagation but either updated during a forward pass or somehow else. E.g. Batch Normalization layers update average batch statistics at each forward pass. Gradient checkpointing performs two forward passes and hence updates batchnorm statistics twice per step that might be undesirable behavior. Therefore, one should be aware of such side effects and handle them. As I know, LSTM doesn't have any of such moments by default.

raymond00000 commented 3 years ago

thanks a lot for explanation and reply.

I ran the example notebook, it did work! but when i changed to my model, that was some errors.

I frozen most of my model layers by param.requires_grad = False. My model contained these layers that required grad = true.

   decoder.gate_layer.linear_layer.weight torch.Size([1, 1536])
   decoder.gate_layer.linear_layer.bias torch.Size([1])

when i ran the this line.

updated_model, loss_history, _ = efficient_maml(batch, loss_kwargs={'device':device},
                                            max_grad_grad_norm=max_grad_grad_norm)

I received this error. I wonder why? Many thanks in advance for feedback.

      RuntimeError                              Traceback (most recent call last)
      <ipython-input-36-316ef155b6cd> in <module>
            1 updated_model, loss_history, _ = efficient_maml(batch, loss_kwargs={'device':device},
      ----> 2                                                 max_grad_grad_norm=max_grad_grad_norm)

      ~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
          720             result = self._slow_forward(*input, **kwargs)
          721         else:
      --> 722             result = self.forward(*input, **kwargs)
          723         for hook in itertools.chain(
          724                 _global_forward_hooks.values(),

      ~/anaconda3/lib/python3.7/site-packages/torch_maml/maml.py in forward(self, inputs, opt_kwargs, loss_kwargs, optimizer_state, **kwargs)
          134         for chunk_start in range(0, len(inputs), self.checkpoint_steps):
          135             steps = min(self.checkpoint_steps, len(inputs) - chunk_start)
      --> 136             inner_losses, *flat_maml_state = checkpoint(_maml_internal, torch.as_tensor(steps), *flat_maml_state)
          137             loss_history.extend(inner_losses.split(1))
          138 

      ~/anaconda3/lib/python3.7/site-packages/torch/utils/checkpoint.py in checkpoint(function, *args, **kwargs)
          161         raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
          162 
      --> 163     return CheckpointFunction.apply(function, preserve, *args)
          164 
          165 

      ~/anaconda3/lib/python3.7/site-packages/torch/utils/checkpoint.py in forward(ctx, run_function, preserve_rng_state, *args)
           72         ctx.save_for_backward(*args)
           73         with torch.no_grad():
      ---> 74             outputs = run_function(*args)
           75         return outputs
           76 

      ~/anaconda3/lib/python3.7/site-packages/torch_maml/maml.py in _maml_internal(steps, *flat_maml_state)
          108                 nested_pack(flat_maml_state, structure=initial_maml_state)
          109             updated_model = copy_and_replace(
      --> 110                 self.model, dict(zip(parameters_to_copy, trainable_parameters)), parameters_not_to_copy)
          111 
          112             is_first_pass = not torch.is_grad_enabled()

      ~/anaconda3/lib/python3.7/site-packages/torch_maml/utils.py in copy_and_replace(original, replace, do_not_copy)
           32         memo[id(item)] = replacement
           33 
      ---> 34     return deepcopy(original, memo)
           35 
           36 

      ~/anaconda3/lib/python3.7/copy.py in deepcopy(x, memo, _nil)
          178                     y = x
          179                 else:
      --> 180                     y = _reconstruct(x, memo, *rv)
          181 
          182     # If is its own copy, don't memoize.

      ~/anaconda3/lib/python3.7/copy.py in _reconstruct(x, memo, func, args, state, listiter, dictiter, deepcopy)
          278     if state is not None:
          279         if deep:
      --> 280             state = deepcopy(state, memo)
          281         if hasattr(y, '__setstate__'):
          282             y.__setstate__(state)

      ~/anaconda3/lib/python3.7/copy.py in deepcopy(x, memo, _nil)
          148     copier = _deepcopy_dispatch.get(cls)
          149     if copier:
      --> 150         y = copier(x, memo)
          151     else:
          152         try:

      ~/anaconda3/lib/python3.7/copy.py in _deepcopy_dict(x, memo, deepcopy)
          238     memo[id(x)] = y
          239     for key, value in x.items():
      --> 240         y[deepcopy(key, memo)] = deepcopy(value, memo)
          241     return y
          242 d[dict] = _deepcopy_dict

      ~/anaconda3/lib/python3.7/copy.py in deepcopy(x, memo, _nil)
          178                     y = x
          179                 else:
      --> 180                     y = _reconstruct(x, memo, *rv)
          181 
          182     # If is its own copy, don't memoize.

      ~/anaconda3/lib/python3.7/copy.py in _reconstruct(x, memo, func, args, state, listiter, dictiter, deepcopy)
          304             for key, value in dictiter:
          305                 key = deepcopy(key, memo)
      --> 306                 value = deepcopy(value, memo)
          307                 y[key] = value
          308         else:

      ~/anaconda3/lib/python3.7/copy.py in deepcopy(x, memo, _nil)
          178                     y = x
          179                 else:
      --> 180                     y = _reconstruct(x, memo, *rv)
          181 
          182     # If is its own copy, don't memoize.

      ~/anaconda3/lib/python3.7/copy.py in _reconstruct(x, memo, func, args, state, listiter, dictiter, deepcopy)
          278     if state is not None:
          279         if deep:
      --> 280             state = deepcopy(state, memo)
          281         if hasattr(y, '__setstate__'):
          282             y.__setstate__(state)

      ~/anaconda3/lib/python3.7/copy.py in deepcopy(x, memo, _nil)
          148     copier = _deepcopy_dispatch.get(cls)
          149     if copier:
      --> 150         y = copier(x, memo)
          151     else:
          152         try:

      ~/anaconda3/lib/python3.7/copy.py in _deepcopy_dict(x, memo, deepcopy)
          238     memo[id(x)] = y
          239     for key, value in x.items():
      --> 240         y[deepcopy(key, memo)] = deepcopy(value, memo)
          241     return y
          242 d[dict] = _deepcopy_dict

      ~/anaconda3/lib/python3.7/copy.py in deepcopy(x, memo, _nil)
          159             copier = getattr(x, "__deepcopy__", None)
          160             if copier:
      --> 161                 y = copier(memo)
          162             else:
          163                 reductor = dispatch_table.get(cls)

      ~/anaconda3/lib/python3.7/site-packages/torch/tensor.py in __deepcopy__(self, memo)
           36     def __deepcopy__(self, memo):
           37         if not self.is_leaf:
      ---> 38             raise RuntimeError("Only Tensors created explicitly by the user "
           39                                "(graph leaves) support the deepcopy protocol at the moment")
           40         if id(self) in memo:

      RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment
dbaranchuk commented 3 years ago

Sorry, I think it requires more involvement to solve the problem. However, my guess that you should specify a proper callable "get_parameters" argument for maml that accounts for the frozen parameters. By default, it uses nn.Module.parameters that returns all model parameters including the frozen ones and hence maml fails when operates on them.

Also, to the best of my knowledge, pytorch maml (not only this implementation) doesn't work with nn.LSTM because it utilizes cudnn operations that don't support a second backward. This point unlikely causes the problem above, but one can face it at the next steps if you use nn.LSTM in your model. FYI, when I applied maml to language models with LSTM layers, I had to use nn.LSTMCell in the loop.

raymond00000 commented 3 years ago

many thanks for feedback.

acutally, I am learning MAML. I am confused on "second backward".

according to here. https://discuss.pytorch.org/t/resolved-implementing-maml-in-pytorch/4053

  (1) inner loop: a forward pass on the training example, take gradients with respect to the parameters (2) meta loop: do a forward pass with the updated parameters on a validation example, 

passed training set to model0, got the loss1, apply loss1, it became model1, passed validation set to model1. obtained the loss2.

  then take another gradient wrt the original parameters and backprop through the first gradient (thus the second derivative).

I thought we just use the loss2 to update the model0. But I am confused on this statement, what was the exact step in pytorch to perform this operation?

many thanks if you could provide some hints on the statement.