erc-compact / skyweaver

Implementation of an offline FBFUSE beamformer
MIT License
0 stars 1 forks source link

Enable alternative stokes modes #14

Closed ewanbarr closed 2 months ago

ewanbarr commented 2 months ago

Currently we support I, Q, U, V and IQUV. There are requests to be able to produce subsets of the stokes parameters, e.g. I,V and Q,U.

Rather than handling these all individually it would make sense to consolidate SingleStokesBeamformerTraits and FullStokesBeamformerTraits into a generic StokesTraits<...> which can take up to 4 template arguments defining the ordering and types of stokes parameters to be generated.

A prototype version of this feature is below.

Adding this feature will require changing the beamformer_utils.cuh file and the pipeline setup for the beamformer will have to be updated. Additionally as the underlying types will be float2, float3 etc. generic math operators for these types will be needed to be added to the types.cuh file to make these types compatible with downstream code.

#include <iostream>
#include <type_traits>
#include <limits>

#define AT(var, idx) accessor<decltype(var)>::template at<idx>(var)

struct float2 {
    float x, y;
};

struct float3 {
    float x, y, z;
};

struct float4 {
    float x, y, z, w;
};

struct char2 {
    int8_t x, y;
};

struct char3 {
    int8_t x, y, z;
};

struct char4 {
    int8_t x, y, z, w;
};

/**
 * @brief Calculate the square of a complex number
 */
static inline float cuCmagf(float2 x)
{
    return x.x * x.x + x.y * x.y;
}

static inline float2 cuConjf(float2 x)
{
    return float2{x.x, -x.y};
}

static inline float cuCimagf(float2 x)
{
    return x.y;
}

static inline float cuCrealf(float2 x)
{
    return x.x;
}

static inline float fminf (float a, float b)
{
    return std::min(a,b);
}

static inline float fmaxf (float a, float b)
{
    return std::max(a,b);
}

static inline int rintf (float b)
{
    return static_cast<int>(b);
}

static inline float2 cuCmulf(float2 a, float2 b)
{
    float2 result{};
    result.x = a.x * b.x - a.y * b.y;
    result.x = a.x * b.x + a.y * b.y;
    return result;
}

enum StokesParameter { I, Q, U, V };

template <bool flag = false>
void static_no_match()
{
    static_assert(flag, "no match");
}

/**
 * @brief Calculate the define Stokes parameter
 *
 * Stokes modes can be considered here:
 * I = P0^2 + P1^2
 * Q = P0^2 - P1^2
 * U = 2 * Re(P0 * conj(P1))
 * V = 2 * Im(P0 * conj(P1))
 */
template <StokesParameter Stokes>
static inline float
calculate_stokes(float2 const& p0, float2 const& p1)
{
    if constexpr(Stokes == StokesParameter::I) {
        return cuCmagf(p0) + cuCmagf(p1);
    } else if constexpr(Stokes == StokesParameter::Q) {
        return cuCmagf(p0) - cuCmagf(p1);
    } else if constexpr(Stokes == StokesParameter::U) {
        return 2 * cuCrealf(cuCmulf(p0, cuConjf(p1)));
    } else if constexpr(Stokes == StokesParameter::V) {
        return 2 * cuCimagf(cuCmulf(p0, cuConjf(p1)));
    } else {
        static_no_match();
    }
}

template <int N> struct stokes_storage_type {};
template <> struct stokes_storage_type<1> {
    using QuantisedPowerType = int8_t;
    using RawPowerType = float;
};
template <> struct stokes_storage_type<2> {
    using QuantisedPowerType = char2;
    using RawPowerType = float2;
};
template <> struct stokes_storage_type<3> {
    using QuantisedPowerType = char3;
    using RawPowerType = float3;
};
template <> struct stokes_storage_type<4> {
    using QuantisedPowerType = char4;
    using RawPowerType = float4;
};

// Define a template struct for element_type
template <typename T>
struct element_type {
};

// Specialize for basic types
template <>
struct element_type<float> {
    using type = float;
};

template <>
struct element_type<float2> {
    using type = float;
};

template <>
struct element_type<float3> {
    using type = float;
};

template <>
struct element_type<float4> {
    using type = float;
};

template <>
struct element_type<int8_t> {
    using type = int8_t;
};

template <>
struct element_type<char2> {
    using type = int8_t;
};

