kohya-ss / sd-webui-additional-networks

GNU Affero General Public License v3.0
1.78k stars 296 forks source link

fix to the: AttributeError: 'JSON' object has no attribute 'value' #43

Closed ChinatsuHS closed 1 year ago

ChinatsuHS commented 1 year ago

Was getting the error after automatic1111's update so i used chatgpt to fix it ...

not sure how Github works with fixes to errors so i gonna put it here:

import os
import glob
import zipfile
import json
import stat
import sys
import inspect
from collections import OrderedDict

import torch

import modules.scripts as scripts
from modules import shared, script_callbacks
import gradio as gr

from modules.processing import Processed, process_images
from modules import sd_models
import modules.ui

from scripts import lora_compvis

MAX_MODEL_COUNT = 5
LORA_MODEL_EXTS = [".pt", ".ckpt", ".safetensors"]
lora_models = {}      # "My_Lora(abcd1234)" -> C:/path/to/model.safetensors
lora_model_names = {} # "my_lora" -> "My_Lora(abcd1234)"
lora_models_dir = os.path.join(scripts.basedir(), "models/lora")
os.makedirs(lora_models_dir, exist_ok=True)

def traverse_all_files(curr_path, model_list):
  f_list = [(os.path.join(curr_path, entry.name), entry.stat()) for entry in os.scandir(curr_path)]
  for f_info in f_list:
    fname, fstat = f_info
    if os.path.splitext(fname)[1] in LORA_MODEL_EXTS:
      model_list.append(f_info)
    elif stat.S_ISDIR(fstat.st_mode):
      model_list = traverse_all_files(fname, model_list)
  return model_list

def get_all_models(sort_by, filter_by, path):
  res = OrderedDict()
  fileinfos = traverse_all_files(path, [])
  filter_by = filter_by.strip(" ")
  if len(filter_by) != 0:
    fileinfos = [x for x in fileinfos if filter_by.lower() in os.path.basename(x[0]).lower()]
  if sort_by == "name":
    fileinfos = sorted(fileinfos, key=lambda x: os.path.basename(x[0]))
  elif sort_by == "date":
    fileinfos = sorted(fileinfos, key=lambda x: -x[1].st_mtime)
  elif sort_by == "path name":
    fileinfos = sorted(fileinfos)

  for finfo in fileinfos:
    filename = finfo[0]
    name = os.path.splitext(os.path.basename(filename))[0]
    # Prevent a hypothetical "None.pt" from being listed.
    if name != "None":
      res[name + f"({sd_models.model_hash(filename)})"] = filename

  return res

def find_closest_lora_model_name(search: str):
    if not search:
        return None
    if search in lora_models:
        return search
    search = search.lower()
    if search in lora_model_names:
        return lora_model_names.get(search)
    applicable = [name for name in lora_model_names.keys() if search in name.lower()]
    if not applicable:
        return None
    applicable = sorted(applicable, key=lambda name: len(name))
    return lora_model_names[applicable[0]]

def update_lora_models():
  global lora_models, lora_model_names
  res = OrderedDict()
  paths = [lora_models_dir]
  extra_lora_path = shared.opts.data.get("additional_networks_extra_lora_path", None)
  if extra_lora_path and os.path.exists(extra_lora_path):
    paths.append(extra_lora_path)
  for path in paths:
    sort_by = shared.opts.data.get("additional_networks_sort_models_by", "name")
    filter_by = shared.opts.data.get("additional_networks_model_name_filter", "")
    found = get_all_models(sort_by, filter_by, path)
    res = {**found, **res}
  lora_models = OrderedDict(**{"None": None}, **res)
  lora_model_names = {}
  for name_and_hash, filename in lora_models.items():
      if filename == None:
          continue
      name = os.path.splitext(os.path.basename(filename))[0].lower()
      lora_model_names[name] = name_and_hash

update_lora_models()

