Anime4KWebBoost / Anime4K-WebGPU

Implementation of Anime4K in WebGPU
MIT License
23 stars 2 forks source link

Could you provide a complete usage example? #15

Open AlmightyHD opened 1 month ago

AlmightyHD commented 1 month ago

Hi, first of all I'd like to thank you for creating this amazing project.

Right now, I'm trying to integrate the shaders into my web player app, however I'm confused on how to actually use it. I have a lot of experience with web, but have never used webgpu, so unfortunately the "Usage" section of the README doesn't mean anything to me. I've also looked into this example, but it's also pretty complicated.

So let's say I have an HTML video element and a canvas element, and I want to draw frames of the video upscaled with a set of shaders onto the canvas. How do I do that? E.g. let's say I would like to apply "upscale CNN (x2) and upscale GAN (x3 & x4) and the Real-ESRGAN" effects to the video. Could you provide some minimal example?

plasmas commented 1 month ago

Thanks for asking! This project was initially intended to provide low-level APIs that can integrate into existing webGPU pipelines so we haven't provided any high-level abstracted components so far. Be aware that webGPU itself is low-level so may require extra effort to understand. If interested, this is a good tutorial. Also, our demo web page was adapted from this sample, which just renders a video to canvas.

For better understanding, I created the following two examples in React:

const fullscreenTexturedQuadWGSL = ` struct VertexOutput { @builtin(position) Position : vec4, @location(0) fragUV : vec2, }

@vertex fn vert_main(@builtin(vertex_index) VertexIndex : u32) -> VertexOutput { const pos = array( vec2( 1.0, 1.0), vec2( 1.0, -1.0), vec2(-1.0, -1.0), vec2( 1.0, 1.0), vec2(-1.0, -1.0), vec2(-1.0, 1.0), );

const uv = array( vec2(1.0, 0.0), vec2(1.0, 1.0), vec2(0.0, 1.0), vec2(1.0, 0.0), vec2(0.0, 1.0), vec2(0.0, 0.0), );

var output : VertexOutput; output.Position = vec4(pos[VertexIndex], 0.0, 1.0); output.fragUV = uv[VertexIndex]; return output; } ; const sampleExternalTextureWGSL = @group(0) @binding(1) var mySampler: sampler; @group(0) @binding(2) var myTexture: texture_2d;

@fragment fn main(@location(0) fragUV : vec2f) -> @location(0) vec4f { return textureSampleBaseClampToEdge(myTexture, mySampler, fragUV); } `;

