Haidra-Org / AI-Horde-CLI

python script integrating with the AI Horde via a command line interface
GNU Affero General Public License v3.0
56 stars 15 forks source link

Add progress report #8

Closed mak448a closed 1 year ago

mak448a commented 1 year ago

Please add a percentage bar like Stable UI. Thanks!

scenaristeur commented 1 year ago

@mak448a i'm agree with you, it is frustrating as a newbie to not know the progress of the request . You can first set default verbosity to 3 in each script : at https://github.com/Haidra-Org/AI-Horde-CLI/blob/b6dcf009ca6a691a9757cba3a4571f8ce75e8dbd/cli_request_dream.py#L20 Perharps @db0 it could be nice to set it by default it gives us something like

INFO       | 2023-10-07 13:48:54 | __main__:generate:141 - {'finished': 0, 'processing': 0, 'restarted': 0, 'waiting': 5, 'done': False, 'faulted': False, 'wait_time': 192, 'queue_position': 355, 'kudos': 1.0, 'is_possible': True}
INFO       | 2023-10-07 13:48:55 | __main__:generate:141 - {'finished': 0, 'processing': 0, 'restarted': 0, 'waiting': 5, 'done': False, 'faulted': False, 'wait_time': 191, 'queue_position': 351, 'kudos': 1.0, 'is_possible': True}
INFO       | 2023-10-07 13:48:56 | __main__:generate:141 - {'finished': 0, 'processing': 0, 'restarted': 0, 'waiting': 5, 'done': False, 'faulted': False, 'wait_time': 190, 'queue_position': 350, 'kudos': 1.0, 'is_possible': True}
db0 commented 1 year ago

You can just call the cli with -vvv and it will show you this info. I'll see if I can give you a tqdm progress bar

scenaristeur commented 1 year ago

I've just worked on it ;-)

import requests
import json
import os
import time
import argparse
import base64
import yaml
import sys
from omegaconf import OmegaConf

from cli_logger import logger, set_logger_verbosity, quiesce_logger, test_logger
from PIL import Image
from io import BytesIO
from requests.exceptions import ConnectionError
from tqdm import tqdm

arg_parser = argparse.ArgumentParser()
arg_parser.add_argument('-n', '--amount', action="store", required=False,
                        type=int, help="The amount of images to generate with this prompt")
arg_parser.add_argument("-m", '--model', action="store", default="stable_diffusion", required=False,
                        type=str, help="Generalist AI image generating model. The baseline for all finetuned models")
arg_parser.add_argument('-p', '--prompt', action="store", required=False,
                        type=str, help="The prompt with which to generate images")
arg_parser.add_argument('-w', '--width', action="store", required=False, type=int,
                        help="The width of the image to generate. Has to be a multiple of 64")
arg_parser.add_argument('-l', '--height', action="store", required=False, type=int,
                        help="The height of the image to generate. Has to be a multiple of 64")
arg_parser.add_argument('-s', '--steps', action="store", required=False,
                        type=int, help="The amount of steps to use for this generation")
arg_parser.add_argument('--api_key', type=str, action='store', required=False,
                        help="The API Key to use to authenticate on the Horde. Get one in https://aihorde.net/register")
arg_parser.add_argument('-f', '--filename', type=str, action='store', required=False,
                        help="The filename to use to save the images. If more than 1 image is generated, the number of generation will be prepended")
arg_parser.add_argument('-v', '--verbosity', action='count', default=0,
                        help="The default logging level is ERROR or higher. This value increases the amount of logging seen in your screen")
arg_parser.add_argument('-q', '--quiet', action='count', default=0,
                        help="The default logging level is ERROR or higher. This value decreases the amount of logging seen in your screen")
arg_parser.add_argument('--horde', action="store", required=False,
                        type=str, default="https://aihorde.net", help="Use a different horde")
arg_parser.add_argument('--nsfw', action="store_true", default=False, required=False,
                        help="Mark the request as NSFW. Only servers which allow NSFW will pick it up")
arg_parser.add_argument('--censor_nsfw', action="store_true", default=False, required=False,
                        help="If the request is SFW, and the worker accidentaly generates NSFW, it will send back a censored image.")
arg_parser.add_argument('--trusted_workers', action="store_true", default=False,
                        required=False, help="If true, the request will be sent only to trusted workers.")
arg_parser.add_argument('--source_image', action="store", required=False, type=str,
                        help="When a file path is provided, will be used as the source for img2img")
arg_parser.add_argument('--source_processing', action="store", required=False,
                        type=str, help="Can either be img2img, inpainting, or outpainting")
arg_parser.add_argument('--source_mask', action="store", required=False, type=str,
                        help="When a file path is provided, will be used as the mask source for inpainting/outpainting")
