Open Alvol opened 7 years ago
Try adding the flag -checkpoint_every 250
to save checkpoints more frequently.
Just an FYI, I'm not sure you'll be able to get very high quality results running for so few iterations.
Also just to make sure: if the command is in a .sh
file then you need to add backslashes before newlines to make sure that separate lines are not interpreted as different commands, like this:
th train.lua \
-h5_file path/to/dataset.h5 \
-style_image path/to/style/image.jpg \
-style_image_size 384 \
-content_weights 1.0 \
-style_weights 5.0 \
-checkpoint_name checkpoint \
-gpu 0 \
-num_iterations 1000 \
-checkpoint_every 250
I inserted this towards the bottom of the code. It allows you to run without checkpoints or with checkpoints that are not even multiple of iterations and still have the final model. As currently written, the code will only save a model if the checkpoint is an even multiple of the number of iterations.
if opt.lr_decay_every > 0 and t % opt.lr_decay_every == 0 then
local new_lr = opt.lr_decay_factor * optim_state.learningRate
optim_state = {learningRate = new_lr}
end
end
-- Save final model
print "Saving final model"
model:clearState()
if use_cudnn then
-- if we're using Cuda, convert it into a CPU friendly net, 'nn'
cudnn.convert(model, nn)
end
model:float()
local checkpoint = {
opt=opt,
train_loss_history=train_loss_history,
val_loss_history=val_loss_history,
val_loss_history_ts=val_loss_history_ts,
style_loss_history=style_loss_history,
}
checkpoint.model = model
local filename = string.format('models/instance_norm/%s_%d.t7', paths.basename(opt.style_image,'jpg'), opt.style_image_size)
torch.save(filename, checkpoint)
end
main()
I've been following readme in order to train new model, but this command: th train.lua \ -h5_file path/to/dataset.h5 \ -style_image path/to/style/image.jpg \ -style_image_size 384 \ -content_weights 1.0 \ -style_weights 5.0 \ -checkpoint_name checkpoint \ -gpu 0 -num_iterations 1000
successfully executes, but does not create any model file (which I expect to be .t7 file) What might be an issue and how to solve it? Please help.