const VideoCanvas: React.FC = () => { const canvasRef = useRef(null); const videoRef = useRef(null);

useEffect(() => { async function init() { // create video element const video = videoRef.current!; video.loop = true; video.autoplay = true; video.muted = true; video.src = '/OnePunchMan.mp4'; await new Promise((resolve) => { video.onloadeddata = resolve; }); await video.play(); const WIDTH = video.videoWidth; const HEIGHT = video.videoHeight;

  // configure WebGPU
  const canvas = canvasRef.current!;
  const adapter = await navigator.gpu.requestAdapter();
  const device = await adapter!.requestDevice();
  const context = canvas.getContext('webgpu') as GPUCanvasContext;
  const presentationFormat = navigator.gpu.getPreferredCanvasFormat();
  context.configure({
    device,
    format: presentationFormat,
    alphaMode: 'premultiplied',
  });

  // create texture of video
  const videoFrameTexture = device.createTexture({
    size: [WIDTH, HEIGHT, 1],
    format: 'rgba16float',
    usage: GPUTextureUsage.TEXTURE_BINDING
    | GPUTextureUsage.COPY_DST
    | GPUTextureUsage.RENDER_ATTACHMENT,
  });

  canvas.width = WIDTH;
  canvas.height = HEIGHT;

  // function to get copy new video frame into texture
  function updateVideoFrameTexture() {
    device.queue.copyExternalImageToTexture(
      { source: video },
      { texture: videoFrameTexture },
      [WIDTH, HEIGHT],
    );
  }

  // render pipeline setups
  const renderBindGroupLayout = device.createBindGroupLayout({
    label: 'Render Bind Group Layout',
    entries: [
      {
        binding: 1,
        visibility: GPUShaderStage.FRAGMENT,
        sampler: {},
      },
      {
        binding: 2,
        visibility: GPUShaderStage.FRAGMENT,
        texture: {},
      }
    ],
  });

  const renderPipelineLayout = device.createPipelineLayout({
    label: 'Render Pipeline Layout',
    bindGroupLayouts: [renderBindGroupLayout],
  });

  const renderPipeline = device.createRenderPipeline({
    layout: renderPipelineLayout,
    vertex: {
      module: device.createShaderModule({
        code: fullscreenTexturedQuadWGSL,
      }),
      entryPoint: 'vert_main',
    },
    fragment: {
      module: device.createShaderModule({
        code: sampleExternalTextureWGSL,
      }),
      entryPoint: 'main',
      targets: [
        {
          format: presentationFormat,
        },
      ],
    },
    primitive: {
      topology: 'triangle-list',
    },
  });

  const sampler = device.createSampler({
    magFilter: 'linear',
    minFilter: 'linear',
  });

  const renderBindGroup = device.createBindGroup({
    layout: renderBindGroupLayout,
    entries: [
      {
        binding: 1,
        resource: sampler,
      },
      {
        binding: 2,
        resource: videoFrameTexture.createView(),
      }
    ],
  });

  // render loop
  function frame() {
    if (!video.paused) {
      updateVideoFrameTexture();
    }
    const commandEncoder = device.createCommandEncoder();
    const passEncoder = commandEncoder.beginRenderPass({
      colorAttachments: [
        {
          view: context.getCurrentTexture().createView(),
          clearValue: {
            r: 0.0, g: 0.0, b: 0.0, a: 1.0,
          },
          loadOp: 'clear' as GPULoadOp,
          storeOp: 'store' as GPUStoreOp,
        },
      ],
    });
    passEncoder.setPipeline(renderPipeline);
    passEncoder.setBindGroup(0, renderBindGroup);
    passEncoder.draw(6);
    passEncoder.end();
    device.queue.submit([commandEncoder.finish()]);
    video.requestVideoFrameCallback(frame);
  }

  // start render loop
  video.requestVideoFrameCallback(frame);
}

init();

}, [])

return (

); };

export default VideoCanvas;


* `VideoCanvasA4K.tsx`: Using CNNx2 UL for upscale and then CNN UL for restore.
```tsx
import React, { useRef, useEffect } from 'react';
import { CNNUL, CNNx2UL } from 'anime4k-webgpu';

const fullscreenTexturedQuadWGSL = `
struct VertexOutput {
  @builtin(position) Position : vec4<f32>,
  @location(0) fragUV : vec2<f32>,
}

@vertex
fn vert_main(@builtin(vertex_index) VertexIndex : u32) -> VertexOutput {
  const pos = array(
    vec2( 1.0,  1.0),
    vec2( 1.0, -1.0),
    vec2(-1.0, -1.0),
    vec2( 1.0,  1.0),
    vec2(-1.0, -1.0),
    vec2(-1.0,  1.0),
  );

  const uv = array(
    vec2(1.0, 0.0),
    vec2(1.0, 1.0),
    vec2(0.0, 1.0),
    vec2(1.0, 0.0),
    vec2(0.0, 1.0),
    vec2(0.0, 0.0),
  );

  var output : VertexOutput;
  output.Position = vec4(pos[VertexIndex], 0.0, 1.0);
  output.fragUV = uv[VertexIndex];
  return output;
}
`;
const sampleExternalTextureWGSL = `
@group(0) @binding(1) var mySampler: sampler;
@group(0) @binding(2) var myTexture: texture_2d<f32>;