arg_parser.add_argument('--dry_run', action="store_true", default=False, required=False,
                        help="If true, The request will only print the amount of kudos the payload would spend, and exit.")
arg_parser.add_argument('--yml_file', action="store", default="cliRequestsData_Dream.yml",
                        required=False, help="Overrides the default yml, CLI arguments still have priority.")
args = arg_parser.parse_args()

class RequestData(object):
    def __init__(self):
        self.client_agent = "cli_request_dream.py:1.1.0:(discord)db0#1625"
        self.api_key = "0000000000"
        self.filename = "witch_dream.png"
        self.imgen_params = {
            "n": 2,
            "width": 64*8,
            "height": 64*8,
            "steps": 20,
            "sampler_name": "k_euler_a",
            "cfg_scale": 7.5,
            "denoising_strength": 0.6,
        }
        self.submit_dict = {
            "prompt": "hyperrealistic HDR render, rpg evil witch, witch robes, pointy hat, wand, ",
            "nsfw": False,
            "censor_nsfw": False,
            "trusted_workers": False,
            "models": ["stable_diffusion"],
            "r2": True,
            "dry_run": False
        }
        self.source_image = None
        self.source_processing = "img2img"
        self.source_mask = None

    def get_submit_dict(self):
        submit_dict = self.submit_dict.copy()
        submit_dict["params"] = self.imgen_params
        submit_dict["source_processing"] = self.source_processing
        if self.source_image:
            final_src_img = Image.open(self.source_image)
            buffer = BytesIO()
            # We send as WebP to avoid using all the horde bandwidth
            final_src_img.save(buffer, format="Webp", quality=95, exact=True)
            submit_dict["source_image"] = base64.b64encode(
                buffer.getvalue()).decode("utf8")
        if self.source_mask:
            final_src_mask = Image.open(self.source_mask)
            buffer = BytesIO()
            # We send as WebP to avoid using all the horde bandwidth
            final_src_mask.save(buffer, format="Webp", quality=95, exact=True)
            submit_dict["source_mask"] = base64.b64encode(
                buffer.getvalue()).decode("utf8")
        return (submit_dict)

def load_request_data():
    request_data = RequestData()
    if os.path.exists(args.yml_file):
        with open(args.yml_file, "rt", encoding="utf-8", errors="ignore") as configfile:
            config = yaml.safe_load(configfile)
            for key, value in config.items():
                setattr(request_data, key, value)
    if os.path.exists("special.yml"):
        special = OmegaConf.load("special.yml")
        request_data.imgen_params["special"] = OmegaConf.to_container(
            special, resolve=True)
    if os.path.exists("special.json"):
        special = OmegaConf.load("special.json")
        request_data.imgen_params["special"] = OmegaConf.to_container(
            special, resolve=True)
    if args.api_key:
        request_data.api_key = args.api_key
    if args.filename:
        request_data.filename = args.filename
    if args.amount:
        request_data.imgen_params["n"] = args.amount
    if args.width:
        request_data.imgen_params["width"] = args.width
    if args.height:
        request_data.imgen_params["height"] = args.height
    if args.steps:
        request_data.imgen_params["steps"] = args.steps
    if args.prompt:
        request_data.submit_dict["prompt"] = args.prompt
    if args.nsfw:
        request_data.submit_dict["nsfw"] = args.nsfw
    if args.censor_nsfw:
        request_data.submit_dict["censor_nsfw"] = args.censor_nsfw
    if args.trusted_workers:
        request_data.submit_dict["trusted_workers"] = args.trusted_workers
    if args.source_image:
        request_data.source_image = args.source_image
    if args.source_processing:
        request_data.source_processing = args.source_processing
    if args.source_mask:
        request_data.source_mask = args.source_mask
    if args.dry_run:
        request_data.submit_dict["dry_run"] = args.dry_run
    return (request_data)

