comfyanonymous / ComfyUI

The most powerful and modular diffusion model GUI, api and backend with a graph/nodes interface.
https://www.comfy.org/
GNU General Public License v3.0
48.38k stars 5.08k forks source link

Always reload model after switching workflow, want a smarter model-loader-cache mechanism. #2192

Open wuutiing opened 8 months ago

wuutiing commented 8 months ago

AS in execution.py https://github.com/comfyanonymous/ComfyUI/blob/ef29542030eefd1ebcd8c6c1da857ce72ea4427b/execution.py#L215-L266 each node is compared to last prompt's same node, if not changed, the result will be reused. but the "same" is decided by prompt's id, which is not controlable to user, for example, when switching two workflows with same CheckpointLoader(all inputs are also same),but the id of two CheckpointLoader may not the same and comfyui will remove the previous output and re-run CheckpointLoader which takes much IO time。

Considering the much IO time Loaders take, I think it will be nice caching as mutch as possible the results of Loaders whenever the prompt changes.

NeedsMoar commented 8 months ago

The loaders don't take much IO time at all with non-loaded models (none if the checkpoint is still in standby / ram cache which is a function of the OS and overall system memory). How long of times are you seeing?

wuutiing commented 8 months ago

@NeedsMoar costs may be not only IO. I have printed the time for loading sd-xl-base-1.0, roughly around 2~2.5s. Just now I use line_profiler package and print a detailed costs table at last. Table said its expansive loading ckpts, let alone OS's optimization. code is below you are welcome to test.

# file COMFYUI/comfy/sd.py
from line_profiler import LineProfiler
lp = LineProfiler()

def print_lp(fn):
    def _f(*a, **kws):
      res = fn(*a, **kws)
      lp.print_stats()
      return res
    return _f

@print_lp
@lp
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
  pass
FIRST TIME
File: comfy/sd.py
Function: load_checkpoint_guess_config at line 434
Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   434                                           @print_lp
   435                                           @lp
   436                                           def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
   437         1  912175410.0    9e+08     34.3      sd = comfy.utils.load_torch_file(ckpt_path)
   438         1       1204.0   1204.0      0.0      sd_keys = sd.keys()
   439         1        215.0    215.0      0.0      clip = None
   440         1        113.0    113.0      0.0      clipvision = None
   441         1        252.0    252.0      0.0      vae = None
   442         1        101.0    101.0      0.0      model = None
   443         1        118.0    118.0      0.0      model_patcher = None
   444         1        133.0    133.0      0.0      clip_target = None
   445                                           
   446         1    1008467.0    1e+06      0.0      parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
   447         1      70523.0  70523.0      0.0      unet_dtype = model_management.unet_dtype(model_params=parameters)
   448                                           
   449         1      25990.0  25990.0      0.0      class WeightsLoader(torch.nn.Module):
   450                                                   pass
   451                                           
   452         1   99808908.0    1e+08      3.8      model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
   453         1        271.0    271.0      0.0      if model_config is None:
   454                                                   raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
   455                                           
   456         1        896.0    896.0      0.0      if model_config.clip_vision_prefix is not None:
   457                                                   if output_clipvision:
   458                                                       clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
   459                                           
   460         1        134.0    134.0      0.0      if output_model:
   461         1     408025.0 408025.0      0.0          inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
   462         1       4223.0   4223.0      0.0          offload_device = model_management.unet_offload_device()
   463         1   94454512.0    9e+07      3.6          model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
   464         1  374762542.0    4e+08     14.1          model.load_model_weights(sd, "model.diffusion_model.")
   465                                           
   466         1        558.0    558.0      0.0      if output_vae:
   467         1     580234.0 580234.0      0.0          vae_sd = comfy.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
   468         1       2398.0   2398.0      0.0          vae_sd = model_config.process_vae_state_dict(vae_sd)
   469         1  285088097.0    3e+08     10.7          vae = VAE(sd=vae_sd)
   470                                           
   471         1        362.0    362.0      0.0      if output_clip:
   472         1      24951.0  24951.0      0.0          w = WeightsLoader()
   473         1       8699.0   8699.0      0.0          clip_target = model_config.clip_target()
   474         1        311.0    311.0      0.0          if clip_target is not None:
   475         1  744525506.0    7e+08     28.0              clip = CLIP(clip_target, embedding_directory=embedding_directory)
   476         1      21894.0  21894.0      0.0              w.cond_stage_model = clip.cond_stage_model
   477         1    1454694.0    1e+06      0.1              sd = model_config.process_clip_state_dict(sd)
   478         1  129492911.0    1e+08      4.9              load_model_weights(w, sd)
   479                                           
   480         1       1667.0   1667.0      0.0      left_over = sd.keys()
   481         1        558.0    558.0      0.0      if len(left_over) > 0:
   482                                                   print("left over keys:", left_over)
   483                                           
   484         1        138.0    138.0      0.0      if output_model:
   485         1   16043524.0    2e+07      0.6          model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
   486         1      10726.0  10726.0      0.0          if inital_load_device != torch.device("cpu"):
   487                                                       print("loaded straight to GPU")
   488                                                       model_management.load_model_gpu(model_patcher)
   489                                           
   490         1        691.0    691.0      0.0      return (model_patcher, clip, vae, clipvision)

