tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
https://docs.tenstorrent.com/ttnn/latest/index.html
Apache License 2.0
480 stars 78 forks source link

Speculative Execution of LLMs on Metal: Master Issue #14240

Open caixunshiren opened 4 weeks ago

caixunshiren commented 4 weeks ago

Issue

We propose fine-grain speculative execution on metal. Specifically, we introduce a new type of parallelism where devices are used to pre-execute ops from speculative output.

def llama(x_multidevice):

tensor1 contains sender information for the speculation op.

# Initially, device 0 has 1, which means it sends to device 1 its speculation result. 
# device 1 has -1, which means it is a receiver of speculative result. 
tensor1 = ttnn.multi_device_tensor([
                                     [1], # for device 0
                                     [-1], # for device 1
                                     ])
# tensor2 contains skip compute information.
# initially, both devices contain 0, which means no device skip compute.
tensor2 = ttnn.multi_device_tensor([
                                     [0], # for device 0
                                     [0], # for device 1
                                     ])
for i in range(num_layers):
    with ttnn.skip(tensor_2):
        # under this scope, if a device's tensor2 is 1, it skips the op's compute. Run normally ow.
        x_multidevice = run_ops_before_speculative_flash_decode(x_multidevice)
    # wait for both devices to execute this op, and decide which device is sender and receiver for spec
    ttnn.sync_before_spec(tensor1, tensor2)
    # sync kv cache based on tensor1 info
    ttnn.sync_cache(K, tensor1)
    ttnn.sync_cache(V, tensor2)
    x_multidevice = speculative_flash_decode(x_multidevice, tensor1, tensor2)
    with ttnn.skip(tensor2):
        x_multidevice = run_ops_after_speculative_flash_decode(x_multidevice)

ttnn.sync_before_spec(tensor1, tensor2)
return x_multidevice, tensor_1

given input tokens from user

new_tok = llama_prefill(input_tokens) while decoding: x = llama_embedding(new_tok) x_multidevice = ttnn.replicate_tensor_to_multi_device(x) x_multidevice, tensor1 = llama(x_multidevice)

choose the x from the device containing the correct tensor value

x = choose_x_based_on_tensor(tensor1)
logits = llama_lm_head(x)
new_tok = do_sampling(logits)
 - Proposed new ttnn ops

ttnn.skip(tensor_2): When lauching the op, read tensor_2's address. If the value is 1, skip the kernel.

ttnn.sync_before_spec(tensor1, tensor2): Do handshake between device 0 and device 1. If tensor2 is 0 on both device (0,0), this indicates previous speculation failed. Exit and do nothing. If tensor2 is 1 on both device (1,1), this indicates an errorneous state. Raise Error. If tensor2 is (1,0), this indicates device 0's speculation is successful and its subsequent computes have been skipped. Therefore device 1 contains the main path. Set tensor1 to (-1, 0) and tensor2 to (0,0) If tensor2 is (0,1), this indicates device 1's speculation is successful and its subsequent computes have been skipped. Therefore device 0 contains the main path. Set tensor1 to (1, -1) and tensor2 to (0,0) (all gather with local tensor update)

ttnn.sync_cache(K, tensor1): if tensor1 is -1 for this device, do nothing. else, do remote write of the cache to the device id in tensor1 (pad and then reduce scatter)

speculative_flash_decode(x_multidevice, tensor1, tensor2): case 1 tensor1 is -1: This means it is a receiver chip so it will waits for the speculative result being sent.

case 2 tensor1 is the receiver chip id: like normal flash decode, we have reducer core and worker core. however, reducer core only do the first and last chunk, then it writes the speculative results and signals the edm to send result over to the receiver chip. afterwards, the reducer core will wait for results from other workers and do full flash decode as usual. At the end of flash decode, it will compute L2 distance between speculative result and full result. If the distance is within threshold, write 1 to tensor2, else write 0 to tensor2

avoraTT commented 2 weeks ago

Here's a diagram for the idea behind TtSyncTensor. It's essentially implemented as an all reduce, with a preceding masking step that ensures that all tensors, except on the sender chip, are 0 before performing the sync.

image
caixunshiren commented 2 weeks ago

Looks great!!