@logger.catch(reraise=True)
def generate():
    request_data = load_request_data()
    # final_submit_dict["source_image"] = 'Test'
    pbar_queue_position = tqdm(total=1000, desc="queue position")
    pbar_wait_time = tqdm(total=100, desc="wait time")
    pbar_waiting = tqdm(
        total=request_data.imgen_params.get('n'), desc="waiting")
    pbar_restarted = tqdm(
        total=request_data.imgen_params.get('n'), desc="restarted")
    pbar_processing = tqdm(
        total=request_data.imgen_params.get('n'), desc="processing")
    pbar_finished = tqdm(
        total=request_data.imgen_params.get('n'), desc="finished")

    headers = {
        "apikey": request_data.api_key,
        "Client-Agent": request_data.client_agent,
    }
    # logger.debug(request_data.get_submit_dict())
    # logger.debug(json.dumps(request_data.get_submit_dict(), indent=4))
    submit_req = requests.post(f'{args.horde}/api/v2/generate/async',
                               json=request_data.get_submit_dict(), headers=headers)
    if submit_req.ok:
        submit_results = submit_req.json()
        logger.debug(submit_results)
        req_id = submit_results.get('id')
        if not req_id:
            logger.message(submit_results)
            return
        is_done = False
        retry = 0
        cancelled = False
        try:
            while not is_done:
                try:
                    chk_req = requests.get(
                        f'{args.horde}/api/v2/generate/check/{req_id}')
                    if not chk_req.ok:
                        logger.error(chk_req.text)
                        return
                    chk_results = chk_req.json()
                    logger.info(chk_results)

                    #print(chk_results)
                    pbar_queue_position.n = chk_results.get('queue_position')
                    pbar_wait_time.n = chk_results.get('wait_time')
                    pbar_finished.n = chk_results.get('finished')
                    pbar_processing.n = chk_results.get('processing')
                    pbar_restarted.n = chk_results.get('restarted')
                    pbar_waiting.n = chk_results.get('waiting')

                    pbar_queue_position.refresh()
                    pbar_wait_time.refresh()
                    pbar_finished.refresh()
                    pbar_processing.refresh()
                    pbar_restarted.refresh()
                    pbar_waiting.refresh()

                    is_done = chk_results['done']
                    time.sleep(0.8)
                except ConnectionError as e:
                    retry += 1
                    logger.error(
                        f"Error {e} when retrieving status. Retry {retry}/10")
                    if retry < 10:
                        time.sleep(1)
                        continue
                    raise
        except KeyboardInterrupt:
            logger.info(f"Cancelling {req_id}...")
            cancelled = True
            retrieve_req = requests.delete(
                f'{args.horde}/api/v2/generate/status/{req_id}')
        if not cancelled:
            retrieve_req = requests.get(
                f'{args.horde}/api/v2/generate/status/{req_id}')
        if not retrieve_req.ok:
            logger.error(retrieve_req.text)
            return
        results_json = retrieve_req.json()
        # logger.debug(results_json)
        if results_json['faulted']:
            final_submit_dict = request_data.get_submit_dict()
            if "source_image" in final_submit_dict:
                final_submit_dict[
                    "source_image"] = f"img2img request with size: {len(final_submit_dict['source_image'])}"
            logger.error(
                f"Something went wrong when generating the request. Please contact the horde administrator with your request details: {final_submit_dict}")
            return
        results = results_json['generations']
        for iter in range(len(results)):
            final_filename = request_data.filename
            if len(results) > 1:
                final_filename = f"{iter}_{request_data.filename}"
            if request_data.get_submit_dict()["r2"]:
                logger.debug(
                    f"Downloading '{results[iter]['id']}' from {results[iter]['img']}")
                try:
                    img_data = requests.get(results[iter]["img"]).content
                except:
                    logger.error("Received b64 again")
                with open(final_filename, 'wb') as handler:
                    handler.write(img_data)
            else:
                b64img = results[iter]["img"]
                base64_bytes = b64img.encode('utf-8')
                img_bytes = base64.b64decode(base64_bytes)
                img = Image.open(BytesIO(img_bytes))
                img.save(final_filename)
            censored = ''
            if results[iter]["censored"]:
                censored = " (censored)"
            logger.generation(
                f"Saved{censored} {final_filename} for {results_json['kudos']} kudos (via {results[iter]['worker_id']})")
    else:
        logger.error(submit_req.text)

set_logger_verbosity(args.verbosity)
quiesce_logger(args.quiet)

generate()
scenaristeur commented 1 year ago

Capture d’écran du 2023-10-07 14-45-48

scenaristeur commented 1 year ago

i don't know how to change % in s for wait time. i 've stated queue_position with a max of 1000 and wait time with a max of 100

db0 commented 1 year ago

Let's comment on the implementation in your PR

scenaristeur commented 1 year ago

i've added ne to pull request.

the messsage is not clear : as it is written we can think that each image génération consumed 23 kudos, , but i thought it was 23 for the total request , not for each


MESSAGE    | Saved (censored) 0_wom_dream.png for 23.0 kudos (via f6dd5f1c-e487-459f-9d54-cb6c45bc28b7)
MESSAGE    | Saved 1_wom_dream.png for 23.0 kudos (via f6dd5f1c-e487-459f-9d54-cb6c45bc28b7)
MESSAGE    | Saved 2_wom_dream.png for 23.0 kudos (via f6dd5f1c-e487-459f-9d54-cb6c45bc28b7)
MESSAGE    | Saved 3_wom_dream.png for 23.0 kudos (via f6dd5f1c-e487-459f-9d54-cb6c45bc28b7)
db0 commented 1 year ago

Please open a new issue instead of replying to the same one