exo-lang / exo

Exocompilation for productive programming of hardware accelerators
https://exo-lang.dev
MIT License
279 stars 28 forks source link

Unknown Z3 error from remove_loop #652

Open yamaguchi1024 opened 1 month ago

yamaguchi1024 commented 1 month ago

Object code:

def matmul_on_gemmini(N: size, M: size, scale: f32 @ DRAM, act: bool @ DRAM,                                              
                      A: i8[N, 512] @ DRAM, B: i8[512, M] @ DRAM,                                                                                                                                                                                    
                      C: i8[N, M] @ DRAM):                                                                                
    assert N % 256 == 0                                                                                                   
    assert M % 256 == 0                                                                                                                                                                                                                              
    config_st_acc_i8(scale, stride(C, 0), act)                                                                            
    config_matmul()                                                                                                       
    config_ld_i8_id1(stride(A, 0))             
    config_ld_i8_id2(stride(B, 0))               
    config_zero()                                                                                                         
    A_tmp: i8[16, 32, 16, 16] @ GEMM_SCRATCH      
    B_tmp: i8[32, 16, 16, 16] @ GEMM_SCRATCH                                                                              
    res: i32[16, 16, 16] @ GEMM_ACCUM                      
    for ioo in seq(0, N / 256):                                                                                           
        for joo in seq(0, M / 256):                         
            for ioi in seq(0, 16):                                                                                        
                for joio in seq(0, 4):                                                                                    
                    for joii in seq(0, 4):                                                                                
                        do_zero_acc_i32(16, 16, res[joii + 4 * joio, 0:16,                                                
                                                    0:16])
    for ioo in seq(0, N / 256):      
        for joo in seq(0, M / 256):                                                                                       
            for ioi in seq(0, 16):                                                                                        
                for joio in seq(0, 4):                                                                                    
                    for koo in seq(0, 8):                                                                                 
                        if joio == 0:           
                            do_ld_i8_block_id1(              
                                16, 4, A[16 * ioi + 256 * ioo:16 + 16 * ioi +                                                                                                                                                                        
                                         256 * ioo, 64 * koo:64 + 64 * koo],                                                                                                                                                                         
                                A_tmp[ioi, 4 * koo:4 + 4 * koo, 0:16, 0:16])
        for joo in seq(0, M / 256):                                                                                                                                                                                                                  
            for ioi in seq(0, 16):                                                                                        
                for koo in seq(0, 8):                                                                                     
                    for koi in seq(0, 4):                                                                                 
                        for joio in seq(0, 4):                                                                            
                            if ioi == 0:                                                                                  
                                do_ld_i8_block_id2(                                                                       
                                    16, 4,                                                                                                                                                                                                           
                                    B[16 * koi + 64 * koo:16 + 16 * koi +
                                      64 * koo, 64 * joio + 256 * joo:64 +                                                                                                                                                                           
                                      64 * joio + 256 * joo],                                                                                                                                                                                        
                                    B_tmp[koi + 4 * koo, 4 * joio:4 + 4 * joio,                                                                                                                                                                      
                                          0:16, 0:16])                                                                                                                                                                                               
            for ioi in seq(0, 16):             
                for joio in seq(0, 4):
                    for joii in seq(0, 4):
                        for koo in seq(0, 8):
                            for koi in seq(0, 4):
                                do_matmul_acc_i8(
                                    16, 16, 16, A_tmp[ioi, koi + 4 * koo, 0:16,
                                                      0:16],
                                    B_tmp[koi + 4 * koo, joii + 4 * joio, 0:16,
                                          0:16], res[joii + 4 * joio, 0:16,
                                                     0:16])
                for joio in seq(0, 4):
                    for joii in seq(0, 4):
                        do_st_acc_i8(
                            16, 16, res[joii + 4 * joio, 0:16, 0:16],
                            C[16 * ioi + 256 * ioo:16 + 16 * ioi + 256 * ioo,
                              16 * joii + 64 * joio + 256 * joo:16 +
                              16 * joii + 64 * joio + 256 * joo])

Schedule:

joo_loop = gemmini.find("for joo in _:_ #1")
gemmini = remove_loop(gemmini, joo_loop)

Error

            if result == Z3.sat:
                is_valid = False
            elif result == Z3.unsat:
                is_valid = True
            else:
>               raise TypeError("unknown result from z3")
E               TypeError: unknown result from z3

../../src/exo/new_analysis_core.py:780: TypeError