template <>
struct element_type<char3> {
    using type = int8_t;
};

template <>
struct element_type<char4> {
    using type = int8_t;
};

template <typename T>
struct accessor {
    // Function to access members based on index
    using base_type = typename element_type<std::decay_t<T>>::type;
    using return_type = std::conditional_t<
                        std::is_const_v<std::remove_reference_t<T>>,
                        base_type const&,
                        base_type&>;

    template <int N>
    static inline return_type at(T in)
    {
        if constexpr (std::is_same_v<std::decay_t<T>, float> || std::is_same_v<std::decay_t<T>, int8_t>) {
            // Handle the case for float and int8_t
            static_assert(N == 0, "Index out of bounds for float or int8_t");
            return in;
        } else {
            // Handle the case for types with x, y, z, w members
            if constexpr (N == 0) {
                return static_cast<return_type>(in.x);
            } else if constexpr (N == 1) {
                return static_cast<return_type>(in.y);
            } else if constexpr (N == 2) {
                return static_cast<return_type>(in.z);
            } else if constexpr (N == 3) {
                return static_cast<return_type>(in.w);
            } else {
                static_assert(N < 4, "Index out of bounds for type with x, y, z, w");
            }
        }
    }
};

// A generic function template to be applied at compile-time
template <int Index, StokesParameter S>
struct Process {
    template <typename Operator, typename... Args>
    static void apply(Args&&... args)
    {
        Operator::template apply<Index, S>(std::forward<Args>(args)...);
    }
};

// Base case for recursion: no elements left
template <int Index, StokesParameter... S>
struct Iterate {
    template <typename Operator, typename... Args>
    static void apply(Args&&... args)
    {
        // No-op when there are no more values
    }
};

// Recursive case: process the first element and recurse
template <int Index, StokesParameter First, StokesParameter... Rest>
struct Iterate<Index, First, Rest...> {
    template <typename Operator, typename... Args>
    static void apply(Args&&... args)
    {
        Process<Index, First>::template apply<Operator>(
            std::forward<Args>(args)...);
        Iterate<Index + 1, Rest...>::template apply<Operator>(
            std::forward<Args>(args)...); // Recurse with the next element
    }
};

struct IntegrateStokes {
    template <int I, StokesParameter S, typename T>
    static void
    apply(float2 const& p0, float2 const& p1, T& power)
    {
        AT(power, I) += calculate_stokes<S>(p0, p1);
    }
};

struct IntegrateWeightedStokes {
    template <int I, StokesParameter S, typename T>
    static void apply(float2 const& p0,
                      float2 const& p1,
                      T& power,
                      float const& weight)
    {
        AT(power, I) += calculate_stokes<S>(p0, p1) * weight;
    }
};

struct IncoherentBeamSubtract {
    template <int I, StokesParameter S, typename T>
    static void apply(T const& power,
                      T const& ib_power,
                      float const& ib_mutliplier, // 127^2 as default
                      float const& scale_factor,
                      T& result)
    {
        AT(result, I) = rintf((AT(power, I) - AT(ib_power, I) * ib_mutliplier) / scale_factor);
    }
};

struct Rescale {
    template <int I, StokesParameter S, typename T>
    static void apply(T const& power,
                      float const& offset,
                      float const& scale_factor,
                      T& result)
    {
        if constexpr(S == StokesParameter::I) {
            AT(result, I) = rintf((AT(power, I) - offset) / scale_factor);
        } else {
            AT(result, I) = rintf(AT(power, I) / scale_factor);
        }
    }
};

struct Clamp {
    template <int I, StokesParameter S, typename T, typename X>
    static void apply(T const& power, X& result)
    {
        using EType = typename element_type<X>::type;
        AT(result, I) = static_cast<EType>(
                            fmaxf(static_cast<float>(
                                      std::numeric_limits<EType>::lowest()),
                                  fminf(static_cast<float>(
                                            std::numeric_limits<EType>::max()),
                                        AT(power, I))));
    }
};

// Wrapper struct to start the iteration with index 0
template <int... Values>
struct Test {
    template <typename Operator, typename... Args>
    static void apply(Args&&... args)
    {
        Iterate<0, Values...>::template apply<Operator>(
            std::forward<Args>(args)...); // Start with index 0
    }
};

