Jittor / jittor

Jittor is a high-performance deep learning framework based on JIT compiling and meta-operators.
https://cg.cs.tsinghua.edu.cn/jittor/
Apache License 2.0
3.07k stars 307 forks source link

Is there an efficient implementation of block diag function? #547

Closed liylo closed 3 months ago

liylo commented 3 months ago

Here is a naive implementatin, do you have better ones ? `def block_diag(*tensors):

Calculate the total shape of the block diagonal matrix

total_rows = sum(tensor.shape[0] if tensor.ndim > 0 else 1 for tensor in tensors)
total_cols = sum(tensor.shape[1] if tensor.ndim == 2 else 1 for tensor in tensors)

# Initialize the block diagonal matrix with zeros
block_matrix = jt.zeros((total_rows, total_cols), dtype=tensors[0].dtype)

current_row = 0
current_col = 0

# Place each tensor in the block diagonal matrix
for tensor in tensors:
    rows = tensor.shape[0] if tensor.ndim > 0 else 1
    cols = tensor.shape[1] if tensor.ndim == 2 else 1

    if tensor.ndim == 0:
        block_matrix[current_row, current_col] = tensor
    elif tensor.ndim == 1:
        for i in range(cols):
            block_matrix[current_row, current_col + i] = tensor[i]
    else:
        for i in range(rows):
            for j in range(cols):
                block_matrix[current_row + i, current_col + j] = tensor[i, j]

    current_row += rows
    current_col += cols

return block_matrix`