pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.44k stars 451 forks source link

model.to(xla_device) increases the number of named_parameters #7042

Open shenh10 opened 4 months ago

shenh10 commented 4 months ago

🐛 Bug

Copy model to xla device affects the number of model's parameters. image

To Reproduce

python xla/benchmarks/experiment_runner.py        --suite-name torchbench        --accelerator cuda  --dynamo openxla --dynamo None        --test train        --repeat 30 --iterations-per-run 5        --print-subprocess        --no-resume --model-config='{"model_name": "hf_Bart"}' --experiment-config='{"accelerator": "cuda", "xla": "PJRT", "xla_flags": null, "dynamo": "openxla", "test": "train"}'

Steps to reproduce the behavior:

  1. Run the above command
  2. insert pdb hook at xla/benchmarks/benchmark_model.py
    110   def prepare_for_experiment(self, dynamo_compilation_opts):
    111     self.device = self.benchmark_experiment.get_device()
    112     self.dtype = self.conversion_dtype()
    113 
    114     if self.dtype is not None:
    115       self.module = self.module.to(self.dtype)
    116       self.example_inputs = cast_to_dtype(self.example_inputs, self.dtype)
    117 
    118     import pdb
    119     pdb.set_trace()
    120     self.module = self.module.to(self.device)
    121     self.example_inputs = move_to_device(self.example_inputs, self.device)
    122 
    123     if self.benchmark_experiment.test == "eval":
    124       self._prepare_for_eval()
    125     elif self.benchmark_experiment.test == "train":
    126       self._prepare_for_train()
    127     else:
    128       raise NotImplementedError
    129 
    130     if self.benchmark_experiment.dynamo:
    131       compilation_opts = dynamo_compilation_opts.copy()
    132       compilation_opts['backend'] = self.benchmark_experiment.dynamo
    133 
    134       logger.info(f"Running torch.compile with opts {compilation_opts}")
    135       self.model_iter_fn = torch.compile(self.model_iter_fn, **compilation_opts)
  3. print the number of named_parameter of model before the copy to xla device and after the copy like the picture above shows.
    (Pdb) new_model = copy.deepcopy(self.module).to("cpu").to(self.device)                                                                                          │105       self.optimizer = self.optimizer_class(self.module.parameters(), lr=0.01)
    (Pdb) len([param for param, value in new_model.named_parameters()])                                                                                             │106 
    262                                                                                                                                                             │107   def conversion_dtype(self):
    (Pdb) len([param for param, value in self.module.named_parameters()])                                                                                           │108     return None
    259                                                                                                                                                             │109 
    (Pdb) len([param for param, value in self.module.named_buffers()])                                                                                              │110   def prepare_for_experiment(self, dynamo_compilation_opts):
    1                                                                                                                                                               │111     self.device = self.benchmark_experiment.get_device()
    (Pdb) len([param for param, value in new_model.named_buffers()])                                                                                                │112     self.dtype = self.conversion_dtype()
    1 

Expected behavior

len([param for param, value in new_model.named_parameters()]) is expected to return 259

Environment

JackCaoG commented 4 months ago

@qihqi since you are offcall this week, do you have time to follow up this issue?

qihqi commented 3 months ago

This is caused by this line: https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/bart/modeling_bart.py#L1530C1-L1533C79

This line merges 2 parameters together for both encoder and decoder resulting 2 parameters less. So there are 2 parameters that shares the same tensor but has different name.

You can print state_dict length with len(model.state_dict()) and that is the same 261 for both models.

Semantics of named_parameters() is unique parameters (with their name) so 2 parameters will same reference are only printed once.

When moving to device, tensors are moved using state_dict so the _tie_weights effect is lost.

You can get it back by running logic of _tie_weights on new_model again.


In [28]:     def _tie_weights(self):
    ...:         if self.config.tie_word_embeddings:
    ...:             self._tie_or_clone_weights(self.encoder.embed_tokens, sel
    ...: f.shared)
    ...:             self._tie_or_clone_weights(self.decoder.embed_tokens, sel
    ...: f.shared)
    ...:

In [29]: _tie_weights(newm)

In [30]: len(list(newm.state_dict()))
Out[30]: 261

In [31]: len(list(newm.named_parameters()))
Out[31]: 259
shenh10 commented 3 months ago

This is caused by this line: https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/bart/modeling_bart.py#L1530C1-L1533C79

This line merges 2 parameters together for both encoder and decoder resulting 2 parameters less. So there are 2 parameters that shares the same tensor but has different name.

You can print state_dict length with len(model.state_dict()) and that is the same 261 for both models.

Semantics of named_parameters() is unique parameters (with their name) so 2 parameters will same reference are only printed once.

When moving to device, tensors are moved using state_dict so the _tie_weights effect is lost.

You can get it back by running logic of _tie_weights on new_model again.


In [28]:     def _tie_weights(self):
    ...:         if self.config.tie_word_embeddings:
    ...:             self._tie_or_clone_weights(self.encoder.embed_tokens, sel
    ...: f.shared)
    ...:             self._tie_or_clone_weights(self.decoder.embed_tokens, sel
    ...: f.shared)
    ...:

In [29]: _tie_weights(newm)

In [30]: len(list(newm.state_dict()))
Out[30]: 261

In [31]: len(list(newm.named_parameters()))
Out[31]: 259

Thank you for your response. I'm wondering why model.to() would trigger _tie_weights function?

qihqi commented 3 months ago

model.to()'s logic can be interpreted roughly as

new_state_dict = {} 
for k, v in model.state_dict():
        new_state_dict[k] = v.to(device)

model.load_state_dict(new_state_dict)

So it undoes what _tie_weights do.