WebAssembly / wasi-nn

Neural Network proposal for WASI
448 stars 35 forks source link

Question about half/float16 format for tensors #3

Open MaxGraey opened 4 years ago

MaxGraey commented 4 years ago

Which 16-bit float format are preferable for tensor's type?

1) IEEE half-precision 16-bit float (5 bits exponent, 10 bits fraction) 2) bfloat16 (8 bits exponent, 7 bits fraction), (supports by Google TPU) 3) depends on execution_target like bfloat16 for TPU and float16 for CPU / GPU?

abrown commented 4 years ago

I had been thinking in terms of option 3 since I assumed the implementations of wasi-nn could use quite different formats, not just for encoding 16-bit floats but even for the model data itself. What are your thoughts?

MaxGraey commented 4 years ago

I think will be great explicitly determine which format is using for interop. So perhaps add $bf16 as additional option. Contextually depend of f16 format could be tricky for user space especially if remove execution targets as suggested in #2. But always implicitly conversion (normalization) to float16 for example also unreasonable due to unnecessary overhead. So it will be great if more experienced people will participate in this discussion)

mingqiusun commented 4 years ago

bfloat16 is also supported on CPU, via AVX512 on Intel for example (https://en.wikichip.org/wiki/x86/avx512_bf16). So both floating point formats need to be supported in my opinion.

abrown commented 4 years ago

To support both formats and avoid implicit formatting (e.g. dependent on execution target), we could expand the tensor types to include both $f16 and $bf16; currently that enum only contains $f16. How many other types does this apply to?

MaxGraey commented 4 years ago

However it seems some TPUs could efficiently convert f32 to bf16 (I guess via fraction truncation): https://cloud.google.com/blog/products/ai-machine-learning/bfloat16-the-secret-to-high-performance-on-cloud-tpus So in this case It's just a matter of bandwidth but not performance

geekbeast commented 11 months ago

This is fine for loading model weights, but there is currently no guest side mechanism for generating input tensors of these types in WASM since neither of these types are valid primitive types in WASM or WIT.

An approach might be to require that all floating point formats smaller than f32 be initialized from f32, but appropriately downsampled host side based on tensor metadata and host support.

That said, I have heard some rumors about Google pushing for bf16 and fp16 in core wasm, so I will follow up on that and see what I can find out.