@fragment
fn main(@location(0) fragUV : vec2f) -> @location(0) vec4f {
  return textureSampleBaseClampToEdge(myTexture, mySampler, fragUV);
}
`;

const VideoCanvasA4K: React.FC = () => {
  const canvasRef = useRef<HTMLCanvasElement>(null);
  const videoRef = useRef<HTMLVideoElement>(null);

  useEffect(() => {
    async function init() {
      // create video element
      const video = videoRef.current!;
      video.loop = true;
      video.autoplay = true;
      video.muted = true;
      video.src = '/OnePunchMan.mp4';
      await new Promise((resolve) => {
        video.onloadeddata = resolve;
      });
      await video.play();
      const WIDTH = video.videoWidth;
      const HEIGHT = video.videoHeight;

      // configure WebGPU
      const canvas = canvasRef.current!;
      const adapter = await navigator.gpu.requestAdapter();
      const device = await adapter!.requestDevice();
      const context = canvas.getContext('webgpu') as GPUCanvasContext;
      const presentationFormat = navigator.gpu.getPreferredCanvasFormat();
      context.configure({
        device,
        format: presentationFormat,
        alphaMode: 'premultiplied',
      });

      // create texture of video
      const videoFrameTexture = device.createTexture({
        size: [WIDTH, HEIGHT, 1],
        format: 'rgba16float',
        usage: GPUTextureUsage.TEXTURE_BINDING
        | GPUTextureUsage.COPY_DST
        | GPUTextureUsage.RENDER_ATTACHMENT,
      });

      // ++++ Anime4K ++++
      const upscalePipeline = new CNNx2UL(device, videoFrameTexture);
      const restorePipeline = new CNNUL(device, upscalePipeline.getOutputTexture());
      canvas.width = restorePipeline.getOutputTexture().width;
      canvas.height = restorePipeline.getOutputTexture().height;
      // ++++ Anime4K ++++

      // function to get copy new video frame into texture
      function updateVideoFrameTexture() {
        device.queue.copyExternalImageToTexture(
          { source: video },
          { texture: videoFrameTexture },
          [WIDTH, HEIGHT],
        );
      }

      // render pipeline setups
      const renderBindGroupLayout = device.createBindGroupLayout({
        label: 'Render Bind Group Layout',
        entries: [
          {
            binding: 1,
            visibility: GPUShaderStage.FRAGMENT,
            sampler: {},
          },
          {
            binding: 2,
            visibility: GPUShaderStage.FRAGMENT,
            texture: {},
          }
        ],
      });

      const renderPipelineLayout = device.createPipelineLayout({
        label: 'Render Pipeline Layout',
        bindGroupLayouts: [renderBindGroupLayout],
      });

      const renderPipeline = device.createRenderPipeline({
        layout: renderPipelineLayout,
        vertex: {
          module: device.createShaderModule({
            code: fullscreenTexturedQuadWGSL,
          }),
          entryPoint: 'vert_main',
        },
        fragment: {
          module: device.createShaderModule({
            code: sampleExternalTextureWGSL,
          }),
          entryPoint: 'main',
          targets: [
            {
              format: presentationFormat,
            },
          ],
        },
        primitive: {
          topology: 'triangle-list',
        },
      });

      const sampler = device.createSampler({
        magFilter: 'linear',
        minFilter: 'linear',
      });

      const renderBindGroup = device.createBindGroup({
        layout: renderBindGroupLayout,
        entries: [
          {
            binding: 1,
            resource: sampler,
          },
          {
            binding: 2,
            // +++ Anime4K +++
            resource: restorePipeline.getOutputTexture().createView(),
            // +++ Anime4K +++
          }
        ],
      });

      // render loop
      function frame() {
        if (!video.paused) {
          updateVideoFrameTexture();
        }
        const commandEncoder = device.createCommandEncoder();
        // +++ Anime4K +++
        upscalePipeline.pass(commandEncoder);
        restorePipeline.pass(commandEncoder);
        // +++ Anime4K +++
        const passEncoder = commandEncoder.beginRenderPass({
          colorAttachments: [
            {
              view: context.getCurrentTexture().createView(),
              clearValue: {
                r: 0.0, g: 0.0, b: 0.0, a: 1.0,
              },
              loadOp: 'clear' as GPULoadOp,
              storeOp: 'store' as GPUStoreOp,
            },
          ],
        });
        passEncoder.setPipeline(renderPipeline);
        passEncoder.setBindGroup(0, renderBindGroup);
        passEncoder.draw(6);
        passEncoder.end();
        device.queue.submit([commandEncoder.finish()]);
        video.requestVideoFrameCallback(frame);
      }

      // start render loop
      video.requestVideoFrameCallback(frame);
    }

    init();
  }, [])

  return (
    <div>
      <video hidden ref={videoRef} />
      <canvas
        ref={canvasRef}
        width="640"
        height="360"
        style={{ border: '1px solid black' }}
      />
    </div>
  );
};

export default VideoCanvasA4K;

Note that some pipelines like GANx4UUL are extremely GPU consuming, and may cause you browser to be unresponsive. It may be needed to experiment with different pipeline combinations.

plasmas commented 3 weeks ago

Examples are added here. A more straightforward render function is provided to render a video to canvas, without building the naive render pipeline. See more usage details in readme.