Open caixunshiren opened 4 weeks ago
Here's a diagram for the idea behind TtSyncTensor. It's essentially implemented as an all reduce, with a preceding masking step that ensures that all tensors, except on the sender chip, are 0 before performing the sync.
Looks great!!
Issue
We propose fine-grain speculative execution on metal. Specifically, we introduce a new type of parallelism where devices are used to pre-execute ops from speculative output.
Speculative op: Flash decode
Proposed Implementation:
Pseudocode:
def llama(x_multidevice):
tensor1 contains sender information for the speculation op.
given input tokens from user
new_tok = llama_prefill(input_tokens) while decoding: x = llama_embedding(new_tok) x_multidevice = ttnn.replicate_tensor_to_multi_device(x) x_multidevice, tensor1 = llama(x_multidevice)
choose the x from the device containing the correct tensor value
ttnn.skip(tensor_2): When lauching the op, read tensor_2's address. If the value is 1, skip the kernel.
ttnn.sync_before_spec(tensor1, tensor2): Do handshake between device 0 and device 1. If tensor2 is 0 on both device (0,0), this indicates previous speculation failed. Exit and do nothing. If tensor2 is 1 on both device (1,1), this indicates an errorneous state. Raise Error. If tensor2 is (1,0), this indicates device 0's speculation is successful and its subsequent computes have been skipped. Therefore device 1 contains the main path. Set tensor1 to (-1, 0) and tensor2 to (0,0) If tensor2 is (0,1), this indicates device 1's speculation is successful and its subsequent computes have been skipped. Therefore device 0 contains the main path. Set tensor1 to (1, -1) and tensor2 to (0,0) (all gather with local tensor update)
ttnn.sync_cache(K, tensor1): if tensor1 is -1 for this device, do nothing. else, do remote write of the cache to the device id in tensor1 (pad and then reduce scatter)
speculative_flash_decode(x_multidevice, tensor1, tensor2): case 1 tensor1 is -1: This means it is a receiver chip so it will waits for the speculative result being sent.
case 2 tensor1 is the receiver chip id: like normal flash decode, we have reducer core and worker core. however, reducer core only do the first and last chunk, then it writes the speculative results and signals the edm to send result over to the receiver chip. afterwards, the reducer core will wait for results from other workers and do full flash decode as usual. At the end of flash decode, it will compute L2 distance between speculative result and full result. If the distance is within threshold, write 1 to tensor2, else write 0 to tensor2