XPixelGroup / DiffBIR

Official codes of DiffBIR: Towards Blind Image Restoration with Generative Diffusion Prior
Apache License 2.0
3.38k stars 289 forks source link

only train and infer stage 2 model #119

Open tzayuan opened 5 months ago

tzayuan commented 5 months ago

Hi, @0x3f3f3f3fun I want to train & inference only stage 2 model(stage 1 will implement alone by another model and generate restorationed image), if you could provide me a command structures?(how to only fine-tune stage 2 model and save it, and how to inference only stage 2 model) Thanks!

BRs, tzayuan

0x3f3f3f3fun commented 5 months ago

Hi :) For inference only stage 2 model, you can take the following steps:

  1. Implement a custom pipeline without any restoration module in utils/helpers.py.

    class CustomPipeline(Pipeline):
    
    def __init__(self, cldm: ControlLDM, diffusion: Diffusion, cond_fn: Optional[Guidance], device: str) -> None:
        super().__init__(None, cldm, diffusion, cond_fn, device)
    
    @count_vram_usage
    def run_stage1(self, lq: torch.Tensor) -> torch.Tensor:
        # In our experiments, the output of restoration module (a.k.a the condition for subsequent IRControlNet)
        # will be resized to a resolution >= 512, since both pretrained SD2 and IRControlNet are trained on 512x512.
        # Here we directly use the lq as condition.
        if min(lq.shape[2:]) < 512:
            clean = resize_short_edge_to(lq, size=512)
        return clean
  2. Implement a custom inference loop in utils/inference.py.

    class CustomInferenceLoop(InferenceLoop):
    
    @count_vram_usage
    def init_stage1_model(self) -> None:
        # Nothing to do.
        pass
    
    def init_pipeline(self) -> None:
        # Instantiate our custom pipeline.
        self.pipeline = CustomPipeline(self.cldm, self.diffusion, self.cond_fn, self.args.device)
  3. Add the custom pipeline to inference.py as an option.
    def main():
    args = parse_args()
    args.device = check_device(args.device)
    set_seed(args.seed)
    if args.version == "v1":
        V1InferenceLoop(args).run()
    else:
        supported_tasks = {
            "sr": BSRInferenceLoop,
            "dn": BIDInferenceLoop,
            "fr": BFRInferenceLoop,
            "fr_bg": UnAlignedBFRInferenceLoop,
            "custom": CustomInferenceLoop
        }
        supported_tasks[args.task](args).run()
        print("done!")

    The provided code is directly wrote in the GitHub comment console and has not been tested, but I think it will work.

As for finetuning stage2 model, there are two ways to do it:

The second way is more recommended.