google / gemma.cpp

lightweight, standalone C++ inference engine for Google's Gemma models.
Apache License 2.0
5.9k stars 499 forks source link

Parallel prefill seems to lead to completely different results #318

Closed ufownl closed 1 month ago

ufownl commented 1 month ago

I tried using the latest dev branch to handle the long prefill, and it gave completely different results than the previous non-parallel version, then I tried running it in a single thread and it behaved just like before. At first, I thought parallelism caused the data race in my code, but when I tried to lock the stream_token callback, the problem persisted.

So I looked into it and found that the main difference is that the parallel version uses separate Activations objects for each thread, and they were only updated by a portion of the tokens and in an undefined order. Previously, a single Activations object was updated with all prefilled tokens in the order they were entered. Intuitively, this does seem to have a big impact on the KV cache, and in my experiments, the KV cache seems to contain only a small portion of the prefilled content.

I'm not sure if this is the key to this issue, I'll take a look later if I have time.

ufownl commented 1 month ago

Today I looked into it again and may find this issue's key.

This code snippet comes from GemmaAttention:

...

const size_t start_pos =
    pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos);
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
  const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
  const size_t kv_offset =
      cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset;
  const float* HWY_RESTRICT k2 = kv_cache.kv_cache.get() + kv_offset;
  const float score = Dot(q, k2, kQKVDim);
  head_att[pos2 % kSeqLen] = score;
}

...

for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
  const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
  const size_t kv_offset =
      cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset;
  float* HWY_RESTRICT v2 =
      kv_cache.kv_cache.get() + kv_offset + kQKVDim;
  MulByConstAndAdd(head_att[pos2 % kSeqLen], v2, att_out, kQKVDim);
}

...

This code snippet shows attention calculation dependency on the values of previous tokens in the KV cache. The attention window size is much larger than the prefill batch size. Therefore, they will be computed in parallel in different batches during the prefill phase, which cannot guarantee that the previous computations have been completed.

jan-wassenberg commented 1 month ago

Thank you for pointing this out. Unfortunately I only tested short contexts with msan, not long contexts. The good news is that msan immediately catches this.

The reason we have multiple Activations is to support batches of independent queries. I'm out this week, but will start thinking about how best to handle the various combinations (num_queries > batch etc.) In the meantime, I think passing a ThreadPool(0) to PrefillState will get the old behavior back, without affecting Decode speed.

ufownl commented 1 month ago

@jan-wassenberg Passing ThreadPool(0) to PrefillState seems to affect the number of workers of the inner pool passed into TransfomerLayer and impact its performance. My current workaround is to set outer_workers to 1 directly, it seems to work fine.

jan-wassenberg commented 1 month ago

Sounds good, triggering the if (outer_workers <= 1) { condition is simpler/more elegant, nice.

jan-wassenberg commented 1 month ago

324 fixes this :)

jan-wassenberg commented 1 month ago

Landed, please reopen if you still notice unexpected results. We tested with msan and long/short contexts and expanded the batching test in gemma_test.