Closed pixelspark closed 2 years ago
I am not sure what you mean by dynamic types?
I am not sure what you mean by dynamic types?
In batchnormalization.wgsl, the type of the input vector is defined dynamically:
[[block]]
struct Block {
data: [[stride({{elem_stride}})]] array<{{ elem_type }}>;
};
// X (input)
[[group(0), binding(0)]]
var<storage, read> {{ inputs[0] }}: Block;
This allows the code to select vec3 when the input size is dividible by 3, vec4 when it can be divided by 4, etc. This was necessary to support e.g. batch normalization on 33x33 inputs. I am not sure if you have encountered this type of issue before, or know of a better way of dealing with it, but this seemed like a neat solution because it also allows the x,y,z input coordinate to refer to the position in the 'image', the channel and the batch index respectively.
I just added a test (and hand-checked the output against what I expect it to be, see code) so this is ready to check and merge as far as I am concerned!
Ok. I think there is an issue with git, making some diffs that should already be integrated. Would it be possible to do some kind of resetting of the commits or a rebasing but this might be difficult.
OK, so now this thing is squashed to a single commit ready to merge.
Note that previously all CI checks passed except for Linux/amd64. This was because one floating point value was different (0.0001 difference). I have now added an 'epsilon based comparison' which should only compare the first two digits after the decimal to prevent this sort of issue.
I am not sure what you mean by dynamic types?
In batchnormalization.wgsl, the type of the input vector is defined dynamically:
[[block]] struct Block { data: [[stride({{elem_stride}})]] array<{{ elem_type }}>; }; // X (input) [[group(0), binding(0)]] var<storage, read> {{ inputs[0] }}: Block;
This allows the code to select vec3 when the input size is dividible by 3, vec4 when it can be divided by 4, etc. This was necessary to support e.g. batch normalization on 33x33 inputs. I am not sure if you have encountered this type of issue before, or know of a better way of dealing with it, but this seemed like a neat solution because it also allows the x,y,z input coordinate to refer to the position in the 'image', the channel and the batch index respectively.
Yeah, I see.
In the intention, I don't have any problem with this solution.
But take into account that the stride for vec3 is 16: https://www.w3.org/TR/WGSL/#alignment-and-size and so you have to add padding to have the intended result. I did this for 3x3 convolution kernels where I used vec3: https://github.com/haixuanTao/wonnx/blob/master/templates/pool/conv_kernel_3.wgsl see https://github.com/haixuanTao/wonnx/blob/656398f229fc44245ff0644c1f89fa27ac06ae41/src/sequencer.rs#L59
I think that I would have probably stayed on vec4, with a padding within the shader for the last vec4.
@haixuanTao I just implemented your solution and I am afraid it will not work here without using 'Array' (f32) instead of 'ArrayVector' (vec4N*C*W*H
).
The reason is as follows: if your input is 1x2x2x2, you have 2 'images' of 2x2. Let's say [[1,2,3,4],[5,6,7,8]]
. For batch normalization you have a [2] vector of means, one for each image. This means the following should be the output:
[[1*a, 2*a, 3*a, 4*a], [5*b, 6*b, 7*b, 8*b]]
This example works if you use vec4 because then the first four elements get the mean a
(input[3][channel]
). Now if your image is 3x3, things become problematic if you use vec4:
[[1*a, 2*a, 3*a, 4*a, 5*a, 6*a, 7*a, 8*a, 9*a], [10*a, 11*a, 12*a, 13*b, 14*b, 15*b, 16*b, 17*b]]
What happens here: the first four elements get mean a, the second 4 get mean a, and then the next four get mean a as well (because the channel=1 at that point) but except for element 9, the next need b.
When I implement your solution with Array
instead of ArrayVector
, everything is fine. However this requires much more threads than my solution and is probably less efficient as using vec4...
But take into account that the stride for vec3 is 16: https://www.w3.org/TR/WGSL/#alignment-and-size and so you have to add padding to have the intended result. I did this for 3x3 convolution kernels where I used vec3: https://github.com/haixuanTao/wonnx/blob/master/templates/pool/conv_kernel_3.wgsl see
Good catch, did not know that! So I think I should leave out the vec3 variant then (vec2, f32 and vec4 have the expected stride of 4*N).
Also renaming the new 'Dims' type to 'Shape' (and all other mentions of 'dims') because that is apparently the commonly used term for this :-)
Hey, So i double checked the changes. I'm ok to merge as is, if you don't want to make the changes on the batchnorm wgsl. Let me know if it is ready.
Could you maybe just rename all _dims
to _shape
within *.wgsl and README. Thanks :)
Could you maybe just rename all
_dims
to_shape
within *.wgsl and README. Thanks :)
Sure thing, I changed them in the README but as far as I can see there is no mention of '_dims' in any WGSL file left?
wonnx % rg dims .
./src/utils.rs
21: dims: Vec<u64>,
27: dims: ds.iter().map(|x| *x as u64).collect(),
32: self.dims.is_empty()
36: self.dims.len()
40: self.dims.iter().product()
49: self.dims[idx]
54: let ds = &self.dims;
55: for i in 1..self.dims.len() {
120: self.dims
203: let shape = Shape::from(info.get_dims());
./src/onnx.rs
3392: pub dims: ::std::vec::Vec<i64>,
3422: // repeated int64 dims = 1;
3425: pub fn get_dims(&self) -> &[i64] {
3426: &self.dims
3428: pub fn clear_dims(&mut self) {
3429: self.dims.clear();
3433: pub fn set_dims(&mut self, v: ::std::vec::Vec<i64>) {
3434: self.dims = v;
3438: pub fn mut_dims(&mut self) -> &mut ::std::vec::Vec<i64> {
3439: &mut self.dims
3443: pub fn take_dims(&mut self) -> ::std::vec::Vec<i64> {
3444: ::std::mem::replace(&mut self.dims, ::std::vec::Vec::new())
3822: ::protobuf::rt::read_repeated_int64_into(wire_type, is, &mut self.dims)?;
3879: for value in &self.dims {
3929: for v in &self.dims {
4042: self.dims.clear();
4369: pub dims: ::std::vec::Vec<i64>,
4452: // repeated int64 dims = 3;
4455: pub fn get_dims(&self) -> &[i64] {
4456: &self.dims
4458: pub fn clear_dims(&mut self) {
4459: self.dims.clear();
4463: pub fn set_dims(&mut self, v: ::std::vec::Vec<i64>) {
4464: self.dims = v;
4468: pub fn mut_dims(&mut self) -> &mut ::std::vec::Vec<i64> {
4469: &mut self.dims
4473: pub fn take_dims(&mut self) -> ::std::vec::Vec<i64> {
4474: ::std::mem::replace(&mut self.dims, ::std::vec::Vec::new())
4504: ::protobuf::rt::read_repeated_int64_into(wire_type, is, &mut self.dims)?;
4526: for value in &self.dims {
4545: for v in &self.dims {
4592: self.dims.clear();
./tests/conv.rs
268: initializer_w.set_dims(vec![m as i64, c as i64, kernel_n as i64, kernel_n as i64]);
./src/onnx_proto3.rs
3182: pub dims: ::std::vec::Vec<i64>,
3212: // repeated int64 dims = 1;
3215: pub fn get_dims(&self) -> &[i64] {
3216: &self.dims
3218: pub fn clear_dims(&mut self) {
3219: self.dims.clear();
3223: pub fn set_dims(&mut self, v: ::std::vec::Vec<i64>) {
3224: self.dims = v;
3228: pub fn mut_dims(&mut self) -> &mut ::std::vec::Vec<i64> {
3229: &mut self.dims
3233: pub fn take_dims(&mut self) -> ::std::vec::Vec<i64> {
3234: ::std::mem::replace(&mut self.dims, ::std::vec::Vec::new())
3574: ::protobuf::rt::read_repeated_int64_into(wire_type, is, &mut self.dims)?;
3631: for value in &self.dims {
3681: for v in &self.dims {
3794: self.dims.clear();
4113: pub dims: ::std::vec::Vec<i64>,
4196: // repeated int64 dims = 3;
4199: pub fn get_dims(&self) -> &[i64] {
4200: &self.dims
4202: pub fn clear_dims(&mut self) {
4203: self.dims.clear();
4207: pub fn set_dims(&mut self, v: ::std::vec::Vec<i64>) {
4208: self.dims = v;
4212: pub fn mut_dims(&mut self) -> &mut ::std::vec::Vec<i64> {
4213: &mut self.dims
4217: pub fn take_dims(&mut self) -> ::std::vec::Vec<i64> {
4218: ::std::mem::replace(&mut self.dims, ::std::vec::Vec::new())
4248: ::protobuf::rt::read_repeated_int64_into(wire_type, is, &mut self.dims)?;
4270: for value in &self.dims {
4289: for v in &self.dims {
4336: self.dims.clear();
(Obviously the code that touches the ONNX protos still use 'get_dims' and the like)
Hey, So i double checked the changes. I'm ok to merge as is, if you don't want to make the changes on the batchnorm wgsl. Let me know if it is ready.
Yep good to go!
Yes, agreed! I forgot I was using i_chunks and not i_dims.
Great addition all around!
OK, this seems to work more or less. So batch normalization works on input tensors of shape
NxCxWxH
(N and C are optional, if omitted they are 1) where input statistics (mean, variance, bias, scale) are provided for eachC
(channel), i.e. as tensors of shapeC
.My initial attempt used an ArrayVector (i.e.
vec4<f32>
) as type, which only works when WH is a multiple of 4. I struggled a bit with getting the alignment right for cases where WH is not divisible by 4 (my test is 1x2x5x5 for instance). I came up with the following solution:x
coordinatey
coordinate (y_threads=C, this means channels are calculated in parallel which is good)z
component (z_threads=N, this means batches are calculated in parallel which is good)W*H
and 4, e.g. whenW*H
is divisible by 4, usevec4<f32>
, if divisible by 3 usevec3<f32>
, and so on. (I tried also usingmat4x4<f32>
but apparently WGSL doesn't allow elementwise subtraction of matrices? need to look into that).gcd(W*H, 4)
.I have not seen the 'dynamic type' thing anywhere in your other shaders @haixuanTao, but it seems to me this is quite an elegant solution which makes it easier to write shaders supporting different shapes. Curious to hear what you think about this.
I still have to manually check the outcomes of the batchnormalization test, although they seem to look correct now.