Open serge-sans-paille opened 5 years ago
The following code
#include "xtensor/xnoalias.hpp" #include "xtensor/xtensor.hpp" #include "xtensor/xarray.hpp" #include "xtensor/xrandom.hpp" #include "xtensor/xview.hpp" #include "xtensor/xfixed.hpp" #include "xtensor/xindex_view.hpp" #define FORCE_IMPORT_ARRAY #include "xtensor-python/pyarray.hpp" #include "xtensor-python/pytensor.hpp" #include "pybind11/stl.h" using namespace xt; auto mandelbrot(double xmin, double xmax, double ymin, double ymax, double xn, double yn, long maxiter, double horizon) { auto X = linspace<double>(xmin, xmax, int(xn)); auto Y = linspace<double>(ymin, ymax, int(yn)); auto C = eval(X + view(Y, all(), newaxis())*std::complex<double>(0, 1)); xtensor<int64_t, 2> N = zeros<int64_t>(C.shape()); xtensor<std::complex<double>, 2> Z = zeros<std::complex<double>>(C.shape()); for(long n = 0; n < maxiter; ++n) { auto I = eval(abs(Z) < horizon); filter(N, I) = n; filter(Z, I) = filter(Z, I)*filter(Z, I) + filter(C, I); } filter(N, equal(N, maxiter-1)) = 0; return std::make_tuple(Z, N); } std::tuple<pytensor<std::complex<double>, 2>, pytensor<int64_t, 2>> py_mandelbrot(double xmin, double xmax, double ymin, double ymax, double xn, double yn, long maxiter, double horizon) { return mandelbrot(xmin, xmax, ymin, ymax, xn, yn, maxiter, horizon); } PYBIND11_MODULE(xtensor_mandelbrot, m) { import_numpy(); m.def("mandelbrot", py_mandelbrot); }
runs much slower than it's equivalent numpy version
import numpy as np #setup: N=100 #run: mandelbrot(0., 100., 0., 100., 100., 100., N, 50.) #pythran export mandelbrot(float, float, float, float, float, float, int, float) def mandelbrot(xmin, xmax, ymin, ymax, xn, yn, maxiter, horizon): X = np.linspace(xmin, xmax, int(xn)) Y = np.linspace(ymin, ymax, int(yn)) C = X + Y[:, None]*1j N = np.zeros(C.shape, dtype=np.int64) Z = np.zeros(C.shape, np.complex128) for n in range(maxiter): I = np.less(np.abs(Z), horizon) N[I] = n Z[I] = Z[I]**2 + C[I] N[N == maxiter-1] = 0 return Z, N
# name engine best average std mandelbrot xtensor 16991 17178 150 mandelbrot python 9607 10102 505
The following code
runs much slower than it's equivalent numpy version