mit-quest / necstlab-damage-segmentation

MIT License
5 stars 6 forks source link

Allow models to be saved at a given frequency #104

Open CarolinaFurtado opened 3 years ago

CarolinaFurtado commented 3 years ago
   filepath = Path(model_dir, "model_epoch_{epoch:03d}.hdf5")
    model_checkpoint_callback_frequency = ModelCheckpoint(filepath,
                                                          monitor=metric_modelcheckpoint, verbose=1,
                                                          save_best_only=False, save_weights_only=False,
                                                          mode='auto', period=20)

    model_checkpoint_callback = ModelCheckpoint(Path(model_dir, 'model.hdf5').as_posix(),
                                                monitor=metric_modelcheckpoint, verbose=1, save_best_only=True)

    tensorboard_callback = TensorBoard(log_dir=logs_dir.as_posix(), write_graph=True,
                                       write_grads=False, write_images=True, update_freq='epoch', profile_batch=0)

    n_sample_images = 20
    train_image_and_mask_paths = sample_image_and_mask_paths(train_generator, n_sample_images)
    validation_image_and_mask_paths = sample_image_and_mask_paths(validation_generator, n_sample_images)

    csv_logger_callback = CSVLogger(Path(model_dir, 'metrics.csv').as_posix(), append=True)
    time_callback = timecallback()  # model_dir, plots_dir, 'metrics_epochtime.csv')

    results = compiled_model.fit(
        train_generator,
        steps_per_epoch=len(train_generator),
        epochs=epochs,
        validation_data=validation_generator,
        validation_steps=len(validation_generator),
        callbacks=[model_checkpoint_callback, model_checkpoint_callback_frequency, tensorboard_callback, time_callback, csv_logger_callback]
    )