Cambricon / mlu-ops

Efficient operation implementation based on the Cambricon Machine Learning Unit (MLU) .
MIT License
102 stars 102 forks source link

【新算子】- linalg.lu 算子开发 #1007

Open PetrelYy opened 5 months ago

PetrelYy commented 5 months ago

开发计划可参考以下节点:

  1. 方案撰写,xx.xx~xx.xx
  2. 开发自测,xx.xx~xx.xx
  3. 提出 PR/MR,xx.xx~xx.xx
  4. review( 3个赞),xx.xx~xx.xx
  5. maintainer 合入
PetrelYy commented 4 months ago

@Chuancysun 麻烦更新进展

Chuancysun commented 4 months ago

目前正针对长条形的矩阵规模,比如(65536,30)进行针对性的优化,重构了分解内核的代码逻辑,其他规模下能够达到10倍以内的性能指标

Chuancysun commented 4 months ago

目前正针对长条形的矩阵规模,比如(65536,30)进行针对性的优化,重构了分解内核的代码逻辑,其他规模下能够达到10倍以内的性能指标

补充:在优化了自己实现的矩阵乘加中的加法算子后,部分规模的矩阵性能仍未达标,用性能分析工具分析后发现瓶颈在最内层的分解内核,目前重构了分解内核的代码逻辑,通过性能分析工具发现在分解内核中有大量重复的 GDRAM 与片上之间的数据搬移,优化后的思路把所有高频使用的数据保存在片上,并采取计算换访存,如果片上存储空间不足则分段存入,预计性能有可观的提升。 1

Chuancysun commented 4 months ago

PR链接如下: https://github.com/Cambricon/mlu-ops/pull/1019

Chuancysun commented 4 months ago

schedule 工作计划如图

Chuancysun commented 4 months ago

当前完成了实数单batch非主元的LU分解,文档和代码已经贴到了PR,https://github.com/Cambricon/mlu-ops/pull/1019

Chuancysun commented 4 months ago

当前完成了实数多batch非主元的LU分解,性能如图,其中(4,5,65536,3000)规模MLU和Magma均无法开启足够的内存(大约15G),且cusolver中没有LU算子的batch实现接口,所以这里将Magma作为比较对象。 1

Chuancysun commented 4 months ago

当前完成了实数多batch非主元的LU分解,性能如图,其中(4,5,65536,3000)规模MLU和Magma均无法开启足够的内存(大约15G),且cusolver中没有LU算子的batch实现接口,所以这里将Magma作为比较对象。 1

代码稍后整理后会贴出PR

PetrelYy commented 4 months ago

与 @Chuancysun 沟通,当前PR 中设计文档&代码仅包含非主元,主元文档以及代码还在开发中

PetrelYy commented 4 months ago

建议排期提前半周,多留几天给review +修改。 否则7.15 风险很大

Chuancysun commented 4 months ago

schedule 更新后的工作计划如图,目前正在调试复数单batch的性能,非主元部分的性能优化预计会如期完成,在非主元的优化经验基础上,选主元部分的性能调试预计会比较快,但具体情况需要实现后才能判断;

Chuancysun commented 3 months ago

正在完成单batch复数下的功能及性能调优,目前正在重点优化长条规模下的性能。

Chuancysun commented 3 months ago

完成了复数单batch和多batch的正确性测试,性能测试结果如图,正在针对性的优化长条形矩阵的规模 2 1

Chuancysun commented 3 months ago

完成了对长条形矩阵的性能优化,目前已完成非主元的LU分解,正在开发选主元的LU分解

Chuancysun commented 3 months ago

目前完成了选主元LU分解中较小规模的功能和性能,对于较大规模的分块实现正在开发调试中

Chuancysun commented 2 months ago

4ZRSROKN1S@ZV6`AILM`0UB 如图,kernel里对nram和sram的max_size进行测试,但是发现nram和sram分别到大约512k和2048k左右就开始报错显示内存超限,根据文档的描述应该是640k和4096k大小,报错如下: WVS O$1O _4FGV)Y~01B9WR

PetrelYy commented 2 months ago

nram float uint8_t test[512*1024]; /// 这句代码错误,有两个类型

