tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
359 stars 35 forks source link

CCL Mid-Level (Python) Programming API #7205

Open SeanNijjar opened 4 months ago

SeanNijjar commented 4 months ago

This issue tracks a larger body of work to bringup functionality that enables a user to write CCL ops in python and at a higher-level than at the moment. At the moment, this is a high-level proposal and I'm soliciting feedback from stakeholders - this isn't yet plan of record.

Motivation

There are three motivating factors for this work: 1) Simplify CCL op development - eventually enable all CCL ops to be implemented in python 2) Empower model developers to directly inspect and modify/add to existing CCL ops without having to leave Python 3) Enable more fine-grain control for overlap of CCL (datamovement) and compute ops

Programming Model/API

Want to enable consistent programming model between single/multi-chip. With @cfjchu's multi-chip APIs, I'm hoping we can land on a pythonic and usable API that has low cognitive load. Additionally, it would ideally act as a precursor to enable Triton-like syntax in our API (or direct integration with Triton itself)

As a starting point, we'd only enable device-to-device data movement with python array slicing and have that work in conjunction with per-chip ops. Consider a basic example:

Pythonic Array Slice Example

Consider a multi-chip tensor, that is width sharded across a ring/line topology of 8 chips. Each chips holds a [1024,128] input and output tensor, in total, the canonical tensor shape is [1024, 1024]. In standard python a user could assign a subset of an output tensor to a subset of the input tensor. In this case, we have some arbitrarily chosen assignment:

image

Under the hood, this would represent a send from chip 7 -> chip 0. The user would specify an array slice, but behind the scenes, this would invoke a ccl::send_data (not yet available) with some generated lower level indexing and connection with the erisc data mover (EDM). This is the basic building block.

More Advanced Usage

The above example does provide basic functionality, but doesn't provide the value we'd really like to have. The goal is to be able to specify many of these (e.g. in a loop nest), which would effectively launch multiple sends in parallel. Furthermore, these would be invokable in sequence with compute ops to be able to overlap compute and data movement, like the JAX example.

