masa-su / pixyz

A library for developing deep generative models in a more concise, intuitive and extendable way
https://pixyz.io
MIT License
491 stars 41 forks source link

Add train_batch and train methods #33

Open masa-su opened 6 years ago

masa-su commented 6 years ago

現在のtrainはbatchごとに受け取って計算する形になっているが,毎回外側でepochのiterationを書くのが面倒.そこで,従来のtrainをtrain_batchと変更し,イテレーターを受け取って1epoch分学習するtrainメソッドを実装する.

for epoch in range(1, epochs + 1):
    train_loss = model.train(train_loader)

testも同じ(pytorchにあわせて,testもevalに変更するかも)

masa-su commented 6 years ago

ただしtrainに引数を明示的に与えられないのが問題点.