template <StokesParameter... Stokes>
struct StokesTraits
{
    using RawPowerType =  stokes_storage_type<sizeof...(Stokes)>::RawPowerType;
    using QuantisedPowerType =  stokes_storage_type<sizeof...(Stokes)>::QuantisedPowerType;
    constexpr static const RawPowerType zero_power = RawPowerType{};

    static inline void
    integrate_stokes(float2 const& p0,
                     float2 const& p1,
                     RawPowerType& power) {
        Iterate<0, Stokes...>::template apply<IntegrateStokes>(p0, p1, power);
    }

    static inline void
    integrate_weighted_stokes(float2 const& p0,
                              float2 const& p1,
                              RawPowerType& power,
                              float const& weight) {
        Iterate<0, Stokes...>::template apply<IntegrateWeightedStokes>(p0, p1, power, weight);
    }

    static inline RawPowerType
    ib_subtract(RawPowerType const& power,
                RawPowerType const& ib_power,
                float const& ib_mutliplier,
                float const& scale_factor) {
        RawPowerType result{};
        Iterate<0, Stokes...>::template apply<IncoherentBeamSubtract>(power, ib_power, ib_mutliplier, scale_factor, result);
        return result;
    }

    static inline RawPowerType
    rescale(RawPowerType const& power,
            float const& offset,
            float const& scale_factor) {
        RawPowerType result{};
        Iterate<0, Stokes...>::template apply<Rescale>(power, offset, scale_factor, result);
        return result;
    }

    static inline QuantisedPowerType
    clamp(RawPowerType const& power) {
        QuantisedPowerType result{};
        Iterate<0, Stokes...>::template apply<Clamp>(power, result);
        return result;
    }

};

int main()
{
    float2 p0{1.0f, 2.0f};
    float2 p1{3.0f, 4.0f};

    {
        using traits = StokesTraits<I, Q>;
        typename traits::RawPowerType power{};
        typename traits::RawPowerType ib_power{};
        typename traits::RawPowerType rescale_result = traits::rescale(power, 10.0, 20.0);
        typename traits::RawPowerType ib_subtract_result = traits::ib_subtract(power, ib_power, 127.0, 10.0);
        traits::integrate_stokes(p0, p1, power);
        traits::integrate_weighted_stokes(p0, p1, power, 1.0f);
        typename traits::QuantisedPowerType result = traits::clamp(power);
        std::cout << (int)result.x << ", " << (int)result.y <<"\n";
    }
    {
        using traits = StokesTraits<I, Q, U, V>;
        typename traits::RawPowerType power{};
        typename traits::RawPowerType ib_power{};
        typename traits::RawPowerType rescale_result = traits::rescale(power, 10.0, 20.0);
        typename traits::RawPowerType ib_subtract_result = traits::ib_subtract(power, ib_power, 127.0, 10.0);
        traits::integrate_stokes(p0, p1, power);
        traits::integrate_weighted_stokes(p0, p1, power, 1.0f);
        typename traits::QuantisedPowerType result = traits::clamp(power);
        std::cout << (int)result.x << ", " << (int)result.y << ", " << (int)result.z << ", " << (int)result.w <<"\n";
    }
    {
        using traits = StokesTraits<I>;
        typename traits::RawPowerType power{};
        typename traits::RawPowerType ib_power{};
        typename traits::RawPowerType rescale_result = traits::rescale(power, 10.0, 20.0);
        typename traits::RawPowerType ib_subtract_result = traits::ib_subtract(power, ib_power, 127.0, 10.0);
        traits::integrate_stokes(p0, p1, power);
        traits::integrate_weighted_stokes(p0, p1, power, 1.0f);
        typename traits::QuantisedPowerType result = traits::clamp(power);
        std::cout << (int)result << "\n";
    }

    return 0;
}
ewanbarr commented 2 months ago

PR #15 is now submitted. @vivekvenkris It needs reviewed as even though tests pass, this is changing a lot of base functionality of the code (like math operators) so is potentially risky.

ewanbarr commented 2 months ago

@vivekvenkris I said to hold on this because there was a simpler solution, but it turns out that the simpler solution doesn't really change much. Hence, I think we stick with this PR as it is as it compiles, passes tests and satisfies the requirement of enabling easy additional stokes modes.

vivekvenkris commented 2 months ago

Done and merged to pipeline_dev