YutaroOgawa / pytorch_advanced

書籍「つくりながら学ぶ! PyTorchによる発展ディープラーニング」の実装コードを配置したリポジトリです
MIT License
848 stars 337 forks source link

[質問] 1-5のファインチューニングの実装の中の、「torch.save()」についてです。 #124

Open techan-yun opened 4 years ago

techan-yun commented 4 years ago

小川雄太郎様、ご多忙の中申し訳ございません。 私、「作りながら学ぶ! PyTorchによる発展ディ―プラーニング」の本を学習している者でございます。 実は、1-5のファインチューニングの実装を行う中で、エラーが発生いたしましたので質問をいたしました。 具体的には、1-5の最後に近い部分にあります、「学習したネットワークを保存する」部分なのですが、

save_path = './weights_fine_tuning.pth' torch.save(net.state_dict(), save_path)

というコードをジュピターノートブックで実行し、実際にweights_fine_tuning.pthが作られておりました。 しかし、ジュピターノートブック上でweights_fine_tuning.pthの中身を確認すると、

Error! /(weights_fine_tuning.pthまでのパスなので省略)/weights_fine_tuning.pth is not UTF-8 encoded. saving disabled. See Console for more details

というエラーになっておりました。 saving disabled.とあるので、保存できていないのかなと感じ、解決のために上記のような質問をいたしました。

長文の質問になってしまい、大変申し訳ございません。どうかよろしくお願いいたします。

YutaroOgawa commented 4 years ago

@techan-yun さま

回答が遅くなり申し訳ございません。

この保存ファイルを読み込むことを試した場合、どのようなエラーが発生しますか?


# PyTorchのネットワークパラメータの保存
save_path = './weights_fine_tuning.pth'
torch.save(net.state_dict(), save_path)

のあとに、以下を実行

# PyTorchのネットワークパラメータのロード
load_path = './weights_fine_tuning.pth'
load_weights = torch.load(load_path)
net.load_state_dict(load_weights)

# GPU上で保存された重みをCPU上でロードする場合
load_weights = torch.load(load_path, map_location={'cuda:0': 'cpu'})
net.load_state_dict(load_weights)

どうぞよろしくお願い致します。

techan-yun commented 4 years ago

小川雄太郎様、返信くださいまして感謝いたします。 雄太郎様が書かれました上記のコードを実行した結果、エラーは発生せず、ネットワーク内容も保存されていたので、私の推測ですが、jupyter notebookでweights_fine_tuning.pthを直接開こうとすると、テキストファイルではないのでエラーが出てしまったのかなと思っております。 実際、保存したネットワークをロードして推論もできましたので無事に問題は解決したと考えております。 小川雄太郎様、お忙しい中、質問に丁寧に答えてくださいまして感謝いたします。小川雄太郎様のこれからのご活躍をお祈り申し上げます。

YutaroOgawa commented 4 years ago

@techan-yun さま

ありがとうございます! 私も一つ勉強になりました。

今後ともどうぞ宜しくお願い致します。