lapp0 / distily

Distily: Language Model Distillation Toolkit and Library
GNU Affero General Public License v3.0
6 stars 0 forks source link

Fix `torch.compile` #1

Open lapp0 opened 2 months ago

lapp0 commented 2 months ago

Reproducer

import distily

distily.run.benchmark(
    teacher_model_name_or_path="gpt2",
    output_dir="distily_verify_compile",
    hub_model_id="distily/distily_verify_compile",
    push_to_hub=True,
    report_to="tensorboard",
    dataset_sample_size=4000,
    gradient_accumulation_steps=1,
    harness_benchmarks=[],
    params=[
        {"torch_compile": True},
        {"torch_compile": False},
    ]
)

Error

Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. Stack trace: File "/opt/conda/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1337, in torch_dynamo_resume_in_forward_at_1315
    lm_logits = self.lm_head(hidden_states). To prevent overwriting, clone the tensor outside of torch.compile() or call torch.compiler.cudagraph_mark_step_begin() before each model invocation.
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/distily/run.py", line 66, in benchmark
    res = train(*parsed_args_tuple)
  File "/opt/conda/lib/python3.10/site-packages/distily/run.py", line 86, in train
    trainer.train()
  File "/opt/conda/lib/python3.10/site-packages/distily/distillation_trainer.py", line 92, in train
    train_output = super().train(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1929, in train
    return inner_training_loop(
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 2205, in _inner_training_loop
    self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 2761, in _evaluate
    metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
  File "/opt/conda/lib/python3.10/site-packages/distily/distillation_trainer.py", line 135, in evaluate
    super().evaluate(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 3666, in evaluate
    output = eval_loop(
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 3857, in evaluation_loop
    losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 4075, in prediction_step
    loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
  File "/opt/conda/lib/python3.10/site-packages/distily/distillation_trainer.py", line 103, in compute_loss
    loss_dict = self.distillation_objective(self.teacher_model, model, inputs)
  File "/opt/conda/lib/python3.10/site-packages/distily/objectives/objectives.py", line 106, in __call__
    logits_loss = self._calc_loss(out_s.logits, out_t.logits, self.logits_loss_component, device)
  File "/opt/conda/lib/python3.10/site-packages/distily/objectives/objectives.py", line 135, in _calc_loss
    loss = loss_component.get_loss(feat_s, feat_t)
  File "/opt/conda/lib/python3.10/site-packages/distily/objectives/loss.py", line 47, in kl_divergence_loss
    teacher_prob = F.softmax(feat_t, dim=-1)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/functional.py", line 1885, in softmax
    ret = input.softmax(dim)
RuntimeError: Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. Stack trace: File "/opt/conda/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1337, in torch_dynamo_resume_in_forward_at_1315
    lm_logits = self.lm_head(hidden_states). To prevent overwriting, clone the tensor outside of torch.compile() or call torch.compiler.cudagraph_mark_step_begin() before each model invocation.

Implications

Completion of this issue allows us to benchmark and integrate

lapp0 commented 2 months ago

https://github.com/mobiusml/hqq/issues/108