turingmotors / heron

Apache License 2.0
163 stars 25 forks source link

zero3時の学習済みモデルの読み込み対応 #28

Open kotarotanahashi opened 11 months ago

kotarotanahashi commented 11 months ago

heronの今後の修正事項として、zero3の時にモデルの初期値として学習済みモデルを読み込む場合、

load_pretrained_weight(model, model_config["pretrained_path"])

の部分がエラーとなってしまう。load_modelの部分を以下のようにfrom_pretrainedを使って学習済みモデルから読み込むように変更する必要ありそうです。

model = VideoBlipForConditionalGeneration.from_pretrained(
   model_config["pretrained_path"], torch_dtype=torch.float16, ignore_mismatched_sizes=True
)