Open tzayuan opened 5 months ago
Hi :) For inference only stage 2 model, you can take the following steps:
Implement a custom pipeline without any restoration module in utils/helpers.py
.
class CustomPipeline(Pipeline):
def __init__(self, cldm: ControlLDM, diffusion: Diffusion, cond_fn: Optional[Guidance], device: str) -> None:
super().__init__(None, cldm, diffusion, cond_fn, device)
@count_vram_usage
def run_stage1(self, lq: torch.Tensor) -> torch.Tensor:
# In our experiments, the output of restoration module (a.k.a the condition for subsequent IRControlNet)
# will be resized to a resolution >= 512, since both pretrained SD2 and IRControlNet are trained on 512x512.
# Here we directly use the lq as condition.
if min(lq.shape[2:]) < 512:
clean = resize_short_edge_to(lq, size=512)
return clean
Implement a custom inference loop in utils/inference.py
.
class CustomInferenceLoop(InferenceLoop):
@count_vram_usage
def init_stage1_model(self) -> None:
# Nothing to do.
pass
def init_pipeline(self) -> None:
# Instantiate our custom pipeline.
self.pipeline = CustomPipeline(self.cldm, self.diffusion, self.cond_fn, self.args.device)
def main():
args = parse_args()
args.device = check_device(args.device)
set_seed(args.seed)
if args.version == "v1":
V1InferenceLoop(args).run()
else:
supported_tasks = {
"sr": BSRInferenceLoop,
"dn": BIDInferenceLoop,
"fr": BFRInferenceLoop,
"fr_bg": UnAlignedBFRInferenceLoop,
"custom": CustomInferenceLoop
}
supported_tasks[args.task](args).run()
print("done!")
The provided code is directly wrote in the GitHub comment console and has not been tested, but I think it will work.
As for finetuning stage2 model, there are two ways to do it:
model
directory, load it in train_stage2.py
and replace SwinIR with your restoration model.The second way is more recommended.
Hi, @0x3f3f3f3fun I want to train & inference only stage 2 model(stage 1 will implement alone by another model and generate restorationed image), if you could provide me a command structures?(how to only fine-tune stage 2 model and save it, and how to inference only stage 2 model) Thanks!
BRs, tzayuan