pytorch / torchtune

A Native-PyTorch Library for LLM Fine-tuning
BSD 3-Clause "New" or "Revised" License
3.56k stars 291 forks source link

Chat tutorial doesn't work as is #1002

Closed christobill closed 1 month ago

christobill commented 1 month ago

Hello 👋 I have an error trying to follow instructions on this page: https://pytorch.org/torchtune/main/tutorials/chat.html

when I run: tune run --nproc_per_node 2 lora_finetune_distributed --config custom_config_lora.yaml

I get:

INFO:torchtune.utils.logging:Dataset and Sampler are initialized.
INFO:torchtune.utils.logging:Learning rate scheduler is initialized.
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/recipes/lora_finetune_distributed.py", line 615, in <module>
    sys.exit(recipe_main())
  File "/opt/conda/lib/python3.10/site-packages/torchtune/config/_parse.py", line 50, in wrapper
    sys.exit(recipe_main(conf))
  File "/opt/conda/lib/python3.10/site-packages/recipes/lora_finetune_distributed.py", line 610, in recipe_main
    recipe.train()
  File "/opt/conda/lib/python3.10/site-packages/recipes/lora_finetune_distributed.py", line 523, in train
    for idx, batch in enumerate(
  File "/opt/conda/lib/python3.10/site-packages/tqdm/std.py", line 1166, in __iter__
    for obj in iterable:
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
    data = self._next_data()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 675, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/opt/conda/lib/python3.10/site-packages/torchtune/datasets/_chat.py", line 81, in __getitem__
    sample = self._data[index]
  File "/opt/conda/lib/python3.10/site-packages/datasets/dataset_dict.py", line 81, in __getitem__
    raise KeyError(
KeyError: "Invalid key: 0. Please first select a split. For example: `my_dataset_dictionary['train'][0]`. Available splits: ['train']"

My custom config:

# Config for multi-device LoRA finetuning in lora_finetune_distributed.py
# using a Llama3 8B model
#
# This config assumes that you've run the following command before launching
# this run:
#   tune download meta-llama/Meta-Llama-3-8B-Instruct --output-dir /tmp/Meta-Llama-3-8B-Instruct --hf-token <HF_TOKEN>
#
# To launch on 2 devices, run the following command from root:
#   tune run --nproc_per_node 2 lora_finetune_distributed --config llama3/8B_lora
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
#   tune run --nproc_per_node 2 lora_finetune_distributed --config llama3/8B_lora checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works best when the model is being fine-tuned on 2+ GPUs.
# For single device LoRA finetuning please use 8B_lora_single_device.yaml
# or 8B_qlora_single_device.yaml

# Tokenizer
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model

# Model Arguments
model:
  _component_: torchtune.models.llama3.lora_llama3_8b
  lora_attn_modules: ['q_proj', 'v_proj']
  apply_lora_to_mlp: False
  apply_lora_to_output: False
  lora_rank: 8
  lora_alpha: 16

checkpointer:
  _component_: torchtune.utils.FullModelMetaCheckpointer
  checkpoint_dir: /tmp/Meta-Llama-3-8B-Instruct/original/
  checkpoint_files: [
    consolidated.00.pth
  ]
  recipe_checkpoint: null
  output_dir: /tmp/Meta-Llama-3-8B-Instruct/
  model_type: LLAMA3
resume_from_checkpoint: False

# Dataset and Sampler
dataset:
  _component_: my_module.custom_dataset
  max_seq_len: 2048
seed: null
shuffle: True
batch_size: 2

# Optimizer and Scheduler
optimizer:
  _component_: torch.optim.AdamW
  weight_decay: 0.01
  lr: 3e-4
lr_scheduler:
  _component_: torchtune.modules.get_cosine_schedule_with_warmup
  num_warmup_steps: 100

loss:
  _component_: torch.nn.CrossEntropyLoss

# Training
epochs: 1
max_steps_per_epoch: null
gradient_accumulation_steps: 32

# Logging
output_dir: /tmp/lora_finetune_output
metric_logger:
  _component_: torchtune.utils.metric_logging.DiskLogger
  log_dir: ${output_dir}
log_every_n_steps: null

# Environment
device: cuda
dtype: bf16
enable_activation_checkpointing: False

My custom_dataset function:

from torchtune.modules.tokenizers import Tokenizer
from torchtune.datasets import ChatDataset
from torchtune.data import Message
from typing import Mapping, Any, List

def message_converter(sample: Mapping[str, Any]) -> List[Message]:
    print(sample)
    input_msg = sample["input"]
    output_msg = sample["output"]

    user_message = Message(
        role="user",
        content=input_msg,
        masked=True,  # Mask if not training on prompt
    )
    assistant_message = Message(
        role="assistant",
        content=output_msg,
        masked=False,
    )
    # A single turn conversation
    messages = [user_message, assistant_message]

    return messages

def custom_dataset(
    *,
    tokenizer: Tokenizer,
    max_seq_len: int = 2048,  # You can expose this if you want to experiment
) -> ChatDataset:

    return ChatDataset(
        tokenizer=tokenizer,
        # For local csv files, we specify "csv" as the source, just like in
        # load_dataset
        source="csv",
        convert_to_messages=message_converter,
        # Llama3 does not need a chat format
        chat_format=None,
        max_seq_len=max_seq_len,
        # To load a local file we specify it as data_files just like in
        # load_dataset
        data_files="your_file.csv",
    )

My CSV:

"input","output"
"How do GPS receivers communicate with satellites?","The first thing to know is the communication is one-way..."
"What are the main components of a computer motherboard?","The motherboard is the main circuit board of a computer..."
"How does a microwave oven work?","Microwave ovens cook food using electromagnetic radiation..."
"What are the symptoms of a computer virus infection?","Symptoms of a computer virus infection can vary depending on the type of virus..."
"What is machine learning?","Machine learning is a subset of artificial intelligence..."
"How do airplanes fly?","Airplanes fly using the principles of aerodynamics..."
"What causes earthquakes?","Earthquakes are caused by the sudden release of energy..."
"How does the internet work?","The internet is a global network of interconnected computers..."
"What are the stages of the water cycle?","The water cycle consists of several stages..."
"What is the greenhouse effect?","The greenhouse effect is a natural process that warms the Earth's surface..."
"How does the human brain process information?","The human brain processes information through a complex network of neurons..."
"What is DNA?","DNA, or deoxyribonucleic acid, is a molecule that contains the genetic instructions for life..."
"How do plants make food?","Plants make food through a process called photosynthesis..."
"What is the theory of evolution?","The theory of evolution proposes that species change over time..."
"What causes tides?","Tides are primarily caused by the gravitational pull of the Moon and the Sun..."
"How does a camera work?","Cameras capture images by focusing light onto a photosensitive surface..."
"What is a black hole?","A black hole is a region of space where gravity is so strong that nothing, not even light, can escape..."
"How do vaccines work?","Vaccines work by stimulating the immune system to produce antibodies..."
"What is the difference between weather and climate?","Weather refers to short-term atmospheric conditions..."
"How do touchscreens work?","Touchscreens detect touch input using electrical signals..."
"What causes thunderstorms?","Thunderstorms are caused by the rapid upward movement of warm, moist air..."
"What is the purpose of the circulatory system?","The circulatory system transports oxygen, nutrients, and hormones throughout the body..."
"What are the properties of acids and bases?","Acids have a sour taste and turn blue litmus paper red..."
"How do cell phones work?","Cell phones work by transmitting and receiving radio signals..."
"What causes the seasons?","The tilt of the Earth's axis causes the seasons..."
"How do solar panels work?","Solar panels convert sunlight into electricity using photovoltaic cells..."
"What is the difference between a hurricane and a tornado?","Hurricanes are large, rotating storms that form over warm ocean waters..."
"How does the human digestive system work?","The human digestive system breaks down food into nutrients..."
"What is a tsunami?","A tsunami is a series of large ocean waves caused by underwater earthquakes or volcanic eruptions..."
"What are the different types of clouds?","Clouds are categorized based on their shape and altitude..."
"What is the structure of an atom?","Atoms consist of a nucleus surrounded by electrons..."
"How does the immune system work?","The immune system protects the body from harmful pathogens..."
"What is the Big Bang theory?","The Big Bang theory is the prevailing cosmological model..."
"What causes the phases of the moon?","The phases of the moon are caused by the relative positions of the Earth, moon, and Sun..."
"How does a refrigerator work?","Refrigerators work by removing heat from the interior..."
"What is the water table?","The water table is the level below which the ground is saturated with water..."
"What is the difference between speed and velocity?","Speed is a scalar quantity that measures how fast an object is moving..."
"How do 3D printers work?","3D printers create three-dimensional objects by laying down successive layers of material..."
"What are the primary colors of light?","The primary colors of light are red, green, and blue..."
"What is the difference between a comet and an asteroid?","Comets are icy bodies that orbit the Sun..."
"How do magnets work?","Magnets produce magnetic fields that exert forces on other magnets and magnetic materials..."
"What causes lightning?","Lightning is caused by the buildup of static electricity in clouds..."
"What is the difference between erosion and weathering?","Erosion is the process of transporting weathered material..."
"How does the human respiratory system work?","The human respiratory system delivers oxygen to the body and removes carbon dioxide..."
"What is the difference between a herbivore and a carnivore?","Herbivores primarily eat plants..."
"How do elevators work?","Elevators use electric motors to move between floors..."
"What causes the color of the sky?","The color of the sky is primarily due to scattering of sunlight by particles in the atmosphere..."
"What is the difference between mass and weight?","Mass is a measure of the amount of matter in an object..."
"How do lasers work?","Lasers emit coherent light through a process called stimulated emission..."
"What is the difference between an ecosystem and a habitat?","An ecosystem consists of all the living and nonliving components of a particular environment..."
"How do digital cameras work?","Digital cameras capture and store images as digital data..."
"What is the difference between weathering and erosion?","Weathering is the process of breaking down rocks..."
"How does a nuclear reactor work?","Nuclear reactors produce electricity by harnessing the heat released from nuclear reactions..."
"What causes the Northern Lights?","The Northern Lights, or auroras, are caused by charged particles from the Sun interacting with Earth's magnetic field..."
"How does a battery work?","Batteries convert chemical energy into electrical energy..."
"What are the layers of the Earth's atmosphere?","The Earth's atmosphere consists of several layers..."
"How do submarines work?","Submarines are underwater vessels that use ballast tanks to control buoyancy..."
"What is the difference between a vertebrate and an invertebrate?","Vertebrates have a backbone or spinal column..."
"How do rainbows form?","Rainbows form when sunlight is refracted, reflected, and dispersed by water droplets in the atmosphere..."
"What causes the Coriolis effect?","The Coriolis effect is caused by the rotation of the Earth..."
"How does a telescope work?","Telescopes gather and focus light to produce magnified images of distant objects..."
"What is the difference between a solution and a suspension?","A solution is a homogeneous mixture..."
"How does a car engine work?","Car engines convert chemical energy from fuel into mechanical energy..."
"What are the different types of energy?","Energy exists in various forms, including kinetic, potential, thermal, and electromagnetic..."
"How do radios work?","Radios receive and transmit radio waves to communicate over long distances..."
"What is the difference between an element and a compound?","An element consists of atoms with the same number of protons..."
"How does a wind turbine work?","Wind turbines convert kinetic energy from the wind into electrical energy..."
"What causes the formation of clouds?","Clouds form when warm, moist air rises and cools..."
"How does a toilet work?","Toilets use water and gravity to remove waste from the bowl..."
"What is the difference between velocity and acceleration?","Velocity is a vector quantity that includes both speed and direction..."
RdoubleA commented 1 month ago

Looks like this was resolved on discord by reinstalling the nightly version of torchtune. Feel free to reopen this issue if you're still running into problems and we can take another look! @christobill

christobill commented 1 month ago

@RdoubleA There was also a line to add to the code provided by the tutorial.

I did this PR to update the documentation: https://github.com/pytorch/torchtune/pull/1004