Open jackiechen19940708 opened 1 year ago
Those CUDA kernels covert an FP32 input to MX. The output will be restricted to the allowed representable values of the MX data type. This conversion is bit-accurate to Float-to-MX conversion in hardware.
Can you specify what mismatches you are referring to? Is there a specific value or bit pattern which is not handled correctly?
It's nice work but I have some questions: we see that we use template T=float (see following code) (1)to use fp32 to represent mx data format and (2)simulating mx format calculation operation using fp32 so I think this may exist in-equivalent with real mx data format representation and operation. Do you use FPGA to evaluate how much error between the fp32-simuation and real mx data format?
template<typename T> __global__ void quantize_mx_cuda_kernel( const T* __restrict__ input, const int scale_bits, const int elem_ebits, const int elem_mbits, const float elem_max_norm, const float* __restrict__ max_values, const long total_size, const int axis_size, const int post_axis_size, const bool flush_fp32_subnorms, const RoundingMode rounding_mode, T* __restrict__ output ) {