不建议编写上述代码,因为不同板卡NRAM_SIZE 大小存在区别,590 算子开发可用 nram 空间没有512k

ArtIntAI commented 1 month ago

测试代码和json可以贴下

Chuancysun commented 1 week ago

mannul_shape_1.json 测试用例可以参考如上

Chuancysun commented 1 week ago

compute.py如下: import torch from nonmlu_ops.base import * import logging import copy import os

@registerTensorList("sgetrf2") class sgetrf2TensorList(TensorList): pass

def castDataNode(self):

#     '''
#     cast input data to onchip data by input_dtype and input_onchip_dtype.
#     '''
#     # compute baseline output
#     #          Qcast             Force Cast            FFT
#     # x_fp -----------> x_int --------------> x_fp -------> y_fp
#     for input_tensor in self.input_tensors_:
#         input_datanode = input_tensor.getDataNode()
#         input_onchip_datanode = input_tensor.onchip_datanode_
#         if input_onchip_datanode.dtype_.isQuantType():
#             bitnum = input_onchip_datanode.dtype_.getDataBits()
#             if_scale = input_tensor.if_scale_
#             if_offset = input_tensor.if_offset_
#             if input_datanode.dtype_.isFloatPoint():
#                 # has nan or has inf, return
#                 if np.isnan(input_datanode.data_).any() or np.isinf(input_datanode.data_).any():
#                     return
#                 # Qcast
#                 position, scale, offset = quantize_utils.compute_quant_param(input_datanode.data_, bitnum, if_scale, if_offset)
#                 input_onchip_datanode.setData(quantize_utils.float2fix(input_datanode.data_, bitnum, position, scale, offset))
#                 input_onchip_datanode.setQuantParam(position, scale, offset)
#                 # Force cast
#                 input_datanode.setData(quantize_utils.fix2float(bitnum, input_onchip_datanode.data_, position, scale, offset))
#             elif input_datanode.dtype_.isComplex():
#                 real_data, imag_data = input_datanode.getComplexData()
#                 complex_data = np.concatenate((real_data, imag_data), axis=-1)
#                 # has nan or has inf, return
#                 if np.isnan(complex_data).any() or np.isinf(complex_data).any():
#                     return
#                 # Qcast
#                 position, scale, offset = quantize_utils.compute_quant_param(complex_data, bitnum, if_scale, if_offset)
#                 real_quant_data = quantize_utils.float2fix(real_data, bitnum, position, scale, offset)
#                 imag_quant_data = quantize_utils.float2fix(imag_data, bitnum, position, scale, offset)
#                 input_onchip_datanode.setQuantParam(position, scale, offset)
#                 # Force cast
#                 real_dequant_data = quantize_utils.fix2float(bitnum, real_quant_data, position, scale, offset)
#                 imag_dequant_data = quantize_utils.fix2float(bitnum, imag_quant_data, position, scale, offset)
#                 input_datanode.setComplexData(real_dequant_data, imag_dequant_data)

def print_matrix(A): if A.ndim == 3: batch = A.shape[0] size = A.shape[1] for i in range(batch): for j in range(size): for k in range(size): print("{:.3}".format(A[i][j][k]),end=" ") print("\n") print("\n") elif A.ndim == 2: size = A.shape[0] for i in range(size): for j in range(size): print("{:.3}".format(A[i][j]),end=" ") print("\n") elif A.ndim == 1: size = A.shape[0] for i in range(size): print("{}".format(A[i]), end=" ") print("\n")

def set_complex_data(data_node, complex_tensor): cpu_array = complex_tensor.cpu().numpy() cpu_real = np.real(cpu_array) cpu_imag = np.imag(cpu_array) data_node.setComplexData(cpu_real, cpu_imag)

def set_values_below_threshold(input_tensor, threshold=1e-3, new_value=1e-6):

获取数据类型

dtype = input_tensor.dtype
if dtype == torch.float32 or dtype == torch.complex64:
    new_value_pos = torch.tensor(new_value, dtype=torch.float32, device=input_tensor.device)
    new_value_neg = torch.tensor(-new_value, dtype=torch.float32, device=input_tensor.device)
