silx-kit / pyFAI

Fast Azimuthal Integration in Python
Other
107 stars 96 forks source link

Implement `medfilt1d_ng` #2285

Open kif opened 2 months ago

kif commented 2 months ago
kif commented 2 months ago

After preliminary work in #2261 we validated it was possible. Performances are expected to be 5s for 2Mpix image using Python

kif commented 1 month ago

Quick&dirty implementation in python:

def azimuthal_median():
    q_start = 0.2
    q_stop = 0.8
    signal = numpy.zeros(csr[2].size-1, dtype="float64")
    norm = numpy.zeros(csr[2].size-1, dtype="float64")
    norm2 = numpy.zeros(csr[2].size-1, dtype="float64")
    variance = numpy.zeros(csr[2].size-1, dtype="float64")
    for i,start in enumerate(csr[2][:-1]):
        stop = csr[2][i+1]
        tmp = numpy.sort(work1[start:stop])
        upper = numpy.cumsum(tmp["n"])
        last = upper[-1]
        lower = numpy.concatenate(([0],upper[:-1]))
        mask = numpy.logical_and(upper>=q_start*last, lower<=q_stop*last)
        tmp = tmp[mask]
        signal[i] = tmp["s"].sum(dtype="float64")
        variance[i] = tmp["v"].sum(dtype="float64")
        norm[i] = tmp["n"].sum(dtype="float64")
        norm2[i] = (tmp["n"]**2).sum(dtype="float64")
    return signal, variance, norm, norm2
kif commented 1 month ago

Non-regression test:

%time res = cython_medfilt3(prep, csr[0], csr[1], csr[2], (0.0,1.0))
%timeit res = cython_medfilt3(prep, csr[0], csr[1], csr[2], (0.0,1.0))

ref = ai.integrate1d(img, 2500, method=("full", "csr", "cython"), unit="r_mm", error_model="poisson")
abs((ref.sum_signal -res[0])/ref.sum_signal).max(), \
abs((ref.sum_variance -res[1])/ref.sum_variance).max(),\
abs((ref.sum_normalization -res[2])/ref.sum_normalization).max(),\
abs((ref.sum_normalization2 -res[3])/ref.sum_normalization2).max()
kif commented 1 month ago

Quick&dirty implementation in cython:

%%cython -a -c-fopenmp --link-args=-fopenmp

# distutils: language = c++
#cython: boundscheck=False, wraparound=False, cdivision=True, initializedcheck=False

import numpy
from libcpp cimport bool
from libcpp.algorithm cimport sort
from cython cimport floating
from cython.parallel import prange

cdef struct float4:
    float a
    float s
    float v
    float n

cdef bool cmp(float4 a, float4 b) noexcept nogil:
    return True if a.a<b.a else False

cdef void sort_3(float4[::1] ary) noexcept nogil:
    cdef:
        int size
    size = ary.shape[0]
    sort(&ary[0], &ary[size-1]+1, cmp)
    # return numpy.asarray(ary)

def cython_medfilt3(float[:, ::1] prep, 
                    float[::1] data, 
                    int[::1]indices, 
                    int[::1] indptr, 
                    quant):
    cdef:
        int i, nbins, npix, idx, j, start, stop
        float sum_
        double s,v,n,n2,d
        float4 w
        float4[::1] work
        float qmin,qmax, q_start, q_stop
        float[::1] pixel

        double[::1] signal, norm, norm2, variance

    q_start = float(quant[0])
    q_stop  = float(quant[1])

    nbins = indptr.shape[0] - 1
    npix = indices.shape[0]
    signal = numpy.zeros(nbins, dtype="float64")
    norm = numpy.zeros(nbins, dtype="float64")
    norm2 = numpy.zeros(nbins, dtype="float64")
    variance = numpy.zeros(nbins, dtype="float64")
    # prep = preproc_cython(img, mask=ai.detector.mask, solidangle=ai.solidAngleArray(), error_model=ErrorModel.POISSON,split_result=3).ershape((-1,3))

    work = numpy.zeros(npix, dtype=numpy.dtype([('a','f4'),('s','f4'),('v','f4'),('n','f4')]))
    for i in prange(npix, nogil=True):
        j = indices[i]
        s = prep[j,0]
        v = prep[j,1]
        n = prep[j,2]
        d = data[i]
        w.a = s/n
        w.s = s * d
        w.v = v * d * d
        w.n = n * d
        work[i] = w

    for idx in prange(nbins, nogil=True, schedule="guided"): #, 
        start = indptr[idx] 
        stop = indptr[idx+1]

        v = s = n = n2 = 0.0
        sort_3(work[start:stop])
        sum_ = 0.0
        for i in range(start, stop):
            sum_ = sum_ + work[i].n
            work[i].a = sum_
        qmin = q_start * sum_
        qmax = q_stop * sum_
        # print(sum_, qmin, qmax)
        for i in range(start, stop):        
            if (0.0 if i==0 else work[i-1].a) >= qmin and work[i].a <= qmax:
                w = work[i]
                s = s + w.s
                v = v + w.v
                n = n + w.n
                n2 = n2 + w.n*w.n
        signal[idx] = s
        variance[idx] = v
        norm[idx] = n
        norm2[idx] = n2

    return numpy.asarray(signal), numpy.asarray(variance), numpy.asarray(norm), numpy.asarray(norm2), numpy.asarray(work)
