webonnx / wonnx

A WebGPU-accelerated ONNX inference run-time written 100% in Rust, ready for native and the web
Other
1.65k stars 60 forks source link

Feature/batch normalization + renaming all around #29

Closed pixelspark closed 2 years ago

pixelspark commented 2 years ago

OK, this seems to work more or less. So batch normalization works on input tensors of shape NxCxWxH (N and C are optional, if omitted they are 1) where input statistics (mean, variance, bias, scale) are provided for each C (channel), i.e. as tensors of shape C.

My initial attempt used an ArrayVector (i.e. vec4<f32>) as type, which only works when WH is a multiple of 4. I struggled a bit with getting the alignment right for cases where WH is not divisible by 4 (my test is 1x2x5x5 for instance). I came up with the following solution:

I have not seen the 'dynamic type' thing anywhere in your other shaders @haixuanTao, but it seems to me this is quite an elegant solution which makes it easier to write shaders supporting different shapes. Curious to hear what you think about this.

I still have to manually check the outcomes of the batchnormalization test, although they seem to look correct now.

haixuanTao commented 2 years ago

I am not sure what you mean by dynamic types?

pixelspark commented 2 years ago

I am not sure what you mean by dynamic types?

In batchnormalization.wgsl, the type of the input vector is defined dynamically:

[[block]]
struct Block {
    data: [[stride({{elem_stride}})]] array<{{ elem_type }}>;
}; 

// X (input)
[[group(0), binding(0)]]
var<storage, read> {{ inputs[0] }}: Block;

This allows the code to select vec3 when the input size is dividible by 3, vec4 when it can be divided by 4, etc. This was necessary to support e.g. batch normalization on 33x33 inputs. I am not sure if you have encountered this type of issue before, or know of a better way of dealing with it, but this seemed like a neat solution because it also allows the x,y,z input coordinate to refer to the position in the 'image', the channel and the batch index respectively.

pixelspark commented 2 years ago

I just added a test (and hand-checked the output against what I expect it to be, see code) so this is ready to check and merge as far as I am concerned!

haixuanTao commented 2 years ago

Ok. I think there is an issue with git, making some diffs that should already be integrated. Would it be possible to do some kind of resetting of the commits or a rebasing but this might be difficult.

pixelspark commented 2 years ago

OK, so now this thing is squashed to a single commit ready to merge.

Note that previously all CI checks passed except for Linux/amd64. This was because one floating point value was different (0.0001 difference). I have now added an 'epsilon based comparison' which should only compare the first two digits after the decimal to prevent this sort of issue.

haixuanTao commented 2 years ago

I am not sure what you mean by dynamic types?

In batchnormalization.wgsl, the type of the input vector is defined dynamically:

[[block]]
struct Block {
    data: [[stride({{elem_stride}})]] array<{{ elem_type }}>;
}; 

// X (input)
[[group(0), binding(0)]]
var<storage, read> {{ inputs[0] }}: Block;

This allows the code to select vec3 when the input size is dividible by 3, vec4 when it can be divided by 4, etc. This was necessary to support e.g. batch normalization on 33x33 inputs. I am not sure if you have encountered this type of issue before, or know of a better way of dealing with it, but this seemed like a neat solution because it also allows the x,y,z input coordinate to refer to the position in the 'image', the channel and the batch index respectively.

Yeah, I see.

In the intention, I don't have any problem with this solution.

But take into account that the stride for vec3 is 16: https://www.w3.org/TR/WGSL/#alignment-and-size and so you have to add padding to have the intended result. I did this for 3x3 convolution kernels where I used vec3: https://github.com/haixuanTao/wonnx/blob/master/templates/pool/conv_kernel_3.wgsl see https://github.com/haixuanTao/wonnx/blob/656398f229fc44245ff0644c1f89fa27ac06ae41/src/sequencer.rs#L59

I think that I would have probably stayed on vec4, with a padding within the shader for the last vec4.

pixelspark commented 2 years ago

@haixuanTao I just implemented your solution and I am afraid it will not work here without using 'Array' (f32) instead of 'ArrayVector' (vec4), which requires increasing the number of threads (to N*C*W*H).

The reason is as follows: if your input is 1x2x2x2, you have 2 'images' of 2x2. Let's say [[1,2,3,4],[5,6,7,8]]. For batch normalization you have a [2] vector of means, one for each image. This means the following should be the output:

[[1*a, 2*a, 3*a, 4*a], [5*b, 6*b, 7*b, 8*b]]

This example works if you use vec4 because then the first four elements get the mean a (input[3][channel]). Now if your image is 3x3, things become problematic if you use vec4:

[[1*a, 2*a, 3*a, 4*a, 5*a, 6*a, 7*a, 8*a, 9*a], [10*a, 11*a, 12*a, 13*b, 14*b, 15*b, 16*b, 17*b]]

What happens here: the first four elements get mean a, the second 4 get mean a, and then the next four get mean a as well (because the channel=1 at that point) but except for element 9, the next need b.

When I implement your solution with Array instead of ArrayVector, everything is fine. However this requires much more threads than my solution and is probably less efficient as using vec4...

pixelspark commented 2 years ago

But take into account that the stride for vec3 is 16: https://www.w3.org/TR/WGSL/#alignment-and-size and so you have to add padding to have the intended result. I did this for 3x3 convolution kernels where I used vec3: https://github.com/haixuanTao/wonnx/blob/master/templates/pool/conv_kernel_3.wgsl see