class Script(scripts.Script):
  def __init__(self) -> None:
    super().__init__()
    self.latest_params = [(None, None, None)] * MAX_MODEL_COUNT
    self.latest_networks = []
    self.latest_model_hash = ""

  def title(self):
    return "Additional networks for generating"

  def show(self, is_img2img):
    return scripts.AlwaysVisible

  def ui(self, is_img2img):
    # NOTE: Changing the contents of `ctrls` means the XY Grid support may need
    # to be updated, see end of file
    ctrls = []
    model_dropdowns = []
    self.infotext_fields = []
    with gr.Group():
      with gr.Accordion('Additional Networks', open=False):
        enabled = gr.Checkbox(label='Enable', value=False)
        ctrls.append(enabled)
        self.infotext_fields.append((enabled, "AddNet Enabled"))

        for i in range(MAX_MODEL_COUNT):
          with gr.Row():
            module = gr.Dropdown(["LoRA"], label=f"Network module {i+1}", value="LoRA")
            model = gr.Dropdown(list(lora_models.keys()),
                                label=f"Model {i+1}",
                                value="None")

            weight = gr.Slider(label=f"Weight {i+1}", value=1.0, minimum=-1.0, maximum=2.0, step=.05)
          ctrls.extend((module, model, weight))
          model_dropdowns.append(model)

          self.infotext_fields.extend([
              (module, f"AddNet Module {i+1}"),
              (model, f"AddNet Model {i+1}"),
              (weight, f"AddNet Weight {i+1}"),
          ])

        def refresh_all_models(*dropdowns):
          update_lora_models()
          updates = []
          for dd in dropdowns:
            if dd in lora_models:
              selected = dd
            else:
              selected = "None"
            update = gr.Dropdown.update(value=selected, choices=list(lora_models.keys()))
            updates.append(update)
          return updates

        refresh_models = gr.Button(value='Refresh models')
        refresh_models.click(refresh_all_models, inputs=model_dropdowns, outputs=model_dropdowns)
        ctrls.append(refresh_models)

    return ctrls

  def set_infotext_fields(self, p, params):
    for i, t in enumerate(params):
      module, model, weight = t
      if model is None or model == "None" or len(model) == 0 or weight == 0:
        continue
      p.extra_generation_params.update({
          "AddNet Enabled": True,
          f"AddNet Module {i+1}": module,
          f"AddNet Model {i+1}": model,
          f"AddNet Weight {i+1}": weight,
      })

  def process(self, p, *args):
    unet = p.sd_model.model.diffusion_model
    text_encoder = p.sd_model.cond_stage_model

    def restore_networks():
      if len(self.latest_networks) > 0:
        print("restoring last networks")
        for network, _ in self.latest_networks[::-1]:
          network.restore(text_encoder, unet)
        self.latest_networks.clear()

    if not args[0]:
      restore_networks()
      return

    params = []
    for i, ctrl in enumerate(args[1:]):
      if i % 3 == 0:
        param = [ctrl]
      else:
        param.append(ctrl)
        if i % 3 == 2:
          params.append(param)

    models_changed = (len(self.latest_networks) == 0)                   # no latest network (cleared by check-off)
    models_changed = models_changed or self.latest_model_hash != p.sd_model.sd_model_hash
    if not models_changed:
      for (l_module, l_model, l_weight), (module, model, weight) in zip(self.latest_params, params):
        if l_module != module or l_model != model or l_weight != weight:
          models_changed = True
          break

    if models_changed:
      restore_networks()
      self.latest_params = params
      self.latest_model_hash = p.sd_model.sd_model_hash

      for module, model, weight in self.latest_params:
        if model is None or model == "None" or len(model) == 0:
          continue
        if weight == 0:
          print(f"ignore because weight is 0: {model}")
          continue

        model_path = lora_models.get(model, None)
        if model_path is None:
          raise RuntimeError(f"model not found: {model}")

        if model_path.startswith("\"") and model_path.endswith("\""):             # trim '"' at start/end
          model_path = model_path[1:-1]
        if not os.path.exists(model_path):
          print(f"file not found: {model_path}")
          continue

        print(f"{module} weight: {weight}, model: {model}")
        if module == "LoRA":
          if os.path.splitext(model_path)[1] == '.safetensors':
            from safetensors.torch import load_file
            du_state_dict = load_file(model_path)
          else:
            du_state_dict = torch.load(model_path, map_location='cpu')

          network, info = lora_compvis.create_network_and_apply_compvis(du_state_dict, weight, text_encoder, unet)
          network.to(p.sd_model.device, dtype=p.sd_model.dtype)         # in medvram, device is different for u-net and sd_model, so use sd_model's

          print(f"LoRA model {model} loaded: {info}")
          self.latest_networks.append((network, model))
      if len(self.latest_networks) > 0:
        print("setting (or sd model) changed. new networks created.")

    self.set_infotext_fields(p, self.latest_params)

def read_lora_metadata(model_path, module):
  if model_path.startswith("\"") and model_path.endswith("\""):             # trim '"' at start/end
    model_path = model_path[1:-1]
  if not os.path.exists(model_path):
    return None

  metadata = None
  if module == "LoRA":
    if os.path.splitext(model_path)[1] == '.safetensors':
      from safetensors.torch import safe_open
      with safe_open(model_path, framework="pt") as f:
        metadata = f.metadata()

  return metadata

