wutiejun / workspace

My workspace.
7 stars 3 forks source link

[Introduction to algorithm]Matrix multiplcation #35

Open wutiejun opened 7 years ago

wutiejun commented 7 years ago
SQUARE-MATRIX-MULTIPLY(A, B)
rows = A.rows
cols = B.cols
let C be a new rows X cols matrix
for i = 1 to A.rows
    for j = 1 to B.cols
        C[i,j] = 0
         for k = 1 to n
            C[i,j] = C[i,j] + A[i,k]*B[k,j]
return C
wutiejun commented 7 years ago

typedef struct Matrix_
{
    int Cols;
    int Rows;
    int * pValues;
} Matrix;

inline int GetMaxtrix(Matrix * pM, int row, int col)
{
    return *(pM->pValues + (pM->Cols * row + col));
}

inline SetMaxtrix(Matrix * pM, int row, int col, int val)
{
    *(pM->pValues + (pM->Cols * row + col)) = val;
}

void PrintMatrix(Matrix * pM, const char * pInfo)
{
    int i,j;
    printf("==================%s==================\r\n", pInfo);
    for (i = 0; i < pM->Rows; i ++)
    {
        for (j = 0; j < pM->Cols; j ++)
        {
            printf("%6d", GetMaxtrix(pM, i, j));
        }
        printf("\r\n");
    }
    printf("====================================\r\n");
}

int RandomScore()
{
    struct timeval tp;
    gettimeofday(&tp, NULL);
    srandom(tp.tv_usec);    
    return (16 - (random()%32));
}

Matrix * CreateMatrix(int rows, int cols)
{
    Matrix * pM = malloc(sizeof(Matrix) + (sizeof(int) * rows * cols));
    //assert(pM!=NULL);
    pM->Cols = cols;
    pM->Rows = rows;
    pM->pValues = (int *)((char *)pM + sizeof(Matrix));
    return pM;
}

Matrix * CreateRandomMatrix(int rows, int cols)
{
    Matrix * pM = CreateMatrix(rows, cols);
    int i,j;
    for (i = 0; i < pM->Rows; i ++)
    {
        for (j = 0; j < pM->Cols; j ++)
        {
            SetMaxtrix(pM, i, j, RandomScore());
        }
    } 
    return pM;
}

Matrix * QUARE_MATRIX_MULTIPLY(Matrix * pA, Matrix * pB)
{
    Matrix * pC = CreateMatrix(pA->Rows, pB->Cols);
    int i, j, k, C;
    for (i = 0; i < pA->Rows; i ++)
    {
        for (j = 0; j < pB->Cols; j++)
        {
            C = 0;
            for (k = 0; k < pB->Rows; k++)
            {
                C = C + GetMaxtrix(pA, i, k) * GetMaxtrix(pB, k,j);
            }
            SetMaxtrix(pC, i, j, C);
        }
    }
    return pC;
}
wutiejun commented 7 years ago
SQUARE-MATRIX-MULTIPLY-RECURSIVE(A, B)
m = A.rows/2
n = B.cols/2
let C be a new A.rows X B.cols matrix
if m == 1 and n == 1
    C[1,1] = A[1,1] X B[1,1]
else if m == 0 or n == 0
    //return C
else
    partition A, B, and C as in equations (4.9)
    C11 = SQUARE-MATRIX-MULTIPLY-RECURSIVE(A11, B11)
               + SQUARE-MATRIX-MULTIPLY-RECURSIVE(A12, B21)
    C12 = SQUARE-MATRIX-MULTIPLY-RECURSIVE(A11, B12)
               + SQUARE-MATRIX-MULTIPLY-RECURSIVE(A12, B22)
    C21 = SQUARE-MATRIX-MULTIPLY-RECURSIVE(A21, B11)
               + SQUARE-MATRIX-MULTIPLY-RECURSIVE(A22, B21)
    C22 = SQUARE-MATRIX-MULTIPLY-RECURSIVE(A21, B12)
               + SQUARE-MATRIX-MULTIPLY-RECURSIVE(A22, B22)
return C
wutiejun commented 7 years ago

write the code correctly at one time, feel great!


// Copy the data from pRaw into pT
// assert(pRaw->rows >= pT->rows + StartRow)
// assert(pRaw->cols >= pT->cols + StartCol)
void MatrixCopy(Matrix * pT, Matrix * pRaw, int StartRow, int StartCol)
{
    int i,j;
    for (i = 0; i < pT->Rows; i ++)
    {
        for (j = 0; j < pT->Cols; j ++)
        {
            int val = GetMaxtrix(pRaw, StartRow + i, StartCol + j);
            SetMaxtrix(pT, i, j, val);
        }
    }
}

void MatrixCopyBack(Matrix * pT, int StartRow, int StartCol, Matrix * pRaw)
{
    int i,j;
    for (i = 0; i < pRaw->Rows; i ++)
    {
        for (j = 0; j < pRaw->Cols; j ++)
        {
            int val = GetMaxtrix(pRaw,  i,  j);
            SetMaxtrix(pT, StartRow + i, StartCol + j, val);
        }
    }
}

Matrix * QUARE_MATRIX_MULTIPLY_RECURSIVE(Matrix * pA, Matrix * pB)
{
    int Rows = pA->Rows;
    int Cols = pB->Cols;

    int MidRow;// = Rows/2;
    int MidCol;// = Cols/2;

    Matrix * pC = CreateMatrix(Rows, Cols);
    MidRow = pC->Rows/2;
    MidCol = pC->Cols/2;

    if((MidRow * MidCol) == 0)
    {
        // has no row or no col, empty matrix
        return pC;
    }
    //
    if ((MidRow * MidCol) == 1)
    {
        // One col and one row, set only one value,and return
        int A = GetMaxtrix(pA, 0, 0);
        int B = GetMaxtrix(pB, 0, 0);
        SetMaxtrix(pC, 0, 0, A * B);
        return pC;
    }

    // Other withs, partition A, B
    Matrix * pC11 = CreateMatrix(MidRow, MidCol);
    Matrix * pC12 = CreateMatrix(MidRow, Cols - MidCol);
    Matrix * pC21 = CreateMatrix(Rows - MidRow, MidCol);
    Matrix * pC22 = CreateMatrix(Rows - MidRow, Cols - MidCol);
    //
    MidRow = pA->Rows/2;
    MidCol = pA->Cols/2;
    Matrix * pA11 = CreateMatrix(MidRow, MidCol);
    Matrix * pA12 = CreateMatrix(MidRow, Cols - MidCol);
    Matrix * pA21 = CreateMatrix(Rows - MidRow, MidCol);
    Matrix * pA22 = CreateMatrix(Rows - MidRow, Cols - MidCol);
    MatrixCopy(pA11, pA, 0, 0);
    MatrixCopy(pA12, pA, 0, MidCol);
    MatrixCopy(pA21, pA, MidRow, 0);
    MatrixCopy(pA22, pA, MidRow, MidCol);
    //
    MidRow = pB->Rows/2;
    MidCol = pB->Cols/2;
    Matrix * pB11 = CreateMatrix(MidRow, MidCol);
    Matrix * pB12 = CreateMatrix(MidRow, Cols - MidCol);
    Matrix * pB21 = CreateMatrix(Rows - MidRow, MidCol);
    Matrix * pB22 = CreateMatrix(Rows - MidRow, Cols - MidCol);
    MatrixCopy(pB11, pB, 0, 0);
    MatrixCopy(pB12, pB, 0, MidCol);
    MatrixCopy(pB21, pB, MidRow, 0);
    MatrixCopy(pB22, pB, MidRow, MidCol);
    //

    Matrix * pTemp1 = QUARE_MATRIX_MULTIPLY_RECURSIVE(pA11, pB11);
    Matrix * pTemp2 = QUARE_MATRIX_MULTIPLY_RECURSIVE(pA12, pB21);
    MatrixAdd(pTemp1, pTemp2, pC11);
    DeleteMatrix(pTemp1);
    DeleteMatrix(pTemp2);
    //
    pTemp1 = QUARE_MATRIX_MULTIPLY_RECURSIVE(pA11, pB12);
    pTemp2 = QUARE_MATRIX_MULTIPLY_RECURSIVE(pA12, pB22);
    MatrixAdd(pTemp1, pTemp2, pC12);
    DeleteMatrix(pTemp1);
    DeleteMatrix(pTemp2);
    //
    pTemp1 = QUARE_MATRIX_MULTIPLY_RECURSIVE(pA21, pB11);
    pTemp2 = QUARE_MATRIX_MULTIPLY_RECURSIVE(pA22, pB21);
    MatrixAdd(pTemp1, pTemp2, pC21);
    DeleteMatrix(pTemp1);
    DeleteMatrix(pTemp2);
    //
    pTemp1 = QUARE_MATRIX_MULTIPLY_RECURSIVE(pA21, pB12);
    pTemp2 = QUARE_MATRIX_MULTIPLY_RECURSIVE(pA22, pB22);
    MatrixAdd(pTemp1, pTemp2, pC22);
    DeleteMatrix(pTemp1);
    DeleteMatrix(pTemp2);
    //
    MatrixCopyBack(pC, 0, 0, pC11);
    MatrixCopyBack(pC, 0, pC11->Cols, pC12);
    MatrixCopyBack(pC, pC11->Rows, 0, pC21);
    MatrixCopyBack(pC, pC11->Rows, pC11->Cols, pC22);
    //
    DeleteMatrix(pC11);
    DeleteMatrix(pC12);
    DeleteMatrix(pC21);
    DeleteMatrix(pC22);
    //
    DeleteMatrix(pA11);
    DeleteMatrix(pA12);
    DeleteMatrix(pA21);
    DeleteMatrix(pA22);
    //
    DeleteMatrix(pB11);
    DeleteMatrix(pB12);
    DeleteMatrix(pB21);
    DeleteMatrix(pB22);
    //
    return pC;
}