takoyaro / db-storage1111

automatic1111's stable-diffusion-webui extension for storing images to a database
49 stars 19 forks source link

Maybe an improvement of the code? #1

Closed Nacurutu closed 1 year ago

Nacurutu commented 1 year ago

As you said in the readme: I barely write any python, I have no doubt this extension could be improved and optimized. Feel free to submit PRs!

And I also don't know anything about coding, I had a little chat with Chat GPT... the results:

Thanks for sharing the code! Here are a few potential optimizations:

Move the imports to the top of the file: It is a best practice to put all imports at the beginning of the file, so that it is easier to understand the dependencies of the code.

Use f-strings instead of string concatenation: In the get_mongo_client function, instead of concatenating strings with + operator, you can use f-strings which provide a more concise and readable syntax.

Use a context manager to open the BytesIO object: Instead of explicitly closing the BytesIO object, you can use a context manager (with statement) to ensure that the object is closed after the block of code is executed.

Here's the optimized code:


python
Copy code
import base64
import os
import re
from io import BytesIO
from pymongo import MongoClient
import gradio as gr

mongo_host = os.getenv('DB_HOST', 'localhost')
mongo_port = int(os.getenv('DB_PORT', '27017'))
mongo_username = os.getenv('DB_USER', '')
mongo_password = os.getenv('DB_PASS', '')

def get_mongo_client(database_name, collection_name):
    dburl = f"mongodb://{mongo_username}:{mongo_password}@{mongo_host}:{mongo_port}"
    client = MongoClient(dburl)
    db = client[database_name]
    collection = db[collection_name]
    return collection

class Scripts(gr.Interface):
    def title(self):
        return "Mongo Storage"

    def show(self, is_img2img):
        return gr.outputs.AlwaysVisible

    def ui(self, is_img2img):
        checkbox_save_to_db = gr.inputs.Checkbox(label="Save to DB", default=False)
        database_name = gr.inputs.Textbox(label="Database Name", default="StableDiffusion")
        collection_name = gr.inputs.Textbox(label="Collection Name", default="Automatic1111")
        return [checkbox_save_to_db, database_name, collection_name]

    def postprocess(self, p, processed, checkbox_save_to_db, database_name, collection_name):
        proc = processed
        # Upload to Mongo
        if checkbox_save_to_db:
            collection = get_mongo_client(database_name, collection_name)
        for i, image in enumerate(proc.images):
            regex = r"Steps:.*$"
            seed = proc.seed
            prompt = proc.prompt
            neg_prompt = proc.negative_prompt
            info = re.findall(regex, proc.info, re.M)[0]
            buffer = BytesIO()
            with buffer:
                image.save(buffer, "png")
                img_bytes = buffer.getvalue()
                img_b64 = base64.b64encode(img_bytes)

            input_dict = dict(item.split(": ") for item in str(info).split(", "))
            steps = input_dict["Steps"]
            seed = input_dict["Seed"]
            sampler = input_dict["Sampler"]
            cfg_scale = input_dict["CFG scale"]
            size = tuple(map(int, input_dict["Size"].split("x")))
            model_hash = input_dict["Model hash"]
            model = input_dict["Model"]

            if checkbox_save_to_db:
                collection.insert_one({
                    "prompt": prompt,
                    "negative_prompt": neg_prompt,
                    "steps": int(steps),
                    "seed": int(seed),
                    "sampler": sampler,
                    "cfg_scale": float(cfg_scale),
                    "size": size,
                    "model_hash": model_hash,
                    "model": model,
                    "image": img_b64
                })
        return True

I hope these optimizations help! Let me know if you have any questions.

can you optimize the code more?

Here are some suggestions to optimize the code further:

Move the database connection code outside the get_mongo_client() function and create a single instance of the MongoClient object. This will reduce the number of database connections required.

Use the bulk_write() method instead of inserting each document individually. This can significantly improve performance when inserting large amounts of data.

Use a generator expression instead of a for loop when converting the info string to a dictionary.

Use a context manager when opening the BytesIO object to avoid resource leaks.

Here's the optimized code:

python
Copy code
import base64
import os
import re
from io import BytesIO

import gradio as gr
import pymongo

mongo_host = os.getenv("DB_HOST", "localhost")
mongo_port = int(os.getenv("DB_PORT", 27017))
mongo_username = os.getenv("DB_USER", "")
mongo_password = os.getenv("DB_PASS", "")

mongo_client = pymongo.MongoClient(
    f"mongodb://{mongo_username}:{mongo_password}@{mongo_host}:{mongo_port}"
)

def get_mongo_collection(database_name, collection_name):
    db = mongo_client[database_name]
    return db[collection_name]