Good catch, did not know that! So I think I should leave out the vec3 variant then (vec2, f32 and vec4 have the expected stride of 4*N).

pixelspark commented 2 years ago

Also renaming the new 'Dims' type to 'Shape' (and all other mentions of 'dims') because that is apparently the commonly used term for this :-)

haixuanTao commented 2 years ago

Hey, So i double checked the changes. I'm ok to merge as is, if you don't want to make the changes on the batchnorm wgsl. Let me know if it is ready.

haixuanTao commented 2 years ago

Could you maybe just rename all _dims to _shape within *.wgsl and README. Thanks :)

pixelspark commented 2 years ago

Could you maybe just rename all _dims to _shape within *.wgsl and README. Thanks :)

Sure thing, I changed them in the README but as far as I can see there is no mention of '_dims' in any WGSL file left?

 wonnx % rg dims .
./src/utils.rs
21:    dims: Vec<u64>,
27:            dims: ds.iter().map(|x| *x as u64).collect(),
32:        self.dims.is_empty()
36:        self.dims.len()
40:        self.dims.iter().product()
49:        self.dims[idx]
54:        let ds = &self.dims;
55:        for i in 1..self.dims.len() {
120:            self.dims
203:        let shape = Shape::from(info.get_dims());

./src/onnx.rs
3392:    pub dims: ::std::vec::Vec<i64>,
3422:    // repeated int64 dims = 1;
3425:    pub fn get_dims(&self) -> &[i64] {
3426:        &self.dims
3428:    pub fn clear_dims(&mut self) {
3429:        self.dims.clear();
3433:    pub fn set_dims(&mut self, v: ::std::vec::Vec<i64>) {
3434:        self.dims = v;
3438:    pub fn mut_dims(&mut self) -> &mut ::std::vec::Vec<i64> {
3439:        &mut self.dims
3443:    pub fn take_dims(&mut self) -> ::std::vec::Vec<i64> {
3444:        ::std::mem::replace(&mut self.dims, ::std::vec::Vec::new())
3822:                    ::protobuf::rt::read_repeated_int64_into(wire_type, is, &mut self.dims)?;
3879:        for value in &self.dims {
3929:        for v in &self.dims {
4042:        self.dims.clear();
4369:    pub dims: ::std::vec::Vec<i64>,
4452:    // repeated int64 dims = 3;
4455:    pub fn get_dims(&self) -> &[i64] {
4456:        &self.dims
4458:    pub fn clear_dims(&mut self) {
4459:        self.dims.clear();
4463:    pub fn set_dims(&mut self, v: ::std::vec::Vec<i64>) {
4464:        self.dims = v;
4468:    pub fn mut_dims(&mut self) -> &mut ::std::vec::Vec<i64> {
4469:        &mut self.dims
4473:    pub fn take_dims(&mut self) -> ::std::vec::Vec<i64> {
4474:        ::std::mem::replace(&mut self.dims, ::std::vec::Vec::new())
4504:                    ::protobuf::rt::read_repeated_int64_into(wire_type, is, &mut self.dims)?;
4526:        for value in &self.dims {
4545:        for v in &self.dims {
4592:        self.dims.clear();

./tests/conv.rs
268:    initializer_w.set_dims(vec![m as i64, c as i64, kernel_n as i64, kernel_n as i64]);

./src/onnx_proto3.rs
3182:    pub dims: ::std::vec::Vec<i64>,
3212:    // repeated int64 dims = 1;
3215:    pub fn get_dims(&self) -> &[i64] {
3216:        &self.dims
3218:    pub fn clear_dims(&mut self) {
3219:        self.dims.clear();
3223:    pub fn set_dims(&mut self, v: ::std::vec::Vec<i64>) {
3224:        self.dims = v;
3228:    pub fn mut_dims(&mut self) -> &mut ::std::vec::Vec<i64> {
3229:        &mut self.dims
3233:    pub fn take_dims(&mut self) -> ::std::vec::Vec<i64> {
3234:        ::std::mem::replace(&mut self.dims, ::std::vec::Vec::new())
3574:                    ::protobuf::rt::read_repeated_int64_into(wire_type, is, &mut self.dims)?;
3631:        for value in &self.dims {
3681:        for v in &self.dims {
3794:        self.dims.clear();
4113:    pub dims: ::std::vec::Vec<i64>,
4196:    // repeated int64 dims = 3;
4199:    pub fn get_dims(&self) -> &[i64] {
4200:        &self.dims
4202:    pub fn clear_dims(&mut self) {
4203:        self.dims.clear();
4207:    pub fn set_dims(&mut self, v: ::std::vec::Vec<i64>) {
4208:        self.dims = v;
4212:    pub fn mut_dims(&mut self) -> &mut ::std::vec::Vec<i64> {
4213:        &mut self.dims
4217:    pub fn take_dims(&mut self) -> ::std::vec::Vec<i64> {
4218:        ::std::mem::replace(&mut self.dims, ::std::vec::Vec::new())
4248:                    ::protobuf::rt::read_repeated_int64_into(wire_type, is, &mut self.dims)?;
4270:        for value in &self.dims {
4289:        for v in &self.dims {
4336:        self.dims.clear();

(Obviously the code that touches the ONNX protos still use 'get_dims' and the like)

pixelspark commented 2 years ago

Hey, So i double checked the changes. I'm ok to merge as is, if you don't want to make the changes on the batchnorm wgsl. Let me know if it is ready.

Yep good to go!

haixuanTao commented 2 years ago

Yes, agreed! I forgot I was using i_chunks and not i_dims.

Great addition all around!