For example (forgive the currently fuzzy syntax and probably incorrect sharded MM code, don't get hung up on the details until the proposal is more mature), you might choose to specify the fused ccl-matmul kernel in the following way:

chunk_width = 128
chunk_height = 1024

# implement a fused ring all-gather with matmul where we overlap data movement with 
# the compute of matmul partials (for each timestep)
# In this case, we shift the input tensor through the ring (one hop) and then we 
# compute the partial result.
# I didn't include the code to describe it as a systolic array (which would involve
# parking it into a chip local buffer before matmulling)
for timestep : 0 .. ring_size - 1:
  for r : 0.. ring_size:
    // compute partial for one time step per chip
    act_chunk = activation[r * chunk_height: (r + 1) * chunk_height,
                           r * chunk_width: (r + 1) * chunk_width]
    output[r * chunk_height: (r + 1) * chunk_height, r * chunk_width: (r + 1) * chunk_width] +=  
            mm(act_chunk, w[:, r * chunk_width: (r + 1) * chunk_width])

These array slices and compute operations would be non-blocking relative to the loop nest's inner loop iterations. Of course, the matmul would wait for the array slice data to arrive.

As a starting point, we don't have this dependence analysis. We require the user to ensure no dependence or only forward dependencies in their code. Any cases that don't satisfy these conditions can, for the time being, utilize temporary syntax to let the user explicitly barrier (or cause the backend to emit multiple programs).

The next snippet contains an example of how the prior snippet could be updated to specify this:

chunk_width = 128
chunk_height = 1024

for timestep : 0 .. ring_size - 1:

  program_epoch.open()  # Signals we will start collecting "commands"

  for r : 0.. ring_size:
    // compute partial for one time step per chip
    act_chunk = activation[r * chunk_height: (r + 1) * chunk_height,
                           r * chunk_width: (r + 1) * chunk_width]
    output[r * chunk_height: (r + 1) * chunk_height, r * chunk_width: (r + 1) * chunk_width] =  
            mm(act_chunk, w[:, r * chunk_width: (r + 1) * chunk_width])

  # Signals a barrier of some sort (e.g. terminate a "program" for the inner loop)
  #  you could imagine this double-loop nest as analogous to a composite op where
  # each program_epoch.close() command ends a new metal program
  program_epoch.close()              

Although insufficient for good performance, this would enable functional bringup of this functionality, with minimal new feature bringup (new point-to-point op and EDM channel schedule analysis).

Some design work is still needed to let us remove this open/close behaviour and express the full loop-nest as a single program.

Implementation

Although the proposal, syntax, and capabilities still need finalizing, there are a few obvious and necessary pieces of work that are needed to support this functionality. They're outlined below:

Looking to the Future

Removing Program Open/Close

I'm initially thinking that I may introduce a very lightweight command interpreter to EDM that would let it establish connectivity with workers dynamically. These commands would specify the worker, sync (semaphore) addresses, and message count. This may be more feasible to support in BlackHole, but it still seems feasible in WH. This would enable us to support a relatively large number of independent sends/receives. For the data dependence problem, we'll need to add some data-dependence analysis under the hood. Data-dependence synchronization may be implementable on-top of the events runtime infra. # Open Questions - [ ] Q: How do we support virtual worker grids? - A?: Python decorators. Obviously every underlying op needs to support virtual worker grid - [ ] Q: How can we express data-dependencies for asynchronous compute/datamovement (as in the examples above?) - A?: Could we use a mix of events (cross core and cross partial "op") and CBs (intra-core) data dependencies? - How to not make the dependent operations wait for the _full_ producer partial op to complete? E.g. let the consumer start streaming as soon as the first tiles are available but no earlier - [ ] Q: How do we get past the current hard limits of maximum live EDM channels? - Option 1: EDM exposes a channel sharing option - it can use this to cycle through producers/consumers per channel - This will still eventually hit the limit of # producer cores (if each core needs to complete different transfers over time it may need to be described as distinct "connections" - Option 2: Add command interpreter capabilities to EDM? Transfers can be stored in DRAM and EDM can (pre)fetch these over time - Option 3: ? - [ ] Q: # Info Gathering Tasks - [ ] Collect all tunable CCL knobs (edm buffer size, count)
SeanNijjar commented 4 months ago

Adding additional thoughts/feedback based on some discussions. Consider incorporating into the main issue description above.

Considerations this work should take into account:

Some miscellaneous ideas:

SeanNijjar commented 4 months ago

I've had a bit of time to think through some of this a little more while machine access was lost. Dumping some notes/diagrams here. I have an excalidraw on my machine I'm adding to over time. Also added some new open questions that came up as I was working through this

In progress high level schedule decision tree

image

Fine Grain Synchronization

image

Indexing Strategies:

image

Indexing Examples (WIP) Ex1

image

Ex2

image

Ex3

image

Non-blocking reader kernel code snippet (pseudocode):

ccl_done = False
op_done = False
s = AddrGen(...)  # For producer (CCL)
d = AddrGen(...)  # For consumer (matmul)
producer_input_available = False
edm_pages_read = 0
edm_pages_per_buf = <arg_val>

while (!ccl_done && !op_done):

  ## CCL BLOCK - Non Blocking - Imported from EDM-worker interface headers
  if !ccl_done:
    if !producer_input_available && *semaphore_from_edm != 0 && !ccl_wait_for_writes_complete:
      producer_input_available = True
      *semaphore_from_edm = 0

    if producer_input_available: 
      ccl_pages_to_read = std::min(edm_pages_per_buf, ccl_bytes_left)
      noc_read(edm_buffer, ccl_staging_buf,ccl_pages_to_read )
      ccl_read_in_progress = True

    if ccl_read_in_progress && ccl_read_complete(): ## Need to track num transfers for CCL and op independently
      ccl_wait_for_writes_complete = True
      ccl_read_in_progress = False
      noc_semaphore_inc(edm_receiver_worker_read_complete_semaphore, 1) # tells EDM to get next data

    if ccl_write_to_tensor_in_progress && ccl_writer_complete():
      ## This assumes 1:1 association between reader/writer
      ## Need to figure out how to synchronize for many:many association (since pages
      ## written to canonical tensor will not be dense/contiguous
      noc_semaphore_inc(canonical_tensor_pages_read_semaphore, ccl_pages_to_read)
      ccl_wait_for_writes_complete = False

  ## MM Block - Non Blocking
  if !op_done:
    canonical_idx_last = op_reader.last_canonical_idx_in_contiguous_row
    canonical_idx_first = op_reader.current_canonical_idx
    if *op_reader_semaphore >= canonical_idx_last && !op_read_in_progress:
      l1_addr = cb_reserve_back(op_num_pages_per_contiguous_row)
      read_pages_from_interleaved(d, offset=canonical_idx_first, l1_addr, op_num_pages_per_contiguous_row)
      op_reader.advance_n_pages(op_num_pages_per_contiguous_row)
      op_read_in_progress = True

    if op_read_in_progress && op_read_complete(): ## Need to track num transfers for CCL and op independently
      cb_push_back(op_num_pages_per_contiguous_row)
      op_read_in_progress = False