and a second time loading same model:

SECOND TIME (cost is doubled because the line_profiler package accumulated the tow call's cost)
File: comfy/sd.py
Function: load_checkpoint_guess_config at line 434

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   434                                           @print_lp
   435                                           @lp
   436                                           def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
   437         2 1822391664.0    9e+08     36.9      sd = comfy.utils.load_torch_file(ckpt_path)
   438         2       2676.0   1338.0      0.0      sd_keys = sd.keys()
   439         2        382.0    191.0      0.0      clip = None
   440         2        239.0    119.5      0.0      clipvision = None
   441         2        517.0    258.5      0.0      vae = None
   442         2        215.0    107.5      0.0      model = None
   443         2        261.0    130.5      0.0      model_patcher = None
   444         2        288.0    144.0      0.0      clip_target = None
   445                                           
   446         2    2115932.0    1e+06      0.0      parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
   447         2     129393.0  64696.5      0.0      unet_dtype = model_management.unet_dtype(model_params=parameters)
   448                                           
   449         2      53681.0  26840.5      0.0      class WeightsLoader(torch.nn.Module):
   450                                                   pass
   451                                           
   452         2  205460258.0    1e+08      4.2      model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
   453         2       1082.0    541.0      0.0      if model_config is None:
   454                                                   raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
   455                                           
   456         2       1990.0    995.0      0.0      if model_config.clip_vision_prefix is not None:
   457                                                   if output_clipvision:
   458                                                       clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
   459                                           
   460         2        292.0    146.0      0.0      if output_model:
   461         2     892193.0 446096.5      0.0          inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
   462         2      10138.0   5069.0      0.0          offload_device = model_management.unet_offload_device()
   463         2  161632167.0    8e+07      3.3          model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
   464         2  683757956.0    3e+08     13.8          model.load_model_weights(sd, "model.diffusion_model.")
   465                                           
   466         2       1215.0    607.5      0.0      if output_vae:
   467         2    1192285.0 596142.5      0.0          vae_sd = comfy.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
   468         2       5786.0   2893.0      0.0          vae_sd = model_config.process_vae_state_dict(vae_sd)
   469         2  342663373.0    2e+08      6.9          vae = VAE(sd=vae_sd)
   470                                           
   471         2        695.0    347.5      0.0      if output_clip:
   472         2      54521.0  27260.5      0.0          w = WeightsLoader()
   473         2      18410.0   9205.0      0.0          clip_target = model_config.clip_target()
   474         2        644.0    322.0      0.0          if clip_target is not None:
   475         2 1427404673.0    7e+08     28.9              clip = CLIP(clip_target, embedding_directory=embedding_directory)
   476         2      43228.0  21614.0      0.0              w.cond_stage_model = clip.cond_stage_model
   477         2    2989338.0    1e+06      0.1              sd = model_config.process_clip_state_dict(sd)
   478         2  257158381.0    1e+08      5.2              load_model_weights(w, sd)
   479                                           
   480         2       2637.0   1318.5      0.0      left_over = sd.keys()
   481         2       1187.0    593.5      0.0      if len(left_over) > 0:
   482                                                   print("left over keys:", left_over)
   483                                           
   484         2        278.0    139.0      0.0      if output_model:
   485         2   30467827.0    2e+07      0.6          model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
   486         2      21942.0  10971.0      0.0          if inital_load_device != torch.device("cpu"):
   487                                                       print("loaded straight to GPU")
   488                                                       model_management.load_model_gpu(model_patcher)
   489                                           
   490         2       1187.0    593.5      0.0      return (model_patcher, clip, vae, clipvision)
IARI commented 3 months ago

This caching mechanism is indeed very strange and counter-intuitive.

ltdrdata commented 3 months ago

If you are willing to share checkpoint model across worlfow.

You can use this node. You can manipulate caching and releasing explicitly. image

And take a look other backend cache nodes in Inspire Pack as well. https://github.com/ltdrdata/ComfyUI-Inspire-Pack