def on_ui_tabs():
  with gr.Blocks(analytics_enabled=False) as additional_networks_interface:
    with gr.Row().style(equal_height=False):
      with gr.Column(variant='panel'):
        with gr.Row():
          module = gr.Dropdown(["LoRA"], label=f"Network module (used throughout this tab)", value="LoRA", interactive=True)
          model = gr.Dropdown(list(lora_models.keys()), label=f"Model", value="None", interactive=True)
          modules.ui.create_refresh_button(model, update_lora_models, lambda: {"choices": list(lora_models.keys())}, "refresh_lora_models")

        with gr.Row():
            with gr.Column():
              gr.HTML(value="Get comma-separated list of models (for XY Grid)")
              model_dir = gr.Textbox("", label=f"Model directory", placeholder="Optional, uses selected model's directory if blank")
              model_sort_by = gr.Radio(label="Sort models by", choices=["name", "date", "path name"], value="name", type="value")
              get_list_button = gr.Button("Get List")
            with gr.Column():
              model_list = gr.Textbox(value="", label="Model list", placeholder="Model list will be output here")

      with gr.Column():
        metadata_view = gr.JSON(data="test", label="Network metadata")

    def update_metadata(module, model):
      if model == "None":
        return {}

      model_path = lora_models.get(model, None)
      if model_path is None:
        metadata_view.data = f"file not found: {model}"

      metadata = read_lora_metadata(model_path, module)

      if metadata is None:
        return "No metadata found."
      else:
        return metadata

    model.change(update_metadata, inputs=[module, model], outputs=[metadata_view])

    def output_model_list(module, model, model_dir, sort_by):
        if model_dir == "":
            # Get list of models with same folder as this one
            model_path = lora_models.get(model, None)
            if model_path is None:
                model_list.value = f"directory not found: {model_dir}"
            model_dir = os.path.dirname(model_path)

        if not os.path.isdir(model_dir):
            return f"directory not found: {model_dir}"

        found = get_all_models(sort_by, "", model_dir)
        return ", ".join(found.keys())

    get_list_button.click(output_model_list, inputs=[module, model, model_dir, model_sort_by], outputs=[model_list])

  return [(additional_networks_interface, "Additional Networks", "additional_networks")]

def update_script_args(p, value, arg_idx):
    for s in scripts.scripts_txt2img.alwayson_scripts:
        if isinstance(s, Script):
            args = list(p.script_args)
            # print(f"Changed arg {arg_idx} from {args[s.args_from + arg_idx - 1]} to {value}")
            args[s.args_from + arg_idx] = value
            p.script_args = tuple(args)
            break

def confirm_models(p, xs):
    for x in xs:
        if x in ["", "None"]:
            continue
        if not find_closest_lora_model_name(x):
            raise RuntimeError(f"Unknown LoRA model: {x}")

def apply_module(p, x, xs, i):
    update_script_args(p, True, 0)      # set Enabled to True
    update_script_args(p, x, 1 + 3 * i) # enabled, ({module}, model, weight), ...

def apply_model(p, x, xs, i):
    name = find_closest_lora_model_name(x)
    update_script_args(p, True, 0)
    update_script_args(p, name, 2 + 3 * i) # enabled, (module, {model}, weight), ...

def apply_weight(p, x, xs, i):
    update_script_args(p, True, 0)
    update_script_args(p, x, 3 + 3 * i) # enabled, (module, model, {weight), ...

LORA_METADATA_NAMES = {
    "ss_learning_rate": "Learning rate",
    "ss_text_encoder_lr": "Text encoder LR",
    "ss_unet_lr": "UNet LR",
    "ss_num_train_images": "# of training images",
    "ss_num_reg_images": "# of reg images",
    "ss_num_batches_per_epoch": "Batches per epoch",
    "ss_num_epochs": "Total epochs",
    "ss_batch_size_per_device": "Batch size/device",
    "ss_total_batch_size": "Total batch size",
    "ss_gradient_accumulation_steps": "Gradient accum. steps",
    "ss_max_train_steps": "Max train steps",
    "ss_lr_warmup_steps": "LR warmup steps",
    "ss_lr_scheduler": "LR scheduler",
    "ss_network_module": "Network module",
    "ss_network_dim": "Network dim",
    "ss_mixed_precision": "Mixed precision", 
    "ss_full_fp16": "Full FP16",
    "ss_v2": "V2",
    "ss_resolution": "Resolution",
    "ss_clip_skip": "Clip skip",
    "ss_max_token_length": "Max token length",
    "ss_color_aug": "Color aug",
    "ss_flip_aug": "Flip aug",
    "ss_random_crop": "Random crop",
    "ss_shuffle_caption": "Shuffle caption",
    "ss_cache_latents": "Cache latents",
    "ss_enable_bucket": "Enable bucket",
    "ss_min_bucket_reso": "Min bucket reso.",
    "ss_max_bucket_reso": "Max bucket reso.",
    "ss_seed": "Seed", 
    "ss_sd_model_name": "SD model name",
    "ss_vae_name": "VAE name"
}

