Open raymond00000 opened 3 years ago
Hi,
By "implicit parameter updates" I meant layer parameters (buffers) that are not trained with backpropagation but either updated during a forward pass or somehow else. E.g. Batch Normalization layers update average batch statistics at each forward pass. Gradient checkpointing performs two forward passes and hence updates batchnorm statistics twice per step that might be undesirable behavior. Therefore, one should be aware of such side effects and handle them. As I know, LSTM doesn't have any of such moments by default.
thanks a lot for explanation and reply.
I ran the example notebook, it did work! but when i changed to my model, that was some errors.
I frozen most of my model layers by param.requires_grad = False. My model contained these layers that required grad = true.
decoder.gate_layer.linear_layer.weight torch.Size([1, 1536])
decoder.gate_layer.linear_layer.bias torch.Size([1])
when i ran the this line.
updated_model, loss_history, _ = efficient_maml(batch, loss_kwargs={'device':device},
max_grad_grad_norm=max_grad_grad_norm)
I received this error. I wonder why? Many thanks in advance for feedback.
RuntimeError Traceback (most recent call last)
<ipython-input-36-316ef155b6cd> in <module>
1 updated_model, loss_history, _ = efficient_maml(batch, loss_kwargs={'device':device},
----> 2 max_grad_grad_norm=max_grad_grad_norm)
~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
--> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),
~/anaconda3/lib/python3.7/site-packages/torch_maml/maml.py in forward(self, inputs, opt_kwargs, loss_kwargs, optimizer_state, **kwargs)
134 for chunk_start in range(0, len(inputs), self.checkpoint_steps):
135 steps = min(self.checkpoint_steps, len(inputs) - chunk_start)
--> 136 inner_losses, *flat_maml_state = checkpoint(_maml_internal, torch.as_tensor(steps), *flat_maml_state)
137 loss_history.extend(inner_losses.split(1))
138
~/anaconda3/lib/python3.7/site-packages/torch/utils/checkpoint.py in checkpoint(function, *args, **kwargs)
161 raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
162
--> 163 return CheckpointFunction.apply(function, preserve, *args)
164
165
~/anaconda3/lib/python3.7/site-packages/torch/utils/checkpoint.py in forward(ctx, run_function, preserve_rng_state, *args)
72 ctx.save_for_backward(*args)
73 with torch.no_grad():
---> 74 outputs = run_function(*args)
75 return outputs
76
~/anaconda3/lib/python3.7/site-packages/torch_maml/maml.py in _maml_internal(steps, *flat_maml_state)
108 nested_pack(flat_maml_state, structure=initial_maml_state)
109 updated_model = copy_and_replace(
--> 110 self.model, dict(zip(parameters_to_copy, trainable_parameters)), parameters_not_to_copy)
111
112 is_first_pass = not torch.is_grad_enabled()
~/anaconda3/lib/python3.7/site-packages/torch_maml/utils.py in copy_and_replace(original, replace, do_not_copy)
32 memo[id(item)] = replacement
33
---> 34 return deepcopy(original, memo)
35
36
~/anaconda3/lib/python3.7/copy.py in deepcopy(x, memo, _nil)
178 y = x
179 else:
--> 180 y = _reconstruct(x, memo, *rv)
181
182 # If is its own copy, don't memoize.
~/anaconda3/lib/python3.7/copy.py in _reconstruct(x, memo, func, args, state, listiter, dictiter, deepcopy)
278 if state is not None:
279 if deep:
--> 280 state = deepcopy(state, memo)
281 if hasattr(y, '__setstate__'):
282 y.__setstate__(state)
~/anaconda3/lib/python3.7/copy.py in deepcopy(x, memo, _nil)
148 copier = _deepcopy_dispatch.get(cls)
149 if copier:
--> 150 y = copier(x, memo)
151 else:
152 try:
~/anaconda3/lib/python3.7/copy.py in _deepcopy_dict(x, memo, deepcopy)
238 memo[id(x)] = y
239 for key, value in x.items():
--> 240 y[deepcopy(key, memo)] = deepcopy(value, memo)
241 return y
242 d[dict] = _deepcopy_dict
~/anaconda3/lib/python3.7/copy.py in deepcopy(x, memo, _nil)
178 y = x
179 else:
--> 180 y = _reconstruct(x, memo, *rv)
181
182 # If is its own copy, don't memoize.
~/anaconda3/lib/python3.7/copy.py in _reconstruct(x, memo, func, args, state, listiter, dictiter, deepcopy)
304 for key, value in dictiter:
305 key = deepcopy(key, memo)
--> 306 value = deepcopy(value, memo)
307 y[key] = value
308 else:
~/anaconda3/lib/python3.7/copy.py in deepcopy(x, memo, _nil)
178 y = x
179 else:
--> 180 y = _reconstruct(x, memo, *rv)
181
182 # If is its own copy, don't memoize.
~/anaconda3/lib/python3.7/copy.py in _reconstruct(x, memo, func, args, state, listiter, dictiter, deepcopy)
278 if state is not None:
279 if deep:
--> 280 state = deepcopy(state, memo)
281 if hasattr(y, '__setstate__'):
282 y.__setstate__(state)
~/anaconda3/lib/python3.7/copy.py in deepcopy(x, memo, _nil)
148 copier = _deepcopy_dispatch.get(cls)
149 if copier:
--> 150 y = copier(x, memo)
151 else:
152 try:
~/anaconda3/lib/python3.7/copy.py in _deepcopy_dict(x, memo, deepcopy)
238 memo[id(x)] = y
239 for key, value in x.items():
--> 240 y[deepcopy(key, memo)] = deepcopy(value, memo)
241 return y
242 d[dict] = _deepcopy_dict
~/anaconda3/lib/python3.7/copy.py in deepcopy(x, memo, _nil)
159 copier = getattr(x, "__deepcopy__", None)
160 if copier:
--> 161 y = copier(memo)
162 else:
163 reductor = dispatch_table.get(cls)
~/anaconda3/lib/python3.7/site-packages/torch/tensor.py in __deepcopy__(self, memo)
36 def __deepcopy__(self, memo):
37 if not self.is_leaf:
---> 38 raise RuntimeError("Only Tensors created explicitly by the user "
39 "(graph leaves) support the deepcopy protocol at the moment")
40 if id(self) in memo:
RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment
Sorry, I think it requires more involvement to solve the problem. However, my guess that you should specify a proper callable "get_parameters" argument for maml that accounts for the frozen parameters. By default, it uses nn.Module.parameters that returns all model parameters including the frozen ones and hence maml fails when operates on them.
Also, to the best of my knowledge, pytorch maml (not only this implementation) doesn't work with nn.LSTM because it utilizes cudnn operations that don't support a second backward. This point unlikely causes the problem above, but one can face it at the next steps if you use nn.LSTM in your model. FYI, when I applied maml to language models with LSTM layers, I had to use nn.LSTMCell in the loop.
many thanks for feedback.
acutally, I am learning MAML. I am confused on "second backward".
according to here. https://discuss.pytorch.org/t/resolved-implementing-maml-in-pytorch/4053
(1) inner loop: a forward pass on the training example, take gradients with respect to the parameters (2) meta loop: do a forward pass with the updated parameters on a validation example,
passed training set to model0, got the loss1, apply loss1, it became model1, passed validation set to model1. obtained the loss2.
then take another gradient wrt the original parameters and backprop through the first gradient (thus the second derivative).
I thought we just use the loss2 to update the model0. But I am confused on this statement, what was the exact step in pytorch to perform this operation?
many thanks if you could provide some hints on the statement.
Hi,
I am interested to try your solution. but I am not familiar to the concept of "implicit parameter updates". could you explain more on what is this meant and how to handle it? I am working on a model with LSTM layer, I wonder if LSTM has the implicit parameter updates? Thanks for advice.