class Scripts:
    def title(self):
        return "Mongo Storage"

    def show(self, is_img2img):
        return scripts.AlwaysVisible

    def ui(self, is_img2img):
        checkbox_save_to_db = gr.inputs.Checkbox(label="Save to DB", default=False)
        database_name = gr.inputs.Textbox(label="Database Name", default="StableDiffusion")
        collection_name = gr.inputs.Textbox(label="Collection Name", default="Automatic1111")
        return [checkbox_save_to_db, database_name, collection_name]

    def postprocess(self, p, processed, checkbox_save_to_db, database_name, collection_name):
        proc = processed
        if checkbox_save_to_db:
            collection = get_mongo_collection(database_name, collection_name)
            bulk_operations = []
        for i, image in enumerate(proc.images):
            regex = r"Steps:.*$"
            seed = proc.seed
            prompt = proc.prompt
            neg_prompt = proc.negative_prompt
            info = re.findall(regex, proc.info, re.M)[0]

            with BytesIO() as buffer:
                image.save(buffer, "png")
                img_bytes = buffer.getvalue()
                img_b64 = base64.b64encode(img_bytes)

            input_dict = dict(item.split(": ") for item in str(info).split(", "))
            steps, seed, sampler, cfg_scale, size, model_hash, model = (
                input_dict["Steps"],
                input_dict["Seed"],
                input_dict["Sampler"],
                input_dict["CFG scale"],
                tuple(map(int, input_dict["Size"].split("x"))),
                input_dict["Model hash"],
                input_dict["Model"],
            )

            if checkbox_save_to_db:
                bulk_operations.append(
                    pymongo.InsertOne(
                        {
                            "prompt": prompt,
                            "negative_prompt": neg_prompt,
                            "steps": int(steps),
                            "seed": int(seed),
                            "sampler": sampler,
                            "cfg_scale": float(cfg_scale),
                            "size": size,
                            "model_hash": model_hash,
                            "model": model,
                            "image": img_b64,
                        }
                    )
                )

        if checkbox_save_to_db:
            collection.bulk_write(bulk_operations)

        return True

Note: I couldn't see the definition of the AlwaysVisible class, so I just left it as scripts.AlwaysVisible in the show() method. You may need to adjust that line if necessary.

can you optimize the code further?

Here are a few further optimizations that could be applied to the code:

Use dict.get() method to simplify environment variable retrieval Instead of using a series of if statements to check for the presence of environment variables and set default values, you can simplify this logic by using the dict.get() method, like this:

lua Copy code mongo_host = os.getenv('DB_HOST', 'localhost') mongo_port = int(os.getenv('DB_PORT', '27017')) mongo_username = os.getenv('DB_USER', '') mongo_password = os.getenv('DB_PASS', '') This code does the same thing as your existing code, but in a more concise way.

Use with statement to handle MongoDB client connection Instead of creating a MongoClient object explicitly and closing the connection manually, you can use the with statement to handle the connection and ensure that it is closed automatically when the block is exited. Here's an example of how this could be done:

python Copy code def get_mongo_client(database_name, collection_name): from pymongo import MongoClient dburl = f"mongodb://{mongo_username}:{mongo_password}@{mongo_host}:{mongo_port}" with MongoClient(dburl) as client: db = client[database_name] collection = db[collection_name] return collection This code will handle the connection to MongoDB more efficiently, and reduce the likelihood of resource leaks.

Use map() to simplify processing of images Instead of using a for loop to process each image in the processed.images list, you can use the map() function to apply the same processing function to each element of the list. Here's an example of how this could be done:

css
Copy code
if checkbox_save_to_db:
    collection = get_mongo_client(database_name, collection_name)

    def process_image(image):
        regex = r"Steps:.*$"
        info = re.findall(regex, proc.info, re.M)[0]

        input_dict = dict(item.split(": ") for item in str(info).split(", "))
        steps = input_dict.get("Steps")
        seed = input_dict.get("Seed")
        sampler = input_dict.get("Sampler")
        cfg_scale = input_dict.get("CFG scale")
        size = tuple(map(int, input_dict.get("Size", "").split("x")))
        model_hash = input_dict.get("Model hash")
        model = input_dict.get("Model")

        buffer = BytesIO()
        image.save(buffer, "png")
        img_bytes = buffer.getvalue()
        img_b64 = base64.b64encode(img_bytes)

        collection.insert_one({
            "prompt": proc.prompt,
            "negative_prompt": proc.negative_prompt,
            "steps": int(steps),
            "seed": int(seed),
            "sampler": sampler,
            "cfg_scale": float(cfg_scale),
            "size": size,
            "model_hash": model_hash,
            "model": model,
            "image": img_b64
        })

    list(map(process_image, processed.images))

This code uses a separate function process_image() to process each image, and then applies this function to each element of the processed.images list using the map() function. This can help to simplify the code and make it more efficient.

keep optimizing it, please

Sure! Here are some additional optimizations:

Use list comprehension instead of for loop to generate input_dict.

Use enumerate instead of range(len(proc.images)) to loop through proc.images and get both the index and image.

Use pil_image.tobytes() instead of creating a BytesIO object to get the image bytes.

Use a dictionary comprehension to generate the document to be inserted into the MongoDB collection.

Here's the updated code:

python
Copy code
def postprocess(self, p, processed, checkbox_save_to_db, database_name, collection_name):
    proc = processed
    # Upload to Mongo
    if checkbox_save_to_db:
        collection = get_mongo_client(database_name, collection_name)

    for i, image in enumerate(proc.images):
        regex = r"Steps:.*$"
        seed = proc.seed
        prompt = proc.prompt
        neg_prompt = proc.negative_prompt
        info = re.findall(regex, proc.info, re.M)[0]

        input_dict = {k: v for k, v in (item.split(": ") for item in str(info).split(", "))}
        steps = input_dict["Steps"]
        seed = input_dict["Seed"]
        sampler = input_dict["Sampler"]
        cfg_scale = input_dict["CFG scale"]
        size = tuple(map(int, input_dict["Size"].split("x")))
        model_hash = input_dict["Model hash"]
        model = input_dict["Model"]

        img_bytes = image.tobytes()
        img_b64 = base64.b64encode(img_bytes)

        if checkbox_save_to_db:
            doc = {
                "prompt": prompt,
                "negative_prompt": neg_prompt,
                "steps": int(steps),
                "seed": int(seed),
                "sampler": sampler,
                "cfg_scale": float(cfg_scale),
                "size": size,
                "model_hash": model_hash,
                "model": model,
                "image": img_b64
            }
            collection.insert_one(doc)

    return True

I hope this helps!

can you optimize it more?

Certainly! Here are a few more optimizations that can be made:

Instead of calling get_mongo_client for every image, create a single instance of the MongoClient and reuse it.

Use pil_image.format instead of assuming that the format is PNG.

Use int and float directly on input_dict instead of converting them later.

Use pil_image.save to save the image directly to a BytesIO object.

Here's the updated code:


python
Copy code
def postprocess(self, p, processed, checkbox_save_to_db, database_name, collection_name):
    proc = processed
    # Upload to Mongo
    client = None
    if checkbox_save_to_db:
        client = MongoClient(f"mongodb://{mongo_username}:{mongo_password}@{mongo_host}:{mongo_port}")
        collection = client[database_name][collection_name]

    for i, image in enumerate(proc.images):
        regex = r"Steps:.*$"
        seed = proc.seed
        prompt = proc.prompt
        neg_prompt = proc.negative_prompt
        info = re.findall(regex, proc.info, re.M)[0]

        input_dict = dict(item.split(": ") for item in str(info).split(", "))
        steps = int(input_dict["Steps"])
        seed = int(input_dict["Seed"])
        sampler = input_dict["Sampler"]
        cfg_scale = float(input_dict["CFG scale"])
        size = tuple(map(int, input_dict["Size"].split("x")))
        model_hash = input_dict["Model hash"]
        model = input_dict["Model"]

        buffer = BytesIO()
        image.save(buffer, format=image.format)
        img_bytes = buffer.getvalue()
        img_b64 = base64.b64encode(img_bytes)

        if client:
            doc = {
                "prompt": prompt,
                "negative_prompt": neg_prompt,
                "steps": steps,
                "seed": seed,
                "sampler": sampler,
                "cfg_scale": cfg_scale,
                "size": size,
                "model_hash": model_hash,
                "model": model,
                "image": img_b64
            }
            collection.insert_one(doc)

    if client:
        client.close()

    return True

I hope this helps!

Can you optimize it to make it as fast and efficient as possible?

I have already provided several optimizations that can improve the performance and efficiency of the code. However, further optimization depends on the specific use case and requirements. Here are some general tips that could help optimize the code further:

Use a database connection pool: A connection pool allows you to reuse existing database connections instead of creating a new connection for each request. This can significantly reduce the overhead of creating a new connection every time you need to interact with the database.

Use asynchronous programming: If you are using Python 3.7 or higher, you can take advantage of the asyncio library to write asynchronous code. This allows you to perform I/O operations in a non-blocking manner, which can improve the performance of your code.

Use a faster serialization format: Base64 encoding can be slow, especially for large files. Consider using a more efficient serialization format, such as MessagePack or Protocol Buffers.

Optimize image compression: Depending on the use case, you may be able to further optimize the compression of the images. For example, if the images are mostly solid colors, you can use a lossless compression algorithm like PNG. If the images are more complex, you can use a lossy compression algorithm like JPEG.

Use a faster image library: If the images are large or complex, consider using a faster image processing library, such as OpenCV or Pillow-SIMD.

Cache frequently accessed data: If you frequently access the same data, consider caching it in memory to reduce the number of database queries. This can be especially helpful if the data changes infrequently.

Optimize database queries: Make sure your database queries are optimized and use appropriate indexes. Use explain queries to identify slow queries and optimize them.

These are just a few suggestions that can help optimize your code further. However, keep in mind that optimization is a tradeoff between performance and simplicity, and that the best approach depends on your specific requirements and use case.

takoyaro commented 1 year ago

I think the last sentence sums it up pretty well

[...]keep in mind that optimization is a tradeoff between performance and simplicity [...]

I can't believe I haven't even thought of at least taking a look at what ChatGPT would have to offer. Thank you for your input, will see what's worth optimizing !

takoyaro commented 1 year ago

Optimized the code in https://github.com/takoyaro/db-storage1111/commit/8c706d2feaa5957f3d4a60e878fc0cc4a08e0f73