Closed pixelspark closed 2 years ago
@haixuanTao I rebased this on top of your workspace PR, this one should be ready to go now.
Ok I think we can merge this branch.
Ok I think we can merge this branch.
Yes, agreed. We might improve the API in a later PR (e.g. the builder pattern or something else)
This PR (based on my earlier PR #45 that implemented the DAG) adds support for data types other than
f32
. Right now I have implemented support fori32
inputs, but others (such asf16
) should be fairly easy to add. The current version will always cast outputs toVec<f32>
but of course the output type should depend on the actual output type in the model.Notably, BERT and some other models appear to use
int64
inputs - working with i64 in WGSL however is not (yet?) possible. We should therefore probably re-interpret these asint32
.Some implementation notes:
Shape
now has adata_type
field of typeScalarType
. The latter holds some information about the type's size and striding.struct.wgsl
, I define a typeScalar
to be the 'default' scalar type (all ops except those that do conversions work in a single type) based on template variablescalar_type
. It also definesVec3
,Vec4
,Mat4x4
etc. which are the appropriate vector/matrix types. All shaders use these type names instead of the raw ones.scalar_type
variable by looking at the types of inputs/outputs. They should all agree (except in cases where there may be inputs/outputs of different type, which happens e.g. forReshape
where the first input is the data and the second input is a shape, with integer data type) or else an error is thrown.InputTensor::F32(my_input.as_slice())
. Output tensors are currently still all of typeVec<f32>
(when the actual output data type is different the current version casts the elements tof32
first).See
tests/arithmetic.rs
for an example of integer addition.