Open MaxGraey opened 4 years ago
I had been thinking in terms of option 3 since I assumed the implementations of wasi-nn could use quite different formats, not just for encoding 16-bit floats but even for the model data itself. What are your thoughts?
I think will be great explicitly determine which format is using for interop. So perhaps add $bf16
as additional option. Contextually depend of f16
format could be tricky for user space especially if remove execution targets as suggested in #2. But always implicitly conversion (normalization) to float16
for example also unreasonable due to unnecessary overhead. So it will be great if more experienced people will participate in this discussion)
bfloat16 is also supported on CPU, via AVX512 on Intel for example (https://en.wikichip.org/wiki/x86/avx512_bf16). So both floating point formats need to be supported in my opinion.
To support both formats and avoid implicit formatting (e.g. dependent on execution target), we could expand the tensor types to include both $f16
and $bf16
; currently that enum only contains $f16
. How many other types does this apply to?
However it seems some TPUs could efficiently convert f32 to bf16 (I guess via fraction truncation): https://cloud.google.com/blog/products/ai-machine-learning/bfloat16-the-secret-to-high-performance-on-cloud-tpus So in this case It's just a matter of bandwidth but not performance
This is fine for loading model weights, but there is currently no guest side mechanism for generating input tensors of these types in WASM since neither of these types are valid primitive types in WASM or WIT.
An approach might be to require that all floating point formats smaller than f32 be initialized from f32, but appropriately downsampled host side based on tensor metadata and host support.
That said, I have heard some rumors about Google pushing for bf16 and fp16 in core wasm, so I will follow up on that and see what I can find out.
Which 16-bit float format are preferable for tensor's type?
1) IEEE half-precision 16-bit float (5 bits exponent, 10 bits fraction) 2) bfloat16 (8 bits exponent, 7 bits fraction), (supports by Google TPU) 3) depends on
execution_target
likebfloat16
for TPU andfloat16
for CPU / GPU?