ghmagazine / llm-book

「大規模言語モデル入門」(技術評論社, 2023)のGitHubリポジトリ
https://gihyo.jp/book/2023/978-4-297-13633-8
Apache License 2.0
272 stars 41 forks source link

第7章の見出し語生成ができない #33

Open SG2810 opened 1 week ago

SG2810 commented 1 week ago

7.4.2でファインチューニングを行った後に以下のコードを実行するとその下のようなエラーが発生しました。

from torch.utils.data import DataLoader
from transformers import PreTrainedModel

def convert_list_dict_to_dict_list(list_dict: dict[str, list]) -> list[dict[str, list]]:
    """ミニバッチのデータを事例単位のlistに変換"""
    dict_list = []
    #dictのキーのリストを取得する
    keys = list(list_dict.keys())
    for idx in range(len(list_dict[keys[0]])):  #各事例で処理する
        #dictの各キーからデータを取り出してlistに追加する
        dict_list.append({key: list_dict[key][idx] for key in keys})
    return dict_list

def run_generation(dataloader: DataLoader, model: PreTrainedModel) -> list[dict[str, Any]]:
    """見出しを生成"""
    generations = []
    for batch in tqdm(dataloader): #各ミニバッチを処理する
        #for k, v in batch.items():print(f"batch {batch}¥n key {k}¥n item{v}¥n¥n")
        batch = {k: v.to(model.device) for k, v in batch.items()}
        #見出しのトークンIDを生成する
        batch["generated_title_ids"] = model.generate(**batch)
        batch = {k: v.cpu().tolist() for k, v in batch.items()}
        #ミニバッチのデータ事例単位のlistに変換する
        generations += convert_list_dict_to_dict_list(batch)
    return generations

#テストセットに対して前処理を行う
test_dataset = dataset["test"].map(
    preprocess_data,
    fn_kwargs={"tokenizer": tokenizer},
    remove_columns=dataset["test"].column_names,
)

test_dataset = test_dataset.remove_columns(["labels"])

#ミニバッチの作成にDataLoaderを用いる
test_dataloader = DataLoader(
    test_dataset,
    batch_size=8,
    shuffle=False,
    collate_fn=data_collator,
)

#見出しを生成する
generations = run_generation(test_dataloader, model)
AttributeError                            Traceback (most recent call last)
[<ipython-input-19-2945c01206a3>](https://localhost:8080/#) in <cell line: 45>()
     43 )
     44 # 見出しを生成する
---> 45 generations = run_generation(test_dataloader, model)

1 frames
[<ipython-input-19-2945c01206a3>](https://localhost:8080/#) in <dictcomp>(.0)
     20     generations = []
     21     for batch in tqdm(dataloader):  # 各ミニバッチを処理する
---> 22         batch = {k: v.to(model.device) for k, v in batch.items()}
     23         # 見出しのトークンのIDを生成する
     24         batch["generated_title_ids"] = model.generate(**batch)

AttributeError: 'NoneType' object has no attribute 'to'

DataLoaderの性質上、こうなってしまうようなのですが、何か解決策はありますでしょうか

Kosuke-Yamada commented 6 days ago

ご連絡ありがとうございます。 DataLoaderの仕様が変わり、Datasetには含まれていないlabelsが挿入されてしまい、batchの中に{"labels": None}が含まれるため、エラーを出力してしまっているみたいです。

エラーとなっている、下記の行を batch = {k: v.to(model.device) for k, v in batch.items()} このように変更してもらってもよろしいでしょうか? batch = {k: v.to(model.device) for k, v in batch.items() if k != "labels"}