kif commented 1 month ago

Sort algorithm in OpenCL for variable size ensemble:

raison = 1.3
pas = [1,2,3,4,6,8,11]
last = pas[-1]
while last<size:
    last = ceil(last*raison)
    pas.append(last)
else:
    pas = pas[:-1]

def swap(lst, i, j):
    "swap two elements if needed, in place. return 1 if a swap occured."
    if lst[i]>lst[j]:
        lst[i],lst[j] = lst[j],lst[i]
        return 1
    else:
        return 0

def passe(lst, pas=1):
    size = len(lst)
    perm = []
    cnt = 0
    if 2*pas>=size:
        parallel=[]
        for i in range(0,size-pas):
            cnt+=swap(lst, i, i+pas)
            parallel.append((i, i+pas))
        perm.append(parallel)
    elif pas == 1:
        for j in range(2):
            parallel=[]
            for i in range(j,size-pas, 2):
                cnt+=swap(lst, i, i+pas)
                parallel.append((i, i+pas))
            perm.append(parallel)
    else:
        parallel=[]
        for i in range(0, size-pas, 2*pas):            
            for j in range(i, i+pas):
                k = j+pas
                if (k<size):
                    cnt+=swap(lst, j, k)
                    parallel.append((j, k))
        perm.append(parallel)        
        parallel=[]
        for i in range(pas, size-pas, 2*pas):            
            for j in range(i, i+pas):
                k = j+pas
                if (k<size):
                    cnt+=swap(lst, j, k)
                    parallel.append((j, k))
        perm.append(parallel)        
    return cnt, perm

def ParallelCombSort(lst):
    perm = []
    c = 0
    cnt = 0
    extra = 0
    for p in pas[::-1]:
        if p >= len(lst):
            continue
        c, pl = passe(lst, p)
        cnt +=c
        if pl:
            perm += pl
    while c:
        c, pl = passe(lst, 1)
        cnt +=c
        perm += pl
        extra+=1
    return perm, extra

Scales in 5log(n) passes

kif commented 1 month ago

OpenCL version of the comb-sort algorithm:

%%cl_kernel

// returns 1 if swapped, else 0
int compare_and_swap(global volatile float* elements, int i, int j)
{
    float vi = elements[i];
    float vj = elements[j];
    if (vi>vj)
    {
        elements[i] = vj;
        elements[j] = vi;
        return 1;
    }
    else
        return 0;
}

// returns the number of swap performed
int passe(global volatile float* elements, 
          int size,
          int step,
          local volatile int* shared)
{
    int wg = get_local_size(0);
    int tid = get_local_id(0);
    int cnt = 0;
    int i, j, k;

    if (2*step>=size)
    {
        for (i=tid;i<size-step;i+=wg)
            cnt += compare_and_swap(elements, i, i+step);
    }
    else if (step == 1)
    {
        for (i=2*tid; i<size-step; i+=2*wg)
            cnt+=compare_and_swap(elements, i, i+step);
        barrier(CLK_GLOBAL_MEM_FENCE);
        for (i=2*tid+1; i<size-step; i+=2*wg)
            cnt+=compare_and_swap(elements, i, i+step);
    }
    else
    {
        for (i=tid*2*step; i<size-step; i+=2*step*wg)
        {
            for (j=i; j<i+step; j++)
            {
                k  = j + step;
                if (k<size)
                    cnt += compare_and_swap(elements, j, k);
            }
        }
        barrier(CLK_GLOBAL_MEM_FENCE);
        for (i=tid*2*step+step; i<size-step; i+=2*step*wg)
        {
            for (j=i; j<i+step; j++)
            {
                k  = j + step;
                if (k<size)
                    cnt += compare_and_swap(elements, j, k);
            }
        }
    }

    // local reduction to sum all swaps performed
    shared[tid] = cnt;
    barrier(CLK_LOCAL_MEM_FENCE);
    for (i=wg/2; i<1; i/=2)
    {
        if ((tid+i)<wg)
            shared[tid] += shared[tid+i];
        barrier(CLK_LOCAL_MEM_FENCE);
    }
    barrier(CLK_GLOBAL_MEM_FENCE);
    return shared[0];
}

int inline next_step(int step, float raison)
{
    return convert_int_rtp((float)step*raison);
}

int inline previous_step(int step, float raison)
{
    return convert_int_rtn((float)step/raison);
}

kernel void combsort(global volatile float* elements, 
                     int size)
{
    local volatile int shared[1024];
    int step = 11;     // magic value
    float raison=1.3f; // magic value
    int cnt;

    while (step<size)
        step=next_step(step, raison);
    if (get_local_id(0) == 0) printf("+ %d %d\n", step, size);

    while (step>=size)
        step=previous_step(step, raison);
    if (get_local_id(0) == 0) printf("- %d %d\n", step, size);

    for (step=step; step>0; step=previous_step(step, raison))
    {      
        cnt = passe(elements, size, step, shared);
        //if (get_local_id(0) == 0) printf("o %d %d %d\n", step, size, cnt);
    }

    step = 1;
    while (cnt){
        cnt = passe(elements, size, step, shared);
        //if (get_local_id(0) == 0) printf("= %d %d %d\n", step, size, cnt);
    }

}