pytorch / torchtune

PyTorch native finetuning library
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.37k stars 445 forks source link

Dataset packing does not work in knowledge distillation recipe #2019

Closed joecummings closed 1 week ago

joecummings commented 1 week ago

As title. Error log:

Traceback (most recent call last):
  File "/home/jrcummings/.conda/envs/joe-torchtune/bin/tune", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/home/jrcummings/projects/joe-torchtune/torchtune/_cli/tune.py", line 49, in main
    parser.run(args)
  File "/home/jrcummings/projects/joe-torchtune/torchtune/_cli/tune.py", line 43, in run
    args.func(args)
  File "/home/jrcummings/projects/joe-torchtune/torchtune/_cli/run.py", line 214, in _run_cmd
    self._run_single_device(args, is_builtin=is_builtin)
  File "/home/jrcummings/projects/joe-torchtune/torchtune/_cli/run.py", line 108, in _run_single_device
    runpy.run_path(str(args.recipe), run_name="__main__")
  File "<frozen runpy>", line 291, in run_path
  File "<frozen runpy>", line 98, in _run_module_code
  File "<frozen runpy>", line 88, in _run_code
  File "/home/jrcummings/projects/joe-torchtune/recipes/knowledge_distillation_single_device.py", line 804, in <module>
    sys.exit(recipe_main())
             ^^^^^^^^^^^^^
  File "/home/jrcummings/projects/joe-torchtune/torchtune/config/_parse.py", line 99, in wrapper
    sys.exit(recipe_main(conf))
             ^^^^^^^^^^^^^^^^^
  File "/home/jrcummings/projects/joe-torchtune/recipes/knowledge_distillation_single_device.py", line 799, in recipe_main
    recipe.train()
  File "/home/jrcummings/projects/joe-torchtune/recipes/knowledge_distillation_single_device.py", line 677, in train
    for idx, batch in enumerate(self._dataloader):
  File "/home/jrcummings/.conda/envs/joe-torchtune/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 701, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/home/jrcummings/.conda/envs/joe-torchtune/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 757, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jrcummings/.conda/envs/joe-torchtune/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/jrcummings/.conda/envs/joe-torchtune/lib/python3.11/site-packages/torch/utils/data/_utils/collate.py", line 398, in default_collate
    return collate(batch, collate_fn_map=default_collate_fn_map)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jrcummings/.conda/envs/joe-torchtune/lib/python3.11/site-packages/torch/utils/data/_utils/collate.py", line 171, in collate
    {
  File "/home/jrcummings/.conda/envs/joe-torchtune/lib/python3.11/site-packages/torch/utils/data/_utils/collate.py", line 172, in <dictcomp>
    key: collate(
         ^^^^^^^^
  File "/home/jrcummings/.conda/envs/joe-torchtune/lib/python3.11/site-packages/torch/utils/data/_utils/collate.py", line 155, in collate
    return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jrcummings/.conda/envs/joe-torchtune/lib/python3.11/site-packages/torch/utils/data/_utils/collate.py", line 272, in collate_tensor_fn
    return torch.stack(batch, 0, out=out)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: stack expects each tensor to be equal size, but got [8] at entry 0 and [14] at entry 1