pytorch / torchtune

PyTorch native finetuning library
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.27k stars 424 forks source link

Not reproducible example in documentation, typo. #1786

Closed krammnic closed 4 weeks ago

krammnic commented 4 weeks ago

Here it is probably typo:

from torchtune.models.llama3_2_vision import llama3_2_vision_transform
from torchtune.datasets.multimodal import multimodal_chat_dataset

transform = Llama3VisionTransform(
    path="/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model",
    prompt_template="torchtune.data.QuestionAnswerTemplate",
    max_seq_len=8192,
    image_size=560,
)
ds = multimodal_chat_dataset(
    model_transform=model_transform,
    source="json",
    data_files="data/my_data.json",
    column_map={
        "dialogue": "conversations",
        "image_path": "image",
    },
    image_dir="/home/user/dataset/",  # /home/user/dataset/images/clock.jpg
    image_tag="<image>",
    split="train",
)
tokenized_dict = ds[0]
print(transform.decode(tokenized_dict["tokens"], skip_special_tokens=False))
# '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nQuestion:<|image|>What time is it on the clock?Answer:<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nIt is 10:00AM.<|eot_id|>'
print(tokenized_dict["encoder_input"]["images"][0].shape)  # (num_tiles, num_channels, tile_height, tile_width)
# torch.Size([4, 3, 224, 224])

Shouldn't it be just transform, not model_transform?

RdoubleA commented 4 weeks ago

Ah yeah, good catch. We should either use transform or model_transform across the example.

If you're able to fix it with a quick PR, happy to stamp it. Otherwise, I can address it.