elif dtype == torch.float64 or dtype == torch.complex128:
    new_value_pos = torch.tensor(new_value, dtype=torch.float64, device=input_tensor.device)
    new_value_neg = torch.tensor(-new_value, dtype=torch.float64, device=input_tensor.device)
else:
    raise ValueError("Unsupported tensor dtype")

# 对于复数tensor,分别处理实部和虚部
if torch.is_complex(input_tensor):
    real_part = input_tensor.real
    imag_part = input_tensor.imag

    real_part[(real_part.abs() < threshold) & (real_part >= 0)] = new_value_pos
    real_part[(real_part.abs() < threshold) & (real_part < 0)] = new_value_pos
    imag_part[(imag_part.abs() < threshold) & (imag_part >= 0)] = new_value_pos
    imag_part[(imag_part.abs() < threshold) & (imag_part < 0)] = new_value_pos

    input_tensor = torch.complex(real_part, imag_part)
else:
    # 对于非复数tensor
    input_tensor[(input_tensor.abs() < threshold) & (input_tensor >= 0)] = new_value_pos
    input_tensor[(input_tensor.abs() < threshold) & (input_tensor < 0)] = new_value_pos

return input_tensor

def set_diag_imag_one(input_tensor): if input_tensor.dim() == 2: diag_indices = torch.arange(input_tensor.size(0), device=input_tensor.device) input_tensor[diag_indices, diag_indices] += 1j - input_tensor[diag_indices, diag_indices].imag * 1j

elif input_tensor.dim() == 3:
    batch_size, n, _ = input_tensor.size()
    for i in range(batch_size):
        diag_indices = torch.arange(n, device=input_tensor.device)
        input_tensor[i, diag_indices, diag_indices] += 1j - input_tensor[i, diag_indices, diag_indices].imag * 1j

def matrix_multiply(A, B):

获取矩阵的维度

rows_A, cols_A = A.shape
rows_B, cols_B = B.shape

# 检查矩阵维度是否匹配
if cols_A != rows_B:
    raise ValueError("矩阵A的列数必须等于矩阵B的行数")

# 创建结果矩阵C,初始化为零
C = torch.zeros((rows_A, cols_B))

# 三重循环实现矩阵相乘
for i in range(rows_A):
    for j in range(cols_B):
        for k in range(cols_A):
            C[i][j] += A[i][k] * B[k][j]

return C

def extract_LU(LU, pivots): if LU.dim() == 2:

处理单个矩阵的情况

    m, n = LU.size()
    if torch.is_complex(LU):
        L = torch.tril(LU, diagonal=-1) + torch.eye(m, n, dtype=LU.dtype, device=LU.device) * (1 + 0j)
    else:
        L = torch.tril(LU, diagonal=-1) + torch.eye(m, n, dtype=LU.dtype, device=LU.device)
    U = torch.triu(LU)
    if m < n:
        L = L[:, :m]  # 裁剪为m * m
    elif m > n:
        U = U[:n, :]

elif LU.dim() == 3:
    # 处理多个矩阵的批次情况
    batch_size, m, n = LU.size()
    if torch.is_complex(LU):
        L = torch.tril(LU, diagonal=-1) + torch.eye(m, n, dtype=LU.dtype, device=LU.device).expand(batch_size, -1, -1) * (1 + 0j)
    else:
        L = torch.tril(LU, diagonal=-1) + torch.eye(m, n, dtype=LU.dtype, device=LU.device).expand(batch_size, -1, -1)
    U = torch.triu(LU)
    if m < n:
        L = L[:, :, :m]  # 裁剪为batch * m * m
    elif m > n:
        U = U[:, :n, :]

elif LU.dim() == 4:
    # 降维
    batch_size, depth, m, n = LU.size()
    LU = LU.view(batch_size * depth, m, n)
    # 处理多个矩阵的批次情况
    batch_size, m, n = LU.size()
    if torch.is_complex(LU):
        L = torch.tril(LU, diagonal=-1) + torch.eye(m, n, dtype=LU.dtype, device=LU.device).expand(batch_size, -1, -1) * (1 + 0j)
    else:
        L = torch.tril(LU, diagonal=-1) + torch.eye(m, n, dtype=LU.dtype, device=LU.device).expand(batch_size, -1, -1)
    U = torch.triu(LU)
    if m < n:
        L = L[:, :, :m]  # 裁剪为batch * m * m
    elif m > n:
        U = U[:, :n, :]

