Closed Syncriix closed 4 months ago
At the bottom of this comment is a diff with the changes needed to add this functionality.
For these instructions, you'll either need to be using a linux terminal or use git bash if you're on Windows. Of course, change the file paths to match your environment and OS.
You should have already downloaded the Z3D-E621-Convnext.onnx
and file tags-selected.csv
from the discord you linked.
Honestly I don't know if the
.csv
file is even required but I chucked it in there anyway.
Copy the diff from the bottom of this comment to a file my_diff.patch
.
Navigate to the extension's directory.
cd ~/stable-diffusion-webui/extensions/stable-diffusion-webui-dataset-tag-editor
git apply ~/my_diff.patch
Put the two files you downloaded from the e621 discord in your SD-Webui repo's models directory: stable-diffusion-webui/models/TaggerOnnx/Z3D-E621-Convnext/
TaggerOnnx
directory didn't exist for me so I had to create it. The files must also be located in their own directory (Z3D-E621-Convnext
) otherwise it won't work.scripts/dataset_tag_editor/tagger.py
based on the DEFAULT_ONNX_PATH
variable and in the E621
class's self.repo_name
variable. Change these if you want to put these files somewhere else. diff --git a/scripts/dataset_tag_editor/dte_logic.py b/scripts/dataset_tag_editor/dte_logic.py
index b2d862c..597e349 100644
--- a/scripts/dataset_tag_editor/dte_logic.py
+++ b/scripts/dataset_tag_editor/dte_logic.py
@@ -15,14 +15,14 @@ from scripts.tokenizer import clip_tokenizer
WD_TAGGER_NAMES = ["wd-v1-4-vit-tagger", "wd-v1-4-convnext-tagger", "wd-v1-4-vit-tagger-v2", "wd-v1-4-convnext-tagger-v2", "wd-v1-4-swinv2-tagger-v2"]
WD_TAGGER_THRESHOLDS = [0.35, 0.35, 0.3537, 0.3685, 0.3771] # v1: idk if it's okay v2: P=R thresholds on each repo https://huggingface.co/SmilingWolf
-INTERROGATORS = [captioning.BLIP(), tagger.DeepDanbooru()] + [tagger.WaifuDiffusion(name, WD_TAGGER_THRESHOLDS[i]) for i, name in enumerate(WD_TAGGER_NAMES)]
+INTERROGATORS = [captioning.BLIP(), tagger.DeepDanbooru(), tagger.E621()] + [tagger.WaifuDiffusion(name, WD_TAGGER_THRESHOLDS[i]) for i, name in enumerate(WD_TAGGER_NAMES)]
INTERROGATOR_NAMES = [it.name() for it in INTERROGATORS]
re_tags = re.compile(r'^([\s\S]+?)( \[\d+\])?$')
re_newlines = re.compile(r'[\r\n]+')
-def interrogate_image(path:str, interrogator_name:str, threshold_booru, threshold_wd):
+def interrogate_image(path:str, interrogator_name:str, threshold_booru, threshold_e621, threshold_wd):
try:
img = Image.open(path).convert('RGB')
except:
@@ -33,6 +33,9 @@ def interrogate_image(path:str, interrogator_name:str, threshold_booru, threshol
if isinstance(it, tagger.DeepDanbooru):
with it as tg:
res = tg.predict(img, threshold_booru)
+ elif isinstance(it, tagger.E621):
+ with it as tg:
+ res = tg.predict(img, threshold_e621)
elif isinstance(it, tagger.WaifuDiffusion):
with it as tg:
res = tg.predict(img, threshold_wd)
@@ -482,7 +485,22 @@ class DatasetTagEditor(Singleton):
print(e)
- def load_dataset(self, img_dir:str, caption_ext:str, recursive:bool, load_caption_from_filename:bool, replace_new_line:bool, interrogate_method:InterrogateMethod, interrogator_names:List[str], threshold_booru:float, threshold_waifu:float, use_temp_dir:bool, kohya_json_path:Optional[str], max_res:float):
+ def load_dataset(
+ self,
+ img_dir:str,
+ caption_ext:str,
+ recursive:bool,
+ load_caption_from_filename:bool,
+ replace_new_line:bool,
+ interrogate_method:InterrogateMethod,
+ interrogator_names:List[str],
+ threshold_booru:float,
+ threshold_e621: float,
+ threshold_waifu:float,
+ use_temp_dir:bool,
+ kohya_json_path:Optional[str],
+ max_res:float,
+ ):
self.clear()
img_dir_obj = Path(img_dir)
@@ -561,6 +579,8 @@ class DatasetTagEditor(Singleton):
if isinstance(it, tagger.Tagger):
if isinstance(it, tagger.DeepDanbooru):
taggers.append((it, threshold_booru))
+ if isinstance(it, tagger.E621):
+ taggers.append((it, threshold_e621))
if isinstance(it, tagger.WaifuDiffusion):
taggers.append((it, threshold_waifu))
elif isinstance(it, captioning.Captioning):
diff --git a/scripts/dataset_tag_editor/interrogators/__init__.py b/scripts/dataset_tag_editor/interrogators/__init__.py
index 726c896..2c98c03 100644
--- a/scripts/dataset_tag_editor/interrogators/__init__.py
+++ b/scripts/dataset_tag_editor/interrogators/__init__.py
@@ -1,6 +1,7 @@
from .git_large_captioning import GITLargeCaptioning
from .waifu_diffusion_tagger import WaifuDiffusionTagger
+from .e621_tagger import E621Tagger
__all__ = [
- 'GITLargeCaptioning', 'WaifuDiffusionTagger'
+ 'GITLargeCaptioning', "E621Tagger", 'WaifuDiffusionTagger'
]
\ No newline at end of file
diff --git a/scripts/dataset_tag_editor/tagger.py b/scripts/dataset_tag_editor/tagger.py
index 5ee520b..c4c5b28 100644
--- a/scripts/dataset_tag_editor/tagger.py
+++ b/scripts/dataset_tag_editor/tagger.py
@@ -5,10 +5,14 @@ import numpy as np
from typing import Optional, Dict
from modules import devices, shared
from modules import deepbooru as db
+from modules import shared
+from modules.shared import models_path
+from pathlib import Path
+import os
from .interrogator import Interrogator
from .interrogators import WaifuDiffusionTagger
-
+from .interrogators import E621Tagger
class Tagger(Interrogator):
def start(self):
@@ -23,7 +27,7 @@ class Tagger(Interrogator):
def get_replaced_tag(tag: str):
use_spaces = shared.opts.deepbooru_use_spaces
- use_escape = shared.opts.deepbooru_escape
+ use_escape = shared.opts.deepbooru_escape
if use_spaces:
tag = tag.replace('_', ' ')
if use_escape:
@@ -102,5 +106,41 @@ class WaifuDiffusion(Tagger):
return probability_dict
+ def name(self):
+ return self.repo_name
+
+DEFAULT_ONNX_PATH = Path(models_path, "TaggerOnnx")
+
+class E621(Tagger):
+ def __init__(self):
+ self.repo_name = "Z3D-E621-Convnext"
+ self.onnx_path = os.path.join(DEFAULT_ONNX_PATH, self.repo_name)
+ self.tagger_inst = E621Tagger(self.onnx_path)
+ self.threshold = 0.35
+
+ def start(self):
+ self.tagger_inst.load()
+ return self
+
+ def stop(self):
+ self.tagger_inst.unload()
+
+ # brought from https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py and modified
+ # set threshold<0 to use default value for now...
+ def predict(self, image: Image.Image, threshold: Optional[float] = None):
+ # may not use ratings
+ # rating = dict(labels[:4])
+
+ labels = self.tagger_inst.apply(image)
+
+ if threshold is not None:
+ if threshold < 0:
+ threshold = self.threshold
+ probability_dict = dict([(get_replaced_tag(x[0]), x[1]) for x in labels[4:] if x[1] > threshold])
+ else:
+ probability_dict = dict([(get_replaced_tag(x[0]), x[1]) for x in labels[4:]])
+
+ return probability_dict
+
def name(self):
return self.repo_name
\ No newline at end of file
diff --git a/scripts/main.py b/scripts/main.py
index 33c613b..27c213f 100644
--- a/scripts/main.py
+++ b/scripts/main.py
@@ -29,7 +29,9 @@ GeneralConfig = namedtuple('GeneralConfig', [
'use_interrogator',
'use_interrogator_names',
'use_custom_threshold_booru',
- 'custom_threshold_booru',
+ 'custom_threshold_booru',
+ 'use_custom_threshold_e621',
+ 'custom_threshold_e621',
'use_custom_threshold_waifu',
'custom_threshold_waifu',
'save_kohya_metadata',
@@ -44,7 +46,7 @@ BatchEditConfig = namedtuple('BatchEditConfig', ['show_only_selected', 'prepend'
EditSelectedConfig = namedtuple('EditSelectedConfig', ['auto_copy', 'sort_on_save', 'warn_change_not_saved', 'use_interrogator_name', 'sort_by', 'sort_order'])
MoveDeleteConfig = namedtuple('MoveDeleteConfig', ['range', 'target', 'caption_ext', 'destination'])
-CFG_GENERAL_DEFAULT = GeneralConfig(True, '', '.txt', False, True, False, 'No', [], False, 0.7, False, 0.35, False, '', '', True, False, False)
+CFG_GENERAL_DEFAULT = GeneralConfig(True, '', '.txt', False, True, False, 'No', [], False, 0.7, False, 0.35, False, 0.35, False, '', '', True, False, False)
CFG_FILTER_P_DEFAULT = FilterConfig(False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, 'AND')
CFG_FILTER_N_DEFAULT = FilterConfig(False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, 'OR')
CFG_BATCH_EDIT_DEFAULT = BatchEditConfig(True, False, False, 'Only Selected Tags', False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, SortBy.ALPHA.value, SortOrder.ASC.value, 75)
@@ -116,6 +118,7 @@ def read_general_config():
('use_blip_to_prefill', 'BLIP'),
('use_git_to_prefill', 'GIT-large-COCO'),
('use_booru_to_prefill', 'DeepDanbooru'),
+ ('use_e621_to_prefill', 'E621'),
('use_waifu_to_prefill', 'wd-v1-4-vit-tagger')
]
use_interrogator_names = []
@@ -240,10 +243,26 @@ def on_ui_tabs():
# General
components_general = [
- ui.toprow.cb_backup, ui.load_dataset.tb_img_directory, ui.load_dataset.tb_caption_file_ext, ui.load_dataset.cb_load_recursive,
- ui.load_dataset.cb_load_caption_from_filename, ui.load_dataset.cb_replace_new_line_with_comma, ui.load_dataset.rb_use_interrogator, ui.load_dataset.dd_intterogator_names,
- ui.load_dataset.cb_use_custom_threshold_booru, ui.load_dataset.sl_custom_threshold_booru, ui.load_dataset.cb_use_custom_threshold_waifu, ui.load_dataset.sl_custom_threshold_waifu,
- ui.toprow.cb_save_kohya_metadata, ui.toprow.tb_metadata_output, ui.toprow.tb_metadata_input, ui.toprow.cb_metadata_overwrite, ui.toprow.cb_metadata_as_caption, ui.toprow.cb_metadata_use_fullpath
+ ui.toprow.cb_backup,
+ ui.load_dataset.tb_img_directory,
+ ui.load_dataset.tb_caption_file_ext,
+ ui.load_dataset.cb_load_recursive,
+ ui.load_dataset.cb_load_caption_from_filename,
+ ui.load_dataset.cb_replace_new_line_with_comma,
+ ui.load_dataset.rb_use_interrogator,
+ ui.load_dataset.dd_intterogator_names,
+ ui.load_dataset.cb_use_custom_threshold_booru,
+ ui.load_dataset.sl_custom_threshold_booru,
+ ui.load_dataset.cb_use_custom_threshold_e621,
+ ui.load_dataset.sl_custom_threshold_e621,
+ ui.load_dataset.cb_use_custom_threshold_waifu,
+ ui.load_dataset.sl_custom_threshold_waifu,
+ ui.toprow.cb_save_kohya_metadata,
+ ui.toprow.tb_metadata_output,
+ ui.toprow.tb_metadata_input,
+ ui.toprow.cb_metadata_overwrite,
+ ui.toprow.cb_metadata_as_caption,
+ ui.toprow.cb_metadata_use_fullpath
]
components_filter = \
[ui.filter_by_tags.tag_filter_ui.cb_prefix, ui.filter_by_tags.tag_filter_ui.cb_suffix, ui.filter_by_tags.tag_filter_ui.cb_regex, ui.filter_by_tags.tag_filter_ui.rb_sort_by, ui.filter_by_tags.tag_filter_ui.rb_sort_order, ui.filter_by_tags.tag_filter_ui.rb_logic] +\
diff --git a/scripts/tag_editor_ui/block_load_dataset.py b/scripts/tag_editor_ui/block_load_dataset.py
index 9437a4d..292d7e6 100644
--- a/scripts/tag_editor_ui/block_load_dataset.py
+++ b/scripts/tag_editor_ui/block_load_dataset.py
@@ -43,6 +43,9 @@ class LoadDatasetUI(UIBase):
with gr.Row():
self.cb_use_custom_threshold_booru = gr.Checkbox(value=cfg_general.use_custom_threshold_booru, label='Use Custom Threshold (Booru)', interactive=True)
self.sl_custom_threshold_booru = gr.Slider(minimum=0, maximum=1, value=cfg_general.custom_threshold_booru, step=0.01, interactive=True, label='Booru Score Threshold')
+ with gr.Row():
+ self.cb_use_custom_threshold_e621 = gr.Checkbox(value=cfg_general.use_custom_threshold_e621, label='Use Custom Threshold (E621)', interactive=True)
+ self.sl_custom_threshold_e621 = gr.Slider(minimum=0, maximum=1, value=cfg_general.custom_threshold_e621, step=0.01, interactive=True, label='E621 Score Threshold')
with gr.Row():
self.cb_use_custom_threshold_waifu = gr.Checkbox(value=cfg_general.use_custom_threshold_waifu, label='Use Custom Threshold (WDv1.4 Tagger)', interactive=True)
self.sl_custom_threshold_waifu = gr.Slider(minimum=0, maximum=1, value=cfg_general.custom_threshold_waifu, step=0.01, interactive=True, label='WDv1.4 Tagger Score Threshold')
@@ -58,6 +61,8 @@ class LoadDatasetUI(UIBase):
use_interrogator_names, #: List[str], : to avoid error on gradio v3.23.0
use_custom_threshold_booru: bool,
custom_threshold_booru: float,
+ use_custom_threshold_e621: bool,
+ custom_threshold_e621: float,
use_custom_threshold_waifu: bool,
custom_threshold_waifu: float,
use_kohya_metadata: bool,
@@ -75,9 +80,10 @@ class LoadDatasetUI(UIBase):
interrogate_method = InterrogateMethod.APPEND
threshold_booru = custom_threshold_booru if use_custom_threshold_booru else shared.opts.interrogate_deepbooru_score_threshold
+ threshold_e621 = custom_threshold_e621 if use_custom_threshold_e621 else -1
threshold_waifu = custom_threshold_waifu if use_custom_threshold_waifu else -1
- dte_instance.load_dataset(dir, caption_file_ext, recursive, load_caption_from_filename, replace_new_line, interrogate_method, use_interrogator_names, threshold_booru, threshold_waifu, opts.dataset_editor_use_temp_files, kohya_json_path if use_kohya_metadata else None, opts.dataset_editor_max_res)
+ dte_instance.load_dataset(dir, caption_file_ext, recursive, load_caption_from_filename, replace_new_line, interrogate_method, use_interrogator_names, threshold_booru, threshold_e621, threshold_waifu, opts.dataset_editor_use_temp_files, kohya_json_path if use_kohya_metadata else None, opts.dataset_editor_max_res)
imgs = dte_instance.get_filtered_imgs(filters=[])
img_indices = dte_instance.get_filtered_imgindices(filters=[])
return [
@@ -90,7 +96,23 @@ class LoadDatasetUI(UIBase):
self.btn_load_datasets.click(
fn=load_files_from_dir,
- inputs=[self.tb_img_directory, self.tb_caption_file_ext, self.cb_load_recursive, self.cb_load_caption_from_filename, self.cb_replace_new_line_with_comma, self.rb_use_interrogator, self.dd_intterogator_names, self.cb_use_custom_threshold_booru, self.sl_custom_threshold_booru, self.cb_use_custom_threshold_waifu, self.sl_custom_threshold_waifu, toprow.cb_save_kohya_metadata, toprow.tb_metadata_output],
+ inputs=[
+ self.tb_img_directory,
+ self.tb_caption_file_ext,
+ self.cb_load_recursive,
+ self.cb_load_caption_from_filename,
+ self.cb_replace_new_line_with_comma,
+ self.rb_use_interrogator,
+ self.dd_intterogator_names,
+ self.cb_use_custom_threshold_booru,
+ self.sl_custom_threshold_booru,
+ self.cb_use_custom_threshold_e621,
+ self.sl_custom_threshold_e621,
+ self.cb_use_custom_threshold_waifu,
+ self.sl_custom_threshold_waifu,
+ toprow.cb_save_kohya_metadata,
+ toprow.tb_metadata_output,
+ ],
outputs=
[dataset_gallery.gl_dataset_images, filter_by_selection.gl_filter_images] +
[dataset_gallery.cbg_hidden_dataset_filter, dataset_gallery.nb_hidden_dataset_filter_apply] +
diff --git a/scripts/tag_editor_ui/tab_edit_caption_of_selected_image.py b/scripts/tag_editor_ui/tab_edit_caption_of_selected_image.py
index 47d493d..e57d2cb 100644
--- a/scripts/tag_editor_ui/tab_edit_caption_of_selected_image.py
+++ b/scripts/tag_editor_ui/tab_edit_caption_of_selected_image.py
@@ -133,16 +133,33 @@ class EditCaptionOfSelectedImageUI(UIBase):
outputs=[self.tb_edit_caption]
)
- def interrogate_selected_image(interrogator_name: str, use_threshold_booru: bool, threshold_booru: float, use_threshold_waifu: bool, threshold_waifu: float):
+ def interrogate_selected_image(
+ interrogator_name: str,
+ use_threshold_booru: bool,
+ threshold_booru: float,
+ use_threshold_e621: bool,
+ threshold_e621: float,
+ use_threshold_waifu: bool,
+ threshold_waifu: float,
+ ):
if not interrogator_name:
return ''
threshold_booru = threshold_booru if use_threshold_booru else shared.opts.interrogate_deepbooru_score_threshold
+ threshold_e621 = threshold_e621 if use_threshold_e621 else -1
threshold_waifu = threshold_waifu if use_threshold_waifu else -1
- return dte_module.interrogate_image(dataset_gallery.selected_path, interrogator_name, threshold_booru, threshold_waifu)
+ return dte_module.interrogate_image(dataset_gallery.selected_path, interrogator_name, threshold_booru, threshold_e621, threshold_waifu)
self.btn_interrogate_si.click(
fn=interrogate_selected_image,
- inputs=[self.dd_intterogator_names_si, load_dataset.cb_use_custom_threshold_booru, load_dataset.sl_custom_threshold_booru, load_dataset.cb_use_custom_threshold_waifu, load_dataset.sl_custom_threshold_waifu],
+ inputs=[
+ self.dd_intterogator_names_si,
+ load_dataset.cb_use_custom_threshold_booru,
+ load_dataset.sl_custom_threshold_booru,
+ load_dataset.cb_use_custom_threshold_e621,
+ load_dataset.sl_custom_threshold_e621,
+ load_dataset.cb_use_custom_threshold_waifu,
+ load_dataset.sl_custom_threshold_waifu,
+ ],
outputs=[self.tb_interrogate]
)
implemented in #93
Would be amazing to have the E621 interrogator included in your dataset tagger!
Source: https://discord.gg/8NfFcZhnWP AI > Other Models > [WIP] E621 Convnext Image tagger.
It's a beast, developed with the help of the original Waifu Upscaler dev.
Had a look at your Integration of WD and it's above my python skills to do it myself