In Llama3 on TG, we are doing slicing of 32 users into 4 groups please check the diagram below
We currently don't have support for Slicing multi-device tensors, as a workaround we are using a Slice Matmul to do the same. Slicing matmul takes around 4600 ns on device.
Requirement
Shape of input tensors on each device = (1, 1, users, hidden) -> (1, 1, 32, 1280)
Slice the input tensors on device such that
1st column in device clusters gets [0-7] users, next column gets [8-16] users so on...
After Slicing, shape of tensors on each device = (1, 1, 8[32], 1280)
Description
In Llama3 on TG, we are doing slicing of 32 users into 4 groups please check the diagram below
We currently don't have support for Slicing multi-device tensors, as a workaround we are using a Slice Matmul to do the same. Slicing matmul takes around
4600 ns
on device.Requirement