tracel-ai / burn

Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals.
https://burn.dev
Apache License 2.0
8.17k stars 399 forks source link

f16 with WGPU #597

Open Gadersd opened 1 year ago

Gadersd commented 1 year ago

Feature description

It would be great if burn-wgpu supported f16. Is there a timeline for this?

Feature motivation

Large models such as Stable Diffusion exceed wgpu's maximum buffer size when using f32. f16 support would enable some such models to be run with the wgpu backend.

nathanielsimard commented 1 year ago

Linked to https://github.com/gfx-rs/wgpu/issues/4384

Gadersd commented 1 year ago

Does anyone know if the limited buffer size in wgpu will be alleviated eventually? Even if f16 gets supported the buffer size limits will still be a barrier to running large models.

nathanielsimard commented 1 year ago

You can manually override the limits when selecting the device : https://github.com/burn-rs/burn/blob/ed255c5561b85876cf02cbc4d48f35e1f0d29ac0/burn-wgpu/src/context/base.rs#L228

The limits are low for compatibility reasons I think, but I can increase max_storage_buffer_binding_size on my RTX 3070 to usize::pow(8, 10) and load bigger tensors. I think we should come up with a way to change the limits for specific devices, probably with a config file or env variables (or both).

Gadersd commented 1 year ago

Would it be reasonable to use the pub fn limits(&self) -> Limits function on the adapter to get the best limits that the adapter offers instead of relying on defaults? I think this would resolve the issue.

nathanielsimard commented 1 year ago

Ho yes I didn't know that, I'll make a PR soon.

nathanielsimard commented 1 year ago

@Gadersd PR done https://github.com/burn-rs/burn/pull/601 Let me know if it helps in running your models.

Gadersd commented 1 year ago

My bad, I accidentally tested with tch.

Gadersd commented 1 year ago

I get the following panic when trying to run stable diffusion: `thread panicked at 'Error in Queue::submit: Validation Error

Caused by: Parent device is lost ', /home/hermes/.cargo/registry/src/index.crates.io-6f17d22bba15001f/wgpu-0.17.0/src/backend/direct.rs:2289:30 note: run with RUST_BACKTRACE=1 environment variable to display a backtrace thread 'main' panicked at 'Unable to read buffer', /home/hermes/.cargo/git/checkouts/burn-acfbee6a141c1b41/22ab534/burn-wgpu/src/context/client.rs:120:17`

nathanielsimard commented 1 year ago

It may happen when you run out of memory. You can try to lower the MAX_TASKS to 1 to reduce memory usage:

https://github.com/burn-rs/burn/blob/9361193b5d62065807fdb6721e95dca8bcf8bf74/burn-wgpu/src/context/server.rs#L65

It might increase the computing time, but it's probably negligible for a big model. Once again, a value that I'm not sure how we should set it 😅.

Gadersd commented 1 year ago

Setting MAX_TASKS to 1 enabled inference to work, but it was very slow compared to the tch run, ~5 minutes for 1 image with wgpu vs ~15 seconds for two images with tch. Perhaps the value should be settable by the user when the default isn't viable?

nathanielsimard commented 1 year ago

Yes we could do that for now. There is an issue to optimize the memory strategy: https://github.com/burn-rs/burn/issues/582.

nathanielsimard commented 1 year ago

@Gadersd I added a way to configure MAX_TASKS: https://github.com/burn-rs/burn/pull/603.

antimora commented 11 months ago

There is a tweet saying "float16 in webGPU finally works now"

https://twitter.com/nisten/status/1698796718840598850

Worth looking into this and see if we need to update anything.

antimora commented 8 months ago

https://github.com/gfx-rs/wgpu/issues/4384