else:
    raise ValueError("Unsupported number of dimensions for LU tensor")

return L, U

def make_diagonally_dominant(input_data): if input_data.dim() == 2:

单个矩阵的情况

    m, n = input_data.size()
    min_mn = min(m, n)
    for i in range(min_mn):
        if torch.is_complex(input_data):
            # 处理复数矩阵
            real_part = input_data[i, i].real
            imag_part = input_data[i, i].imag
            input_data[i, i] = torch.complex(real_part + n, imag_part)
        else:
            # 处理实数矩阵
            input_data[i, i] += n
elif input_data.dim() == 3:
    # 多个矩阵的批次情况
    batchCount, m, n = input_data.size()
    min_mn = min(m, n)
    for s in range(batchCount):
        for i in range(min_mn):
            if torch.is_complex(input_data):
                # 处理复数矩阵
                real_part = input_data[s, i, i].real
                imag_part = input_data[s, i, i].imag
                input_data[s, i, i] = torch.complex(real_part + n, imag_part)
            else:
                # 处理实数矩阵
                input_data[s, i, i] += n
elif input_data.dim() == 4:
    # 降维
    batch_size, depth, m, n = input_data.size()
    input_data = input_data.view(batch_size * depth, m, n)
    # 多个矩阵的批次情况
    batchCount, m, n = input_data.size()
    min_mn = min(m, n)
    for s in range(batchCount):
        for i in range(min_mn):
            if torch.is_complex(input_data):
                # 处理复数矩阵
                real_part = input_data[s, i, i].real
                imag_part = input_data[s, i, i].imag
                input_data[s, i, i] = torch.complex(real_part + n, imag_part)
            else:
                # 处理实数矩阵
                input_data[s, i, i] += n
return input_data

Function to swap two rows of a matrix

def swap_rows(matrix, row1, row2):

print("swap row1 row2", row1,row2)

# print("row1 ")
# size = matrix[row1-1,:].shape[0]
# for i in range(size):
#     print("{:.3}".format(matrix[row1-1,i]), end=" ")
# print("row2 ")
# size = matrix[row2-1,:].shape[0]
# for i in range(size):
#     print("{:.3}".format(matrix[row2-1,i]), end=" ")

matrix[[row1-1, row2-1], :] = matrix[[row2-1, row1-1], :]

Function to apply row swaps to matrix A using ipiv

def apply_row_swaps(A, ipiv): if A.dim() == 2: batch_size = 1 m, n = A.size() elif A.dim() == 3: batch_size, m, n = A.size() elif A.dim() == 4: batch_size, depth, m, n = A.size() batch_size = batch_size * depth else: raise ValueError("Unsupported number of dimensions for A tensor") m = min(m, n) if batch_size > 1: for b in range(batch_size): for i in range(m - 1, -1, -1): # Iterate backwards if ipiv[b, i] - 1 != i: swap_rows(A[b], i + 1, ipiv[b, i])

ipiv[b, ipiv[b, i] - 1], ipiv[b, i] = ipiv[b, i], ipiv[b, ipiv[b, i] - 1]

else:
    for i in range(m - 1, -1, -1):  # Iterate backwards
        if ipiv[i] - 1 != i:
            # print("i",i)
            # print("ipiv[i] ",ipiv[i])
            swap_rows(A, i + 1, ipiv[i])
            # ipiv[ipiv[i] - 1], ipiv[i] = ipiv[i], ipiv[ipiv[i] - 1]

@registerOp("sgetrf2") class sgetrf2Op(OpTest): def init(self,tensorlist,params): super().init(tensorlist,params) self.mode = self.params.get("mode")

compute_cout = 0

def compute(self):
    os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5,6,7'
    gpu_count = torch.cuda.device_count()
    print("gpu_count:")
    print(torch.cuda.device_count())
    print("os visible:",os.environ.get('CUDA_VISIBLE_DEVICES'))
    cuda_count = sgetrf2Op.compute_cout % gpu_count
    print('now gpu:',cuda_count)
    sgetrf2Op.compute_cout += 1

    result_mul = True

    input_tensor = self.tensor_list_.getInputTensor(0)
    output_tensor = self.tensor_list_.getOutputTensor(0)
    input_is_complex = input_tensor.getDataType().isComplex()
    mode = self.mode_

    if not input_is_complex:
        # input_data = torch.tensor(input_tensor.getData()).cuda(cuda_count)
        # input_data_fp64 = input_data.type(torch.float64).cuda(cuda_count)
        # upper_triangle = torch.triu(input_data, diagonal=1)
        # L_matrix = input_data - upper_triangle

        # del input_data
        # del upper_triangle
        # torch.cuda.empty_cache()

        # batch = 1
        # size = L_matrix.size(1)
        # if L_matrix.dim() == 2:
        #     U_matrix = L_matrix.transpose(0, 1)
        #     A = torch.mm(L_matrix, U_matrix)
        #     del L_matrix
        #     torch.cuda.empty_cache()
        #     eye = torch.eye(size, dtype=torch.float32).cuda(cuda_count)
        #     A = A + eye
        # elif L_matrix.dim() == 3:
        #     U_matrix = L_matrix.transpose(1, 2)
        #     A = torch.bmm(L_matrix, U_matrix)
        #     batch = L_matrix.size(0)
        #     del L_matrix
        #     torch.cuda.empty_cache()
        #     eye = torch.eye(size, dtype=torch.float32).expand(batch, -1, -1).cuda(cuda_count)
        #     A = A + eye
        # else:
        #     exit()

        # del eye
        # del U_matrix
        # torch.cuda.empty_cache()

        # 输入矩阵 A
        flag = (mode == 1)
        print("pivot ",flag)

        input_data = torch.tensor(input_tensor.getData()).cuda(cuda_count)
        if flag == False:
            input_data = make_diagonally_dominant(input_data)
        input_data_fp64 = input_data.type(torch.float64).cuda(cuda_count)
        input_tensor.setData(input_data.cpu().numpy())

        # print("input:")
        # print_matrix(input_data.cpu().numpy())

        # input_tensor.setData(input_data)
        torch.backends.cuda.preferred_linalg_library(backend='cusolver')
        LU, pivots = torch.linalg.lu_factor(input_data, pivot=flag)
        # print("LU ",LU)

        # batch = 1
        # size = L_matrix.size(1)

        L_matrix, U_matrix = extract_LU(LU, pivots)
        print("L U size",L_matrix.size(),U_matrix.size())
        # print("L",L_matrix)
        # print("U",U_matrix)
        if result_mul or mode == 1:
            if L_matrix.dim() == 2:
                result = torch.mm(L_matrix, U_matrix)
            else:
                result = torch.bmm(L_matrix, U_matrix)

        else:
            result = LU

        # print("ipiv fp32")
        # print_matrix(pivots.cpu().numpy())
        if mode == 1:
            # Apply pivots to the result to restore the original matrix
            apply_row_swaps(result, pivots)

        # print("result fp32 LU:")
        # print_matrix(result.cpu().numpy())

        output_result = result.cpu().numpy()
        output_tensor.setData(output_result)

        del LU
        del pivots
        del result
        del L_matrix
        del U_matrix
        del output_result
        del input_data
        torch.cuda.empty_cache()

        # upper_triangle_fp64 = torch.triu(input_data_fp64, diagonal=1)
        # L_matrix_fp64 = input_data_fp64 - upper_triangle_fp64
        # del upper_triangle_fp64
        # del input_data_fp64
        # torch.cuda.empty_cache()

        # if L_matrix_fp64.dim() == 2:
        #     U_matrix_fp64 = L_matrix_fp64.transpose(0, 1)
        #     A_fp64 = torch.mm(L_matrix_fp64, U_matrix_fp64)
        #     del L_matrix_fp64
        #     del U_matrix_fp64
        #     A_fp64 = A_fp64 + torch.eye(size, dtype=torch.float64).cuda(cuda_count)
        # elif L_matrix_fp64.dim() == 3:
        #     U_matrix_fp64 = L_matrix_fp64.transpose(1, 2)
        #     A_fp64 = torch.bmm(L_matrix_fp64, U_matrix_fp64)
        #     del L_matrix_fp64
        #     del U_matrix_fp64
        #     A_fp64 = A_fp64 + torch.eye(size, dtype=torch.float64).expand(batch, -1, -1).cuda(cuda_count)

        # torch.cuda.empty_cache()
        A_fp64 = input_data_fp64.double()

        result_LU_fp64, pivots = torch.linalg.lu_factor(A_fp64, pivot=flag)

        result_L_fp64, result_U_fp64 = extract_LU(result_LU_fp64, pivots)

        # del A_fp64
        torch.cuda.empty_cache()

        base_node = DataNode("double")

        if result_mul or mode == 1:
            if result_L_fp64.dim() == 2:
                result = torch.matmul(result_L_fp64, result_U_fp64)
            else:
                result = torch.bmm(result_L_fp64, result_U_fp64)

        else:
            result = result_LU_fp64

        if mode == 1:
            # Apply pivots to the result to restore the original matrix
            apply_row_swaps(result, pivots)

            # print("orign fp64 result:")
            # print_matrix(result_L_fp64.cpu().numpy())

            # result = result_L_fp64
        # print("result fp64 LU:")
        # print_matrix(result.cpu().numpy())
        # print("ipiv fp64")
        # print_matrix(pivots.cpu().numpy())
        output_result = result.cpu().numpy()
        base_node.setData(output_result)
        del result
        del pivots
        del result_L_fp64
        del result_U_fp64
        del result_LU_fp64
        del input_data_fp64
        del A_fp64
        del output_result
        torch.cuda.empty_cache()

        # print_matrix(output_result_fp64)

        half_dynamic_threshold = 1e-3
        float_dynamic_threshold = 1e-5
        eva = diff_utils.Evaluator(base_node, output_tensor.getDataNode(), half_dynamic_threshold, float_dynamic_threshold)
        diff1 = eva.computeDiff1()
        diff2 = eva.computeDiff2()
        diff_3_2 = eva.computeDiff3_2(10.0)
        print("diff1: ", diff1)
        print("diff2: ", diff2)
        print("diff_3_2: ", diff_3_2)
        output_tensor.setDiff(diff1, diff2, -1, diff_3_2, -1)

    else:
        torch.backends.cuda.preferred_linalg_library(backend='cusolver')
        flag = (mode == 1)
        print("pivot ",flag)
        input_real_data, input_imag_data = input_tensor.getComplexData()
        # 组合成复数张量
        input_complex_data = torch.complex(torch.from_numpy(input_real_data), torch.from_numpy(input_imag_data))
        # upper_triangle_real = np.tril(input_real_data)
        # upper_triangle_imag = np.tril(input_imag_data, k=-1)
        # complex_numpy_array = upper_triangle_real + 1j * upper_triangle_imag
        # del upper_triangle_real
        # del upper_triangle_imag

        input_data = torch.tensor(input_complex_data, dtype=torch.complex64).cuda(cuda_count)
        if flag== False:
            input_data = make_diagonally_dominant(input_data)
        set_complex_data(input_tensor, input_data)
        # del complex_numpy_array
        # print("origin input:")
        # print_matrix(input_data.cpu().numpy())
        # input_real_data = np.expand_dims(input_real_data, -1)
        # print("real data:",input_real_data)
        # input_imag_data = np.expand_dims(input_imag_data, -1)
        # print("imag data:",input_imag_data)
        # input_complex_data = np.concatenate((input_real_data, input_imag_data), axis=-1)

        # input_data = torch.tensor(input_complex_data).cuda(cuda_count)

        # print("complex data:",input_data)

        input_data_complex128 = input_data.type(torch.complex128).cuda(cuda_count)
        # upper_triangle_complex64 = torch.triu(input_data, diagonal=1)
        # L_matrix_complex64 = input_data - upper_triangle_complex64

        # del input_data
        # del upper_triangle_complex64
        # # 释放显存
        # torch.cuda.empty_cache()

        # batch = 1
        # size = L_matrix_complex64.size(1)
        # print("Tensor shape:", L_matrix_complex64.shape)
        # if L_matrix_complex64.dim() == 2:
        #     U_matrix_complex64 = L_matrix_complex64.transpose(0, 1).conj()
        #     A_complex64 = torch.mm(L_matrix_complex64, U_matrix_complex64)
        #     del L_matrix_complex64
        #     torch.cuda.empty_cache()
        #     eye_complex64 = torch.eye(size, dtype=torch.complex64).cuda(cuda_count)
        #     A_complex64 = A_complex64 + eye_complex64
        # elif L_matrix_complex64.dim() == 3:
        #     U_matrix_complex64 = L_matrix_complex64.transpose(1, 2).conj()
        #     A_complex64 = torch.bmm(L_matrix_complex64, U_matrix_complex64)
        #     batch = L_matrix_complex64.size(0)
        #     del L_matrix_complex64
        #     torch.cuda.empty_cache()
        #     eye_complex64 = torch.eye(size, dtype=torch.complex64).expand(batch, -1, -1).cuda(cuda_count)
        #     A_complex64 = A_complex64 + eye_complex64
        # else:
        #     exit()

        # del eye_complex64
        # del U_matrix_complex64
        # torch.cuda.empty_cache()

        # set_complex_data(input_tensor, A_complex64)

        # print("input A:")
        # print_matrix(A_complex64.cpu().numpy())

        result_LU_complex64, pivots = torch.linalg.lu_factor(input_data,pivot=flag)

        result_L_complex64, result_U_complex64 = extract_LU(result_LU_complex64, pivots)
        # del A_complex64
        torch.cuda.empty_cache()

        if result_mul or mode == 1 :
            if result_L_complex64.dim() == 2:
                result = torch.mm(result_L_complex64, result_U_complex64)
            else:
                result = torch.bmm(result_L_complex64, result_U_complex64)
        else:
            result = result_LU_complex64

        if mode == 1:
            # Apply pivots to the result to restore the original matrix
            apply_row_swaps(result, pivots)

        # set_values_below_threshold(result)
        # set_diag_imag_one(result)
        set_complex_data(output_tensor, result)
        # print("result complex64 result:")
        # print_matrix(result.cpu().numpy())
        # print("ipiv complex64")
        # print_matrix(pivots.cpu().numpy())
        del pivots
        del input_real_data
        del input_imag_data
        del result_L_complex64
        del result_U_complex64
        del result_LU_complex64
        del result
        torch.cuda.empty_cache()

            # output_result = result.cpu().numpy()
            # output_tensor.setData(output_result)

        # print("result1:")
        # print_matrix(result1.cpu().numpy())

        # output_result_complex64 = result_L_complex64.cpu().numpy()

        # output_real = np.real(output_result_complex64)

        # output_imag = np.imag(output_result_complex64)

        # output_tensor.setComplexData(output_real, output_imag)

        # del result_L_complex64
        # torch.cuda.empty_cache()

        # print_matrix(input_data_complex128)

        # upper_triangle_complex128 = torch.triu(input_data_complex128, diagonal=1)
        # L_matrix_complex128 = input_data_complex128 - upper_triangle_complex128
        # L_matrix_complex128 = input_data_complex128

        # del input_data_complex128
        # del upper_triangle_complex128
        # torch.cuda.empty_cache()

        # if L_matrix_complex128.dim() == 2:

        #     A_complex128 = torch.mm(L_matrix_complex128, L_matrix_complex128.transpose(0, 1).conj()) + torch.eye(size, dtype=torch.complex128).cuda(cuda_count)
        #     del L_matrix_complex128
        #     torch.cuda.empty_cache()  
        # elif L_matrix_complex128.dim() == 3:

        #     A_complex128 = torch.bmm(L_matrix_complex128, L_matrix_complex128.transpose(1, 2).conj()) 
        #     del L_matrix_complex128
        #     torch.cuda.empty_cache()  
        #     A_complex128 = A_complex128 + torch.eye(size, dtype=torch.complex128).expand(batch, -1, -1).cuda(cuda_count)

        # print_matrix(A_complex64.cpu().numpy())

        # print_matrix(result_L_complex64.cpu().numpy())

        result_LU_complex128, pivots = torch.linalg.lu_factor(input_data_complex128,pivot=flag)

        result_L_complex128, result_U_complex128 = extract_LU(result_LU_complex128, pivots)
        # del A_complex128
        torch.cuda.empty_cache()

        base_node = DataNode("complex128")

        if result_mul:
            if result_L_complex128.dim() == 2:
                result = torch.mm(result_L_complex128, result_U_complex128)
            else:
                result = torch.bmm(result_L_complex128, result_U_complex128)

        else:
           result = result_LU_complex128

        if mode == 1:
            # Apply pivots to the result to restore the original matrix
            apply_row_swaps(result, pivots)

        # set_diag_imag_one(result)
        # print("result 128:")
        # print_matrix(result.cpu().numpy())
        # print("ipiv complex128")
        # print_matrix(pivots.cpu().numpy())
        set_complex_data(base_node, result)
        del result_LU_complex128
        del result_L_complex128
        del result_U_complex128
        del pivots
        del result
        del input_data
        del input_data_complex128
        torch.cuda.empty_cache()

        # print("result:")
        # print_matrix(result_L_complex128.cpu().numpy())

        # output_result_complex128 = result_L_complex128.cpu().numpy()

        # print("A_complex128 result:")
        # print_matrix(A_complex128)

        # print("complex128 result:")
        # print_matrix(output_result_complex128)

        # print_matrix(output_result_complex128)

        # output_real_fp64 = np.real(output_result_complex128)

        # output_imag_fp64 = np.imag(output_result_complex128)

        # base_node.setComplexData(output_real_fp64, output_imag_fp64)

        half_dynamic_threshold = 1e-3
        float_dynamic_threshold = 1e-5
        eva = diff_utils.Evaluator(base_node, output_tensor.getDataNode(), half_dynamic_threshold, float_dynamic_threshold)
        diff_3_2 = eva.computeDiff3_2(1.0)
        diff1 = eva.computeDiff1()
        diff2 = eva.computeDiff2()
        print("diff1: ", diff1)
        print("diff2: ", diff2)
        print("diff_3_2: ", diff_3_2)
        output_tensor.setDiff(diff1, diff2, -1, diff_3_2, -1)

    # print("还存在的变量:", locals())
    local_vars = list(locals().keys())
    # 删除所有局部变量
    for var in local_vars:
        del locals()[var]
    torch.cuda.empty_cache()

    # if output_is_complex:
    #     # set base node
    #     base_node = DataNode("complex128")
    #     base_node.setComplexData(output_real_fp64, output_imag_fp64)
    #     # set output tensor
    #     # if pytorch do not support half fft, convert output from float32 to float16
    #     if output_is_half:
    #         output_real = output_real.astype("float16")
    #         output_imag = output_imag.astype("float16")
    #     has_inf = np.isinf(output_real).any() or np.isinf(output_imag).any() or \
    #               np.isinf(output_real_fp64).any() or np.isinf(output_imag_fp64).any()
    #     has_nan = np.isnan(output_real).any() or np.isnan(output_imag).any() or \
    #               np.isnan(output_real_fp64).any() or np.isnan(output_imag_fp64).any()
    #     output_tensor.setComplexData(output_real, output_imag)
    # else:
    #     # set base node
    #     base_node = DataNode("double")
    #     base_node.setData(output_result_fp64)
    #     # set output tensor
    #     # if pytorch do not support half fft, convert output from float32 to float16
    #     if output_is_half:
    #         output_result = output_result.astype("float16")
    #     has_inf = np.isinf(output_result_fp64).any() or np.isinf(output_result).any()
    #     has_nan = np.isnan(output_result_fp64).any() or np.isnan(output_result).any()
    #     output_tensor.setData(output_result)

@registerProtoWriter("sgetrf2") class sgetrf2ProtoWriter(MluOpProtoWriter): def dumpOpParam2Node(self): sgetrf2_param_node = self.protonode.sgetrf2_param sgetrf2_param_node.mode = self.opparams.get("mode")

sgetrf2_param_node.n.extend(self.opparams.get("n"))

    # sgetrf2_param_node.direction = self.op_params_.get("direction")
    # sgetrf2_param_node.scale_factor = self.op_params_.get("scale_factor")