Open kawase621 opened 3 years ago
transformersのtokenizerでは単一の文を符号化すると
return_tensors='pt'
のオプションをつけると、2次元のtorch.tensorを出力します。
そのため、今回のencode_plus_untaggedの関数でもその仕様に従ったものです。これは想像するに、return_tensors='pt'
のオプションをつける場合には、tokenizerの出力をそのままBERTに入力することを想定しており、そのため2次元のtorch.tensorを出力しているのだと思います。
ご回答ありがとうございます。 以下の理解で間違いありませんか?
・BERTは入力として2次元のtorch.Tensorしか受け付けない。つまり、与えるデータが単一でも2次元にする必要がある。 ・encode_plus_untagged関数の出力は、推論時にそのままBERTへ入力されることを想定しているため、2次元のtorch.Tensorになっている。複数のデータを扱う際は、一つずつ繰り返し与える。つまり、推論処理において、データローダを用いたバッチ処理は想定していない(2次元のtorch.Tensorを使うと3次元のバッチデータになるため)。 ・逆に、encode_plus_tagged関数は出力を一度データローダに通し、バッチデータ化することを想定しているため、出力は1次元のtorch.Tensorになっている。つまり、学習時に単一のデータをBERTへ入力することは想定していない。
BERTは入力として2次元のtorch.Tensorしか受け付けない。つまり、与えるデータが単一でも2次元にする必要がある。
こちらは正しいです。
・encode_plus_untagged関数の出力は、推論時にそのままBERTへ入力されることを想定しているため、2次元のtorch.Tensorになっている。複数のデータを扱う際は、一つずつ繰り返し与える。つまり、推論処理において、データローダを用いたバッチ処理は想定していない(2次元のtorch.Tensorを使うと3次元のバッチデータになるため)。 ・逆に、encode_plus_tagged関数は出力を一度データローダに通し、バッチデータ化することを想定しているため、出力は1次元のtorch.Tensorになっている。つまり、学習時に単一のデータをBERTへ入力することは想定していない。
encode_plus_tagged
の出力ではinput_idsは1次元の「torch.Tensor」ではなく1次元の「リスト」です(p116の8-7のコードセルの出力を参照のこと)。
encode_plus_untagged
の出力は、入力によって以下のように異なります。
return_tensors='pt'
を入力に含める場合:例えば、tokenizer.encode_plus_untagged(text, return_tensors='pt')
の出力ではinput_idsは2次元のtorch.Tensorになり、直接BERTに入力できます。return_tensors
を指定しない場合:例えば、tokenizer.encode_plus_untagged(text, max_length=20)
の出力ではinput_idsは1次元の「リスト」(長さは20)になります。(この用法は本書では触れられてはいませんが、動作としてはこうなります。推論でもデータをバッチ化する際にはこちらを用いる方が良いかと思います。)ありがとうございます、理解できました。
ちなみに、encode_plus_tagged関数でpytorchで出力するモードがない理由はあるのでしょうか?
本書ではencode_plus_tagged
は学習データを処理して、後にバッチ化することを想定しているためです。
よく分かりました、お忙しい中ご対応頂きありがとうございます。
encode_plus_untagged関数で、最後にtorch.Tensorに変換する部分ですが、以下のように修正すべきではありませんか?
↓修正後
torch.Tensorに変換する前にリスト化することで、encode_plus_untagged関数の返り値の要素であるinput_idsなどの値は「2次元Tensor」になっています。 その一方、encode_plus_tagged関数の返り値の要素であるinput_idsなどは値が「1次元Tensor」になっています。 性能評価もバッチ処理にて行うべくコードを書いている中で、encode_plus_untagged関数の返り値を使用してデータローダを作ろうとした際、input_idsなどの値が3次元のTensorになってしまい、BERTへ入力できないという問題が生じました。encode_plus_untagged関数のほうだけ、input_idsなどの値を2次元Tensorにしなくてはならない特別な理由がない限り、バッチ処理も想定して1次元Tensorにすべきと考えます。