def format_lora_model(p, opt, x):
    model = find_closest_lora_model_name(x)
    if model is None or model.lower() in ["", "none"]:
        return "None"

    value = xy_grid.format_value(p, opt, model)

    model_path = lora_models.get(model)
    metadata = read_lora_metadata(model_path, "LoRA")
    if not metadata:
        return value

    metadata_names = shared.opts.data.get("additional_networks_xy_grid_model_metadata", "").split(",")
    if not metadata_names:
        return value

    for name in metadata_names:
        name = name.strip()
        if name in metadata:
            formatted_name = LORA_METADATA_NAMES.get(name, name)
            value += f"\n{formatted_name}: {metadata[name]}, "

    return value.strip(" ").strip(",")

for scriptDataTuple in scripts.scripts_data:
    if os.path.basename(scriptDataTuple.path) == "xy_grid.py":
        xy_grid = scriptDataTuple.module
        for i in range(MAX_MODEL_COUNT):
           model = xy_grid.AxisOption(f"AddNet Model {i+1}", str, lambda p, x, xs, i=i: apply_model(p, x, xs, i), format_lora_model, confirm_models)
           weight = xy_grid.AxisOption(f"AddNet Weight {i+1}", float, lambda p, x, xs, i=i: apply_weight(p, x, xs, i), xy_grid.format_value_add_label, None)
           xy_grid.axis_options.extend([model, weight])

def on_ui_settings():
    section = ('additional_networks', "Additional Networks")
    shared.opts.add_option("additional_networks_extra_lora_path", shared.OptionInfo("", "Extra path to scan for LoRA models (e.g. training output directory)", section=section))
    shared.opts.add_option("additional_networks_sort_models_by", shared.OptionInfo("name", "Sort LoRA models by", gr.Radio, {"choices": ["name", "date", "path name"]}, section=section))
    shared.opts.add_option("additional_networks_model_name_filter", shared.OptionInfo("", "LoRA model name filter", section=section))
    shared.opts.add_option("additional_networks_xy_grid_model_metadata", shared.OptionInfo("", "Metadata to show in XY-Grid label for Model axes, comma-separated (example: \"ss_learning_rate, ss_num_epochs\")", section=section))

def on_infotext_pasted(infotext, params):
    for i in range(MAX_MODEL_COUNT):
        if f"AddNet Module {i+1}" not in params:
            params[f"AddNet Module {i+1}"] = "LoRA"
        if f"AddNet Model {i+1}" not in params:
            params[f"AddNet Model {i+1}"] = "None"
        if f"AddNet Weight {i+1}" not in params:
            params[f"AddNet Weight {i+1}"] = "0"

script_callbacks.on_ui_tabs(on_ui_tabs)
script_callbacks.on_ui_settings(on_ui_settings)
script_callbacks.on_infotext_pasted(on_infotext_pasted)
MegaScience commented 1 year ago

I am getting the error, though do not trust ChatGPT to fix complex code depending on multiple files. ChatGPT has no idea of the scope nor the current state of the codebase, so it will not be able to give a valid answer unless you can elaborate specifically what's wrong.

Edit: The issue is that the value specified in this line isn't actual JSON but a string "test". The WebUI bumped the gradio version, and it's built-in JSON function no longer silently discards string values, so an exception is thrown. Changing test to an empty object ({}) fixes the issue, as it can be interpreted as an empty JSON object. https://github.com/kohya-ss/sd-webui-additional-networks/blob/cd156db79b7527a4f4f737d28d5771bb0126787f/scripts/additional_networks.py#L283-L284

toyxyz commented 1 year ago

I am getting the error, though do not trust ChatGPT to fix complex code depending on multiple files. ChatGPT has no idea of the scope nor the current state of the codebase, so it will not be able to give a valid answer unless you can elaborate specifically what's wrong.

Edit: The issue is that the value specified in this line isn't actual JSON but a string "test". The WebUI bumped the gradio version, and it's built-in JSON function no longer silently discards string values, so an exception is thrown. Changing test to an empty object ({}) fixes the issue, as it can be interpreted as an empty JSON object.

https://github.com/kohya-ss/sd-webui-additional-networks/blob/cd156db79b7527a4f4f737d28d5771bb0126787f/scripts/additional_networks.py#L283-L284

Now it works! thanks!

kohya-ss commented 1 year ago

Thank you for this! I've updated the script to fix the error. This seems to be caused by Gradio updating in the latest web UI.