apache / lucene

Apache Lucene open-source search software
https://lucene.apache.org/
Apache License 2.0
2.45k stars 973 forks source link

Reduce the overhead of `IndexInput#prefetch` when data is cached in RAM. #13381

Closed jpountz closed 1 month ago

jpountz commented 1 month ago

As Robert pointed out and benchmarks confirmed, there is some (small) overhead to calling madvise via the foreign function API, benchmarks suggest it is in the order of 1-2us. This is not much for a single call, but may become non-negligible across many calls. Until now, we only looked into using prefetch() for terms, skip data and postings start pointers which are a single prefetch() operation per segment per term.

But we may want to start using it in cases that could result into more calls to madvise, e.g. if we start using it for stored fields and a user requests 10k documents. In #13337, Robert wondered if we could take advantage of mincore() to reduce the overhead of IndexInput#prefetch(), which is what this PR is doing.

For now, this is trying to not add new APIs. Instead, IndexInput#prefetch tracks consecutive hits on the page cache and calls madvise less and less frequently under the hood as the number of cache hits increases.

jpountz commented 1 month ago
I slightly modified the benchmark from #13337 ```java import java.io.IOException; import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Random; import java.util.concurrent.ThreadLocalRandom; import org.apache.lucene.store.Directory; import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.store.MMapDirectory; public class PrefetchBench { private static final int NUM_TERMS = 3; private static final long FILE_SIZE = 100L * 1024 * 1024 * 1024; // 100GB private static final int NUM_BYTES = 16; public static int DUMMY; public static void main(String[] args) throws IOException { Path filePath = Paths.get(args[0]); Path dirPath = filePath.getParent(); String fileName = filePath.getFileName().toString(); Random r = ThreadLocalRandom.current(); try (Directory dir = new MMapDirectory(dirPath)) { if (Arrays.asList(dir.listAll()).contains(fileName) == false) { try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { byte[] buf = new byte[8196]; for (long i = 0; i < FILE_SIZE; i += buf.length) { r.nextBytes(buf); out.writeBytes(buf, buf.length); } } } for (boolean dataFitsInCache : new boolean[] { false, true}) { try (IndexInput i0 = dir.openInput("file", IOContext.DEFAULT)) { byte[][] b = new byte[NUM_TERMS][]; for (int i = 0; i < NUM_TERMS; ++i) { b[i] = new byte[NUM_BYTES]; } IndexInput[] inputs = new IndexInput[NUM_TERMS]; if (dataFitsInCache) { // 16MB slice that should easily fit in the page cache inputs[0] = i0.slice("slice", 0, 16 * 1024 * 1024); } else { inputs[0] = i0; } for (int i = 1; i < NUM_TERMS; ++i) { inputs[i] = inputs[0].clone(); } final long length = inputs[0].length(); List[] latencies = new List[2]; latencies[0] = new ArrayList<>(); latencies[1] = new ArrayList<>(); for (int iter = 0; iter < 100_000; ++iter) { final boolean prefetch = (iter & 1) == 0; final long start = System.nanoTime(); for (IndexInput ii : inputs) { final long offset = r.nextLong(length - NUM_BYTES); ii.seek(offset); if (prefetch) { ii.prefetch(offset, 1); } } for (int i = 0; i < NUM_TERMS; ++i) { inputs[i].readBytes(b[i], 0, b[i].length); } final long end = System.nanoTime(); // Prevent the JVM from optimizing away the reads DUMMY = Arrays.stream(b).mapToInt(Arrays::hashCode).sum(); latencies[iter & 1].add(end - start); } latencies[0].sort(null); latencies[1].sort(null); System.out.println("Data " + (dataFitsInCache ? "fits" : "does not fit") + " in the page cache"); long prefetchP50 = latencies[0].get(latencies[0].size() / 2); long prefetchP90 = latencies[0].get(latencies[0].size() * 9 / 10); long prefetchP99 = latencies[0].get(latencies[0].size() * 99 / 100); long noPrefetchP50 = latencies[1].get(latencies[1].size() / 2); long noPrefetchP90 = latencies[1].get(latencies[1].size() * 9 / 10); long noPrefetchP99 = latencies[1].get(latencies[1].size() * 99 / 100); System.out.println(" With prefetching: P50=" + prefetchP50 + "ns P90=" + prefetchP90 + "ns P99=" + prefetchP99 + "ns"); System.out.println(" Without prefetching: P50=" + noPrefetchP50 + "ns P90=" + noPrefetchP90 + "ns P99=" + noPrefetchP99 + "ns"); } } } } } ```

It gives the following results. Before the change:

Data does not fit in the page cache
  With prefetching:    P50=88080ns P90=122970ns P99=157420ns
  Without prefetching: P50=224040ns P90=242320ns P99=297470ns
Data fits in the page cache
  With prefetching:    P50=880ns P90=1060ns P99=1370ns
  Without prefetching: P50=190ns P90=280ns P99=580ns

After the change:

Data does not fit in the page cache
  With prefetching:    P50=89710ns P90=124780ns P99=159400ns
  Without prefetching: P50=224271ns P90=242940ns P99=297371ns
Data fits in the page cache
  With prefetching:    P50=210ns P90=300ns P99=630ns
  Without prefetching: P50=200ns P90=290ns P99=580ns
uschindler commented 1 month ago

P.S.: Actually when looking at the code, the MemorySegment#load() method calls madvise(MADV_WILLNEED). So we could also implement prefetch using load(). You can follow that through the same chain of classes/call like in my previous review comment.

The only problem with that is: After doing the madvise, it touches a byte in each page to actually trigger the load synchronously. So we have to stay with our direct native call here.

rmuir commented 1 month ago

somewhat related: i was playing around with the new cachestat syscall (it isn't relevant to us here directly, takes fd, etc), but the background did bring up the opposite concern of this PR:

such an application can learn whether the pages it is prefetching into the cache are still there by the time it gets around to using them. If those pages are being evicted, the prefetching is overloading the page cache and causing more work overall; in such situations, the application can back off and get better performance.

https://lwn.net/Articles/917096/

You can play around with it easily on linux 6.x from the commandline:

$ fincore --output-all myindexdir/*
rmuir commented 1 month ago

Maybe if we didn't close the fd in mmapdir we could eventually think about making use of this on modern linux. it doesn't have a glibc wrapper yet... here is minimal sample code, but maybe just look at fincore for a more functional example: https://github.com/util-linux/util-linux/blob/master/misc-utils/fincore.c

#include <sys/syscall.h>
#include <linux/mman.h>
#include <fcntl.h>
#include <stdio.h>
#include <unistd.h>

int
cachestat(int fd, struct cachestat_range *range, struct cachestat *stats, int flags) {
  return syscall(SYS_cachestat, fd, range, stats, flags);
}

int main(int argc, char **argv) {
  int fd;

  if (argc != 2) {
    printf("usage: %s <file>\n", argv[0]);
    return 2;
  }

  if ((fd = open(argv[1], O_RDONLY)) < 0) {
    perror("couldn't open");
    return 1;
  }

  struct cachestat_range range = { 0, 0 };
  struct cachestat cstats;
  if (cachestat(fd, &range, &cstats, 0) != 0) {
    perror("couldn't cachestat");
    return 1;
  }

  printf("cached: %llu\ndirty: %llu\nwriteback: %llu\nevicted: %llu\nrecently_evicted: %llu\n",
      cstats.nr_cache, cstats.nr_dirty, cstats.nr_writeback, cstats.nr_evicted, cstats.nr_recently_evicted);

  return 0;
}
jpountz commented 1 month ago

such an application can learn whether the pages it is prefetching into the cache are still there by the time it gets around to using them

This is an interesting idea!

I was discussing this potential problem with @tveasey the other day. With terms and postings, we're currently only looking into loading a few pages in parallel per search thread and we then use them immediately. With GBs of capacity for the page cache, it would be extremely unlikely for these pages to get evicted in the meantime. But if/when we start looking into using prefetch() for bigger regions (e.g. stored fields) and/or possibly longer before needing the data (e.g. starting prefetching data for the next segment while we're scoring the current segment), then this could become a problem indeed. It would be nice if we could learn to disable prefetching when it's not working as intended. This would make this API safer to use.

jpountz commented 1 month ago
I added "search" concurrency to the benchmark to make it a bit more realistic ```java import java.io.IOException; import java.io.UncheckedIOException; import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Random; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ThreadLocalRandom; import org.apache.lucene.store.Directory; import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.store.MMapDirectory; import org.apache.lucene.util.ThreadInterruptedException; public class PrefetchBench { private static final int CONCURRENCY = 10; private static final int NUM_TERMS = 3; private static final long FILE_SIZE = 100L * 1024 * 1024 * 1024; // 100GB private static final int NUM_BYTES = 16; public static int DUMMY; public static void main(String[] args) throws Exception { Path filePath = Paths.get(args[0]); Path dirPath = filePath.getParent(); String fileName = filePath.getFileName().toString(); Random r = ThreadLocalRandom.current(); try (Directory dir = new MMapDirectory(dirPath)) { if (Arrays.asList(dir.listAll()).contains(fileName) == false) { try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { byte[] buf = new byte[8196]; for (long i = 0; i < FILE_SIZE; i += buf.length) { r.nextBytes(buf); out.writeBytes(buf, buf.length); } } } for (boolean dataFitsInCache : new boolean[] { false, true}) { try (IndexInput i0 = dir.openInput("file", IOContext.DEFAULT)) { final IndexInput input; if (dataFitsInCache) { // 16MB slice that should easily fit in the page cache input = i0.slice("slice", 0, 16 * 1024 * 1024); } else { input = i0; } final CountDownLatch latch = new CountDownLatch(1); RandomReader[] readers = new RandomReader[CONCURRENCY]; for (int i = 0; i < readers.length; ++i) { IndexInput[] inputs = new IndexInput[NUM_TERMS]; for (int j = 0; j < inputs.length; ++j) { inputs[j] = input.clone(); } readers[i] = new RandomReader(inputs, latch); readers[i].start(); } latch.countDown(); List prefetchLatencies = new ArrayList<>(); List noPrefetchLatencies = new ArrayList<>(); for (RandomReader reader : readers) { reader.join(); prefetchLatencies.addAll(reader.latencies[0]); noPrefetchLatencies.addAll(reader.latencies[1]); } prefetchLatencies.sort(null); noPrefetchLatencies.sort(null); System.out.println("Data " + (dataFitsInCache ? "fits" : "does not fit") + " in the page cache"); long prefetchP50 = prefetchLatencies.get(prefetchLatencies.size() / 2); long prefetchP90 = prefetchLatencies.get(prefetchLatencies.size() * 9 / 10); long prefetchP99 = prefetchLatencies.get(prefetchLatencies.size() * 99 / 100); long noPrefetchP50 = noPrefetchLatencies.get(noPrefetchLatencies.size() / 2); long noPrefetchP90 = noPrefetchLatencies.get(noPrefetchLatencies.size() * 9 / 10); long noPrefetchP99 = noPrefetchLatencies.get(noPrefetchLatencies.size() * 99 / 100); System.out.println(" With prefetching: P50=" + prefetchP50 + "ns P90=" + prefetchP90 + "ns P99=" + prefetchP99 + "ns"); System.out.println(" Without prefetching: P50=" + noPrefetchP50 + "ns P90=" + noPrefetchP90 + "ns P99=" + noPrefetchP99 + "ns"); } } } } private static class RandomReader extends Thread { private final IndexInput[] inputs; private final CountDownLatch latch; private final byte[][] b = new byte[NUM_TERMS][]; final List[] latencies = new List[2]; RandomReader(IndexInput[] inputs, CountDownLatch latch) { this.inputs = inputs; this.latch = latch; latencies[0] = new ArrayList<>(); latencies[1] = new ArrayList<>(); for (int i = 0; i < NUM_TERMS; ++i) { b[i] = new byte[NUM_BYTES]; } } @Override public void run() { try { latch.await(); final ThreadLocalRandom r = ThreadLocalRandom.current(); final long length = inputs[0].length(); for (int iter = 0; iter < 100_000; ++iter) { final boolean prefetch = (iter & 1) == 0; final long start = System.nanoTime(); for (IndexInput ii : inputs) { final long offset = r.nextLong(length - NUM_BYTES); ii.seek(offset); if (prefetch) { ii.prefetch(offset, 1); } } for (int i = 0; i < NUM_TERMS; ++i) { inputs[i].readBytes(b[i], 0, b[i].length); } final long end = System.nanoTime(); // Prevent the JVM from optimizing away the reads DUMMY = Arrays.stream(b).mapToInt(Arrays::hashCode).sum(); latencies[iter & 1].add(end - start); } } catch (IOException e) { throw new UncheckedIOException(e); } catch (InterruptedException e) { throw new ThreadInterruptedException(e); } } } } ```

On the latest version of this PR, it reports:

Data does not fit in the page cache
  With prefetching:    P50=104260ns P90=159710ns P99=228880ns
  Without prefetching: P50=242580ns P90=315821ns P99=405901ns
Data fits in the page cache
  With prefetching:    P50=310ns P90=6700ns P99=12320ns
  Without prefetching: P50=290ns P90=6770ns P99=11610ns

vs. the following on main:

Data does not fit in the page cache
  With prefetching:    P50=97620ns P90=153050ns P99=220510ns
  Without prefetching: P50=226690ns P90=302530ns P99=392770ns
Data fits in the page cache
  With prefetching:    P50=6970ns P90=9380ns P99=12300ns
  Without prefetching: P50=290ns P90=5890ns P99=8560ns