tairov / llama2.mojo

Inference Llama 2 in one file of pure 🔥
https://www.modular.com/blog/community-spotlight-how-i-built-llama2-by-aydyn-tairov
MIT License
2.09k stars 139 forks source link

Changed vectorize function for tile (with a nelts list) in batch_matmul #77

Open andresnowak opened 9 months ago

andresnowak commented 9 months ago

This branch is the implementation of an idea that @mikowals wanted to implement, changing the vectorize function for tile in the Batch_matmul function. Based on benchmarking with lamatune, changing vectorize for tile does seem to help a little bit in speed of the program. (This tests were run on a ryzen 3600x)

image

mikowals commented 9 months ago

I think the key change here is making 32 the first entry in VariadicList(32, 16, 8, 4, 2, 1). I think number of columns for all our matmuls are evenly divisible by 32 so none of the other entries in the list are used and there are no tail elements. So I think what you are finding is that 4 * simdwidthof[DType.float32]() does not equal 32 on your system and that 32 is a better setting in the matmul vectorization.

Since the 4 * simdwidthof[DType.float32]() was chosen by experiments on M1 Macs it isn't surprising that it is not optimal on other systems. It would be interesting to experiment with a broad range of values. For values larger than 32 you would probably want to convert all the vectorize to tile to avoid tail iterations obfuscating otherwise good settings. I have done this on M1 and could not find anything better than 4 * simdwidthof[DType.float32]().

Edit: In my comments above, I overlooked the case where 4 * simdwidthof[DType.float32]() is greater than 32 on your system. So it could be you had a lot of tail loops disappear by using tile.

andresnowak commented 9 months ago

In my machine, using 4 x simdwidthof gives a simd_width value of 32, so I don't think the speedup is related with the first entry of nelts_list (32), but it is true that majority of the sizes are divisible by 32, so I don't know in which cases this tile is helping that it is giving a tiny bit of more performance. Another thing is that if i tried to use a value of 64 in the nelts list, i would start to get incorrect values from the matmul operation, and also using a size of 32 for the simd width (in the amd cpu) and doing a reduce_add() all the time inside the _multiply_tail function (so like lets say the size of the temp variable would only need to be SIMD[f32, 1]), would also start to give incorrect values in my machine. And if I'm not wrong I remember that in a modular stream they said that using a simd width value bigger than what the machine is supposed to use, would give less performance or give incorrect behavior.

But that's why i don't understand why using sometimes 2 x simdwidthof or in mac 4 x simdwidthof would give better results. From what i'm seeing now testing in an m1 pro mac, is that 4 x simdwidthof gives 16, and in the m1 pro using the size of simd_width of 32 gives incorrect results in the matmul operation and using only a 16 simd_width size and less for the nelts list doesn't cause problems, so it seems that maybe using a simd_width size bigger than the cache line gives problems (the m1 cache line is 128 bytes), but that wouldn't explain the amd cpu (the cache line is of 64 bytes), so i don't know why using a size bigger than 4 x simdwidthof causes problems and a size bigger than 2 x simdwidthof for the reduce_add part also causes undefined behavior.

So maybe I have to create a tile function that uses an if to first check if the nelts size it is trying to use is bigger than 4 x simdwidthof

extra: (In the m1 we have 128bit for simd size and in my ryzen it is 256bit for simd size)

mikowals commented 9 months ago

I think your error with large nelts values stems from hardcoding stack_size at 32. That variable should be deleted and just allocate nelts length pointers in stack_allocation. By fixing it at 32 your are storing and loading out of bounds when the _nelts is larger than 32.

andresnowak commented 9 months ago

No I also change that value, if I change the nelts list to 64 I put the stack size to 64 and the same code that works in the amd cpu (a stack size of 32 and nelts list of 32) gives an error in the m1 cpu so I am sure the problem has to be with the simd_width size (but I will change that, as you said you can introduce errors), but if I change the nelts list value to 16 in the m1 cpu the error is fixed so it seems that the error is the operation with a bigger simd size, but I wouldn't know why it gives an error using a value bigger than 4 x simdwidth.

I'm going to test just making the value of nelts bigger in the original matmul, but maybe there won't be an error, as I said before, you get different behavior using reduce_add, if you use reduce_add with a value of 4 x simdwidth you get incorrect values, if you use 2 x simdwidth you don't get an error, based on that it seems that using a simd value bigger than what the cpu allows can give undefined behavior, the thing is why does it sometimes happen and sometimes it doesn't.

(But if you have time and you can test it I would appreciate it, I'm 99% sure what I did is correct and the problem is the simd sizes, but if you can confirm it I would appreciate it so I don't give incorrect information)

Update: If you use 8 x simdwidth (so a simd width of 64 for my amd cpu) with the original matmul code you also get undefined behavior, the matmul operation gives incorrect values. So I think this confirm that there is an undefined behavior when using a simd width bigger than what the cpu accepts

tairov commented 9 months ago

Thanks for taking time to research this topic. I saw in the other PR you tried to leverage autotune for finding optimal nelts. Probably the autotune based implementation might be a better choose, since it should help to get rid of VariadicList complication inside the dot function.

tairov commented 9 months ago

BTW, I don't see any use of tile primitive that's imported.

andresnowak commented 9 months ago
tairov commented 4 months ago

Hi @andresnowak any chance you can validate the ideas from this PR on the latest mojo release? Is there anything we can improve?