stockmarkteam / bert-book

「BERTによる自然言語処理入門: Transformersを使った実践プログラミング」サポートページ
MIT License
259 stars 80 forks source link

【第8章】encode_plus_untagged関数のtorch.Tensor変換について #46

Open kawase621 opened 3 years ago

kawase621 commented 3 years ago

encode_plus_untagged関数で、最後にtorch.Tensorに変換する部分ですが、以下のように修正すべきではありませんか?

encoding = { k: torch.tensor([v]) for k, v in encoding.items() }

↓修正後

encoding = { k: torch.tensor(v) for k, v in encoding.items() }

torch.Tensorに変換する前にリスト化することで、encode_plus_untagged関数の返り値の要素であるinput_idsなどの値は「2次元Tensor」になっています。 その一方、encode_plus_tagged関数の返り値の要素であるinput_idsなどは値が「1次元Tensor」になっています。 性能評価もバッチ処理にて行うべくコードを書いている中で、encode_plus_untagged関数の返り値を使用してデータローダを作ろうとした際、input_idsなどの値が3次元のTensorになってしまい、BERTへ入力できないという問題が生じました。encode_plus_untagged関数のほうだけ、input_idsなどの値を2次元Tensorにしなくてはならない特別な理由がない限り、バッチ処理も想定して1次元Tensorにすべきと考えます。

omitakahiro commented 3 years ago

transformersのtokenizerでは単一の文を符号化すると

これは想像するに、return_tensors='pt'のオプションをつける場合には、tokenizerの出力をそのままBERTに入力することを想定しており、そのため2次元のtorch.tensorを出力しているのだと思います。

kawase621 commented 3 years ago

ご回答ありがとうございます。 以下の理解で間違いありませんか?

・BERTは入力として2次元のtorch.Tensorしか受け付けない。つまり、与えるデータが単一でも2次元にする必要がある。 ・encode_plus_untagged関数の出力は、推論時にそのままBERTへ入力されることを想定しているため、2次元のtorch.Tensorになっている。複数のデータを扱う際は、一つずつ繰り返し与える。つまり、推論処理において、データローダを用いたバッチ処理は想定していない(2次元のtorch.Tensorを使うと3次元のバッチデータになるため)。 ・逆に、encode_plus_tagged関数は出力を一度データローダに通し、バッチデータ化することを想定しているため、出力は1次元のtorch.Tensorになっている。つまり、学習時に単一のデータをBERTへ入力することは想定していない。

omitakahiro commented 3 years ago

BERTは入力として2次元のtorch.Tensorしか受け付けない。つまり、与えるデータが単一でも2次元にする必要がある。

こちらは正しいです。

・encode_plus_untagged関数の出力は、推論時にそのままBERTへ入力されることを想定しているため、2次元のtorch.Tensorになっている。複数のデータを扱う際は、一つずつ繰り返し与える。つまり、推論処理において、データローダを用いたバッチ処理は想定していない(2次元のtorch.Tensorを使うと3次元のバッチデータになるため)。 ・逆に、encode_plus_tagged関数は出力を一度データローダに通し、バッチデータ化することを想定しているため、出力は1次元のtorch.Tensorになっている。つまり、学習時に単一のデータをBERTへ入力することは想定していない。

kawase621 commented 3 years ago

ありがとうございます、理解できました。

kawase621 commented 3 years ago

ちなみに、encode_plus_tagged関数でpytorchで出力するモードがない理由はあるのでしょうか?

omitakahiro commented 3 years ago

本書ではencode_plus_taggedは学習データを処理して、後にバッチ化することを想定しているためです。

kawase621 commented 3 years ago

よく分かりました、お忙しい中ご対応頂きありがとうございます。