This pull request includes changes to several Python files in the bitblas library, with the primary goal of improving support for different data types and making the code more robust. This includes changes to the hint.py, tensorcore.py, lop3.py, general_matmul.py, and matmul_dequantize_impl.py files. The changes can be grouped into three main categories: updates to the hint.py and tensorcore.py files to handle different data types, improvements to the lop3.py file to better handle different bit sizes, and changes to the general_matmul.py and matmul_dequantize_impl.py files to add assertions and handle different bit sizes.
Handling different data types:
python/bitblas/base/roller/hint.py: Updated the __repr__ method in the TensorCoreExtraConfig class to handle float32 and int32 data types.
python/bitblas/gpu/intrin/lop3.py: Reformatted the get_fast_decode_intrin function calls in the LOP3_FAST_DECODE_INT2_TO_INT8_TO_INT8_L16_INTRIN, LOP3_FAST_DECODE_INT1_TO_INT8_TO_INT8_L16_INTRIN, and LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_INTRIN registrations for better readability. Also added new LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_INTRIN and LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_SCALE_INTRIN registrations. [1][2][3]
Adding assertions and handling different bit sizes:
python/bitblas/ops/general_matmul.py: Added the is_not_fast_decoding_supported function in the __initialize_fast_decoding method and updated the condition in the transform_weight method to check if bit is less than 8. [1][2]
python/bitblas/ops/impl/matmul_dequantize_impl.py: Added assertions to check if bit is in [1, 2, 4, 8] in the matmul_nt_dequantize_b, matmul_nt_dequantize_b_propagate_b, and matmul_nt_dequantize_b_propagate_a_propagate_b functions. Also updated the decode_func function in these methods to handle the case where bit is 8. [1][2][3][4][5][6][7][8]
This pull request includes changes to several Python files in the
bitblas
library, with the primary goal of improving support for different data types and making the code more robust. This includes changes to thehint.py
,tensorcore.py
,lop3.py
,general_matmul.py
, andmatmul_dequantize_impl.py
files. The changes can be grouped into three main categories: updates to thehint.py
andtensorcore.py
files to handle different data types, improvements to thelop3.py
file to better handle different bit sizes, and changes to thegeneral_matmul.py
andmatmul_dequantize_impl.py
files to add assertions and handle different bit sizes.Handling different data types:
python/bitblas/base/roller/hint.py
: Updated the__repr__
method in theTensorCoreExtraConfig
class to handlefloat32
andint32
data types.python/bitblas/base/roller/policy/tensorcore.py
: Modified the_score
function to setshared_scope
to"shared.dyn"
if theout_dtype
isfloat32
.Improvements to handle different bit sizes:
python/bitblas/gpu/intrin/lop3.py
: Reformatted theget_fast_decode_intrin
function calls in theLOP3_FAST_DECODE_INT2_TO_INT8_TO_INT8_L16_INTRIN
,LOP3_FAST_DECODE_INT1_TO_INT8_TO_INT8_L16_INTRIN
, andLOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_INTRIN
registrations for better readability. Also added newLOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_INTRIN
andLOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_SCALE_INTRIN
registrations. [1] [2] [3]Adding assertions and handling different bit sizes:
python/bitblas/ops/general_matmul.py
: Added theis_not_fast_decoding_supported
function in the__initialize_fast_decoding
method and updated the condition in thetransform_weight
method to check ifbit
is less than8
. [1] [2]python/bitblas/ops/impl/matmul_dequantize_impl.py
: Added assertions to check ifbit
is in[1, 2, 4, 8]
in thematmul_nt_dequantize_b
,matmul_nt_dequantize_b_propagate_b
, andmatmul_nt_dequantize_b_propagate_a_propagate_b
functions. Also updated thedecode_func
function in these methods to handle the case wherebit
is8
. [1] [2] [3] [4] [5] [6] [7] [8]Other changes:
3rdparty/tvm
: Updated the subproject commit.