ninia / jep

Embed Python in Java
Other
1.31k stars 149 forks source link

Performance issues with Java multi-threads invoke the python script #468

Closed Dengsi closed 1 year ago

Dengsi commented 1 year ago

Please take a look at the code below:

import jep.JepConfig;
import jep.SharedInterpreter;

import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class JepUtils {
    public static void main(String[] args) throws Exception {
        JepConfig jepConfig = new JepConfig();
        jepConfig.redirectStdout(System.out);
        jepConfig.redirectStdErr(System.err);
        SharedInterpreter.setConfig(jepConfig);

        List<Double> data = IntStream.range(0, 10 * 10000).mapToObj(Double::valueOf).collect(Collectors.toList());
        for (int i = 0; i < 10; i++) {
            final int index = i;
            new Thread(() -> {
                try (SharedInterpreter interpreter = new SharedInterpreter()) {
                    long beginTime = System.currentTimeMillis();
                    interpreter.set("data", data);
                    interpreter.exec("import numpy as np");
                    interpreter.exec("fft_data = np.fft.fft(data).astype('float64')");
                    System.out.printf("Worker %s success: cost=>%sms, result=>%s%n", index, System.currentTimeMillis() - beginTime, interpreter.getValue("fft_data"));
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }).start();
        }
    }
}

My requirement is to implement FFT conversion through Java calling Python scripts. When I use one thread, it takes about 200 milliseconds, but if I run 10 threads simultaneously, it takes over 30 seconds. Could you please help me identify the possible reasons for this long time consumption? Thank you very much! This question has been confusing me for a long time.

bsteffensmeier commented 1 year ago

That is a very prominent slowdown Thank you for making a standalone example that I can run easily. I was able to recreate your problem with your example code.

After quite a bit of digging and experimentation I have found a majority of your problem comes from 6 lines of code in Jep. When I comment out these 6 lines the execution time of 10 threads goes from 23 seconds to 1.3 seconds. It is still about 5 times slower than a single thread but much improved from the current state.

  1. https://github.com/ninia/jep/blob/v4.1.1/src/main/c/Jep/java_access/Iterator.c#L36
  2. https://github.com/ninia/jep/blob/v4.1.1/src/main/c/Jep/java_access/Iterator.c#L40
  3. https://github.com/ninia/jep/blob/v4.1.1/src/main/c/Jep/java_access/Iterator.c#L47
  4. https://github.com/ninia/jep/blob/v4.1.1/src/main/c/Jep/java_access/Iterator.c#L51
  5. https://github.com/ninia/jep/blob/v4.1.1/src/main/c/Jep/java_access/Number.c#L51
  6. https://github.com/ninia/jep/blob/v4.1.1/src/main/c/Jep/java_access/Number.c#L55

All those are places where we use Py_BEGIN_ALLOW_THREADS and Py_END_ALLOW_THREADS. Ironically the reason we do that is to increase Python concurrency, however in this case it ends up taking more time than the Python itself. More information about these macros can be found in the Python Documentation for releasing the GIL. The summary is that for long running operations that don't involve Python objects we want to release the GIL but for fast operations we shouldn't bother.

Unfortunately I don't think the solution is as simple as just removing those 6 lines. In this case those particular operations are very fast and it would make more sense to skip the GIL manipulations, the problem is we just don't know if that is generally true. As a Java developer I have used iterators where that are backed by a database connection and hasNext() or next() could be a long operation that is reaching across the network, this is the exact type of scenario where it is better to release the GIL rather than hold onto it so I really don't think it is safe to assume any Java operation is going to be fast.

Our general philosophy in Jep is to release the GIL whenever we execute any Java code, we essentially assume all Java operations are going to take long enough to justify releasing the GIL. After seeing how significant the impact is in this case we may want to revisit that philosophy but that has implications for how maintainable and reusable the Jep code is so I don't think we will be making changes quickly.

My current recommendation to improve the situation is to change your code so it is not crossing back and forth between Java and Python as often. In this case a List<Double> is going to require Python to go back into Java when accessing each element. That means each thread is going back and forth over 100,000 times and since there are 3 places where we release the GIL each time that is alot of GIL operations.

For numeric data you should see the best performance if you can use a direct buffer because then each language can access the same memory without any copying or coordination for individual elements. I changed your example to use a DoubleBuffer as data, using the code below to populate it with the same information. On my system it took 140ms for one thread and each of the 10 threads took about 180ms each, much better performance.

DoubleBuffer data = ByteBuffer.allocateDirect(8 * 10 * 10000).asDoubleBuffer();
IntStream.range(0, 10 * 10000).asDoubleStream().forEach(data::put);

You should be able to see similar performance by using NDArray or DirectNDArray. Even using a double[] instead of a List<Double> eliminated the multi-thread penalty for me.

Dengsi commented 1 year ago

Thank you very much for your serious reply, which is very clear to me. And I now know how to improve the issues in this scenario.

rayduan commented 1 year ago

That is a very prominent slowdown Thank you for making a standalone example that I can run easily. I was able to recreate your problem with your example code.

After quite a bit of digging and experimentation I have found a majority of your problem comes from 6 lines of code in Jep. When I comment out these 6 lines the execution time of 10 threads goes from 23 seconds to 1.3 seconds. It is still about 5 times slower than a single thread but much improved from the current state.

  1. https://github.com/ninia/jep/blob/v4.1.1/src/main/c/Jep/java_access/Iterator.c#L36
  2. https://github.com/ninia/jep/blob/v4.1.1/src/main/c/Jep/java_access/Iterator.c#L40
  3. https://github.com/ninia/jep/blob/v4.1.1/src/main/c/Jep/java_access/Iterator.c#L47
  4. https://github.com/ninia/jep/blob/v4.1.1/src/main/c/Jep/java_access/Iterator.c#L51
  5. https://github.com/ninia/jep/blob/v4.1.1/src/main/c/Jep/java_access/Number.c#L51
  6. https://github.com/ninia/jep/blob/v4.1.1/src/main/c/Jep/java_access/Number.c#L55

All those are places where we use Py_BEGIN_ALLOW_THREADS and Py_END_ALLOW_THREADS. Ironically the reason we do that is to increase Python concurrency, however in this case it ends up taking more time than the Python itself. More information about these macros can be found in the Python Documentation for releasing the GIL. The summary is that for long running operations that don't involve Python objects we want to release the GIL but for fast operations we shouldn't bother.

Unfortunately I don't think the solution is as simple as just removing those 6 lines. In this case those particular operations are very fast and it would make more sense to skip the GIL manipulations, the problem is we just don't know if that is generally true. As a Java developer I have used iterators where that are backed by a database connection and hasNext() or next() could be a long operation that is reaching across the network, this is the exact type of scenario where it is better to release the GIL rather than hold onto it so I really don't think it is safe to assume any Java operation is going to be fast.

Our general philosophy in Jep is to release the GIL whenever we execute any Java code, we essentially assume all Java operations are going to take long enough to justify releasing the GIL. After seeing how significant the impact is in this case we may want to revisit that philosophy but that has implications for how maintainable and reusable the Jep code is so I don't think we will be making changes quickly.

My current recommendation to improve the situation is to change your code so it is not crossing back and forth between Java and Python as often. In this case a List<Double> is going to require Python to go back into Java when accessing each element. That means each thread is going back and forth over 100,000 times and since there are 3 places where we release the GIL each time that is alot of GIL operations.

For numeric data you should see the best performance if you can use a direct buffer because then each language can access the same memory without any copying or coordination for individual elements. I changed your example to use a DoubleBuffer as data, using the code below to populate it with the same information. On my system it took 140ms for one thread and each of the 10 threads took about 180ms each, much better performance.

DoubleBuffer data = ByteBuffer.allocateDirect(8 * 10 * 10000).asDoubleBuffer();
IntStream.range(0, 10 * 10000).asDoubleStream().forEach(data::put);

You should be able to see similar performance by using NDArray or DirectNDArray. Even using a double[] instead of a List<Double> eliminated the multi-thread penalty for me.

hi,is any version to fix this question?