cumulo-autumn / StreamDiffusion

StreamDiffusion: A Pipeline-Level Solution for Real-Time Interactive Generation
Apache License 2.0
9.47k stars 672 forks source link

Loopback? #65

Open SMUsamaShah opened 8 months ago

SMUsamaShah commented 8 months ago

If I put the generated image in capture area in screen example, there is a noticable gap of 3-4 frames between capture and output. How do I go about making sure that capture is done after the image is put on screen such that it creates a loopback and iteratively improves the image?

frame_buffer_size is 1 already.

SMUsamaShah commented 8 months ago

Made a simple web server using image-to-image example in the main readme (https://github.com/cumulo-autumn/StreamDiffusion?tab=readme-ov-file#image-to-image). Sending image via canvas is not fast and all i get in return is a brown image.

from http.server import BaseHTTPRequestHandler, SimpleHTTPRequestHandler, HTTPServer
from PIL import Image
from io import BytesIO
import base64
import json

# Import the necessary classes and functions from the snippet
import torch
from diffusers import AutoencoderTiny, StableDiffusionPipeline
from diffusers.utils import load_image
from streamdiffusion import StreamDiffusion
from streamdiffusion.image_utils import postprocess_image

class StreamDiffusionServer(SimpleHTTPRequestHandler):
    def do_GET(self):
        if self.path == '/':
            self.path = '/web/index.html'
        return super().do_GET()

    def do_POST(self):
        if self.path != '/process_image':
            self.send_response(404)
            return

        content_length = int(self.headers['Content-Length'])
        post_data = self.rfile.read(content_length)
        data = json.loads(post_data)

        img_data = base64.b64decode(data['image'].split(',')[1])
        image = Image.open(BytesIO(img_data)).convert('RGB')  # .convert('RGB') added here

        # Prepare the stream
        stream.prepare("1girl with dog hair, thick frame glasses")
        x_output = stream(image)
        proc_image = postprocess_image(x_output, output_type="pil")[0]

        buffer = BytesIO()
        proc_image.save(buffer, format="PNG")
        img_str = base64.b64encode(buffer.getvalue()).decode()

        self.send_response(200)
        self.send_header('Content-type', 'application/json')
        self.end_headers()
        self.wfile.write(json.dumps({'image': f'data:image/png;base64,{img_str}'}).encode())

if __name__ == '__main__':
    # Your model setup here
    pipe = StableDiffusionPipeline.from_pretrained("KBlueLeaf/kohaku-v2.1").to(
        device=torch.device("cuda"),
        dtype=torch.float16,
    )

    stream = StreamDiffusion(
        pipe,
        t_index_list=[32, 45],
        torch_dtype=torch.float16,
    )

    stream.load_lcm_lora()
    stream.fuse_lora()
    stream.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd").to(device=pipe.device, dtype=pipe.dtype)
    pipe.enable_xformers_memory_efficient_attention()

    server_address = ('', 8000)
    httpd = HTTPServer(server_address, StreamDiffusionServer)
    print('running server...')
    try:
        httpd.serve_forever()
    except KeyboardInterrupt:
        pass
    httpd.server_close()
    print('\nstopped server')

Client side view

<!DOCTYPE html>
<html>
  <body>
    <canvas id="myCanvas" width="512" height="512" style="border:1px solid #d3d3d3;">
      Your browser does not support the HTML5 canvas tag.
    </canvas>
    <button onclick="sendImage()">Process Image</button>

    <script>
      const canvas = document.getElementById("myCanvas");
      const context = canvas.getContext("2d");

      let isDrawing = false;

      canvas.addEventListener('mousedown', (event) => {
        isDrawing = true;
        drawLine(context, event.pageX - canvas.offsetLeft, event.pageY - canvas.offsetTop, false);
      });

      canvas.addEventListener('mousemove', (event) => {
        if (isDrawing) {
          drawLine(context, event.pageX - canvas.offsetLeft, event.pageY - canvas.offsetTop, true);
        }
      });

      canvas.addEventListener('mouseup', () => {
        isDrawing = false;
      });

      function drawLine(context, x, y, isDrawingLine) {
        if (isDrawingLine) {
          context.lineTo(x, y);
          context.stroke();
        } else {
          context.beginPath();
          context.moveTo(x, y);
        }
      };

      function sendImage() {
        const data = canvas.toDataURL();
        fetch("http://localhost:8000/process_image", {
            method: "POST",
            body: JSON.stringify({ image: data }), //`image=${encodeURIComponent(data)}`,
            headers: { "Content-Type": "application/x-www-form-urlencoded" }
        })
        .then(res => res.json())
        .then(data => {
            const image = new Image();
            image.src = data.image;
            image.onload = function () {
              context.clearRect(0, 0, canvas.width, canvas.height);  // clear the current drawing
              context.drawImage(image, 0, 0, canvas.width, canvas.height); // draw the received image
            }
        });
      }
    </script>
  </body>
</html>

All I get back is this and not as fast as screen example and other examples image

SMUsamaShah commented 8 months ago

Was doing stream.prepare on every request instead of doing once.

Now the server returns result of previously sent input image instead of current one. What am I doing wrong here? Made this web server to have more control over input image because i was seeing the same behaviour in screen example.

Looks like the lag is based on length of t_index_list.

cumulo-autumn commented 8 months ago

@SMUsamaShah Stream Batch creates a denoising batch for each number of denoising steps and processes them in batches as shown in the following figures. This means an input image will be outputted after n-1 frames, where n is the number of denoising steps. This Stream Batch enables high throughput. Additionally, at the start of image generation, Stream Batch is filled with zero tensors, so the first n-1 images generated will be brown, as mentioned above. It is necessary to sweep out all these initial brown images before getting the generated image.

prepare() function also initializes all tensors in the Stream Batch to zero tensors. Therefore, if you wish to update your prompt, it is recommended that you use the update_prompt() function instead of the prepare() function.

Considering these properties, if you want to do a loopback, please keep in mind that it will be offset by n-1 steps. Additionally, setting use_denoising_batch to False allows image generation using the traditional method without Stream Batch, but this is not yet fully supported, so its functionality is not guaranteed.

batch_denoising_concept system_concept

SMUsamaShah commented 8 months ago

Thanks for explaining. The diagram makes it even easier to understand what is going on. This means I can't do loop back the way I imagined. Doing it periodically after n images, instead of every single image, will make more sense.