pystiche / papers

Reference implementation and replication of prominent NST papers
BSD 3-Clause "New" or "Revised" License
4 stars 1 forks source link

Partially wrong num_batches in johnson_alahi_li_2016 training #292

Closed jbueltemeier closed 1 year ago

jbueltemeier commented 2 years ago

Small question about training.py in johnson_alahi_li_2016. The original authors have published hyperparameters for the different styles, which we also use in the replication. Among other things, the hyperparameter num_batches have to be adjusted. Doesn't image_loader in main(args) therefore have to be created later with the adapted hyperparameters? Or am I missing something there?

Proposal change:

def main(args):
    dataset = paper.dataset(args.dataset_dir, impl_params=args.impl_params)

    for style in args.style:
        style_image = read_style_image(
            args.images_source_dir, style, device=args.device
        )

        hyper_parameters = adapted_hyper_parameters(
            args.impl_params, args.instance_norm, style
        )
        content_image_loader = paper.image_loader(
            dataset,
            hyper_parameters=hyper_parameters,
            pin_memory=str(args.device).startswith("cuda"),
        )
...
pmeier commented 2 years ago

This also needs to be updated:

https://github.com/pystiche/papers/blob/c5e075231d58fa1556a47a50d851a72f21ded443/pystiche_papers/johnson_alahi_li_2016/_data.py#L227

pmeier commented 2 years ago

Makes sense, go for it!