openucx / ucc

Unified Collective Communication Library
https://openucx.github.io/ucc/
BSD 3-Clause "New" or "Revised" License
177 stars 85 forks source link

add new alltoallv alg field #983

Closed pangjuan8848 closed 1 week ago

pangjuan8848 commented 3 weeks ago

I previously read a paper that proposed the bruck2phase algorithm for alltoallv communication. In this paper, it mentioned that the bruck2phase algorithm outperforms the SLOVX algorithm, which is the ucp_alltoallv_hybrid algorithm in UCC. Therefore, I want to port it to UCC and compare the performance. However, when running the ucc_perftest program for benchmarking, the program got stuck and did not return any latency. Why is this happening?

paper: https://dl.acm.org/doi/10.1145/3502181.3531468

Here is the source code: `static void ucc_tl_ucp_alltoallv_bruck2phase_progress(ucc_coll_task_t coll_task) { ucc_tl_ucp_task_t task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); ucc_tl_ucp_team_t team = TASK_TEAM(task); ucc_rank_t grank = UCC_TL_TEAM_RANK(team); ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team); ptrdiff_t sbuf = (ptrdiff_t)TASK_ARGS(task).src.info_v.buffer; ptrdiff_t rbuf = (ptrdiff_t)TASK_ARGS(task).dst.info_v.buffer; ucc_memory_type_t smem = TASK_ARGS(task).src.info_v.mem_type; ucc_memory_type_t rmem = TASK_ARGS(task).dst.info_v.mem_type; size_t sdt_size = ucc_dt_size(TASK_ARGS(task).src.info_v.datatype); size_t rdt_size = ucc_dt_size(TASK_ARGS(task).dst.info_v.datatype); int s_disps = (int)TASK_ARGS(task).src.info_v.displacements; int r_disps = (int)TASK_ARGS(task).dst.info_v.displacements; int scounts = (int)TASK_ARGS(task).src.info_v.counts; int rcounts = (int*)TASK_ARGS(task).dst.info_v.counts;

void *extra_buffer=task->alltoallv_bruck2phase.extra_buffer->addr;
void *temp_send_buffer=task->alltoallv_bruck2phase.temp_send_buffer->addr;
void *temp_recv_buffer=task->alltoallv_bruck2phase.temp_recv_buffer->addr;

//1.find max send count
int max_send_count=task->alltoallv_bruck2phase.max_send_count;
int max_send_elements=task->alltoallv_bruck2phase.max_send_elements;

//2.create local index array
int index_array[gsize];
for(int i=0; i<gsize; i++){
     index_array[i] = (2*grank-i+gsize)%gsize;
}

//3.exchange data with log(P) steps

int pos_status[gsize];
memset(pos_status, 0, gsize*sizeof(int));

memcpy(PTR_OFFSET(rbuf, r_disps[grank]*rdt_size),
                       PTR_OFFSET(sbuf, s_disps[grank]*sdt_size),
                       rcounts[grank]*rdt_size);

for (int k = 1; k < gsize; k <<= 1) {
    // 1) find which data blocks to send
    int send_indexes[max_send_elements];
    int sendb_num = 0;
    for (int i = k; i < gsize; i++) {
        if (i & k)
            send_indexes[sendb_num++] = (grank+i)%gsize;
    }

    // 2) prepare metadata and send buffer
    int metadata_send[sendb_num];
    int sendCount = 0;
    int offset = 0;
    for (int i = 0; i < sendb_num; i++) {
        int send_index = index_array[send_indexes[i]];
        metadata_send[i] = scounts[send_index];
        if (pos_status[send_index] == 0){
            memcpy(PTR_OFFSET(temp_send_buffer, offset),
                       PTR_OFFSET(sbuf, s_disps[send_index]*sdt_size),
                       scounts[send_index]*sdt_size);
        }
        else
        {
            memcpy(PTR_OFFSET(temp_send_buffer, offset),
                       PTR_OFFSET(extra_buffer, send_indexes[i]*max_send_count*sdt_size),
                       scounts[send_index]*sdt_size);
        }
        offset += scounts[send_index]*sdt_size;
    }

    // 3) exchange metadata
    int sendrank = (grank - k + gsize) % gsize;
    int recvrank = (grank + k) % gsize;
    int metadata_recv[sendb_num];

    UCPCHECK_GOTO(ucc_tl_ucp_send_nb((void *)(metadata_send),sendb_num*sdt_size, smem, sendrank, team, task),task, out);

    UCPCHECK_GOTO(ucc_tl_ucp_recv_nb((void *)(metadata_recv),sendb_num*rdt_size, rmem, recvrank, team, task),task, out);
    for(int i = 0; i < sendb_num; i++)
        sendCount += metadata_recv[i];

    // 4) exchange data

    UCPCHECK_GOTO(ucc_tl_ucp_send_nb((void *)(temp_send_buffer),offset, smem, sendrank, team, task),task, out);

    UCPCHECK_GOTO(ucc_tl_ucp_recv_nb((void *)(temp_recv_buffer),sendCount*rdt_size, rmem, recvrank, team, task),task, out);

    // 5) replace
    offset = 0;
    for (int i = 0; i < sendb_num; i++) {
        int send_index = index_array[send_indexes[i]];

        if ((send_indexes[i] - grank + gsize) % gsize < (k << 1))
        {
            memcpy(PTR_OFFSET(rbuf, r_disps[send_indexes[i]]*sdt_size),
                       PTR_OFFSET(temp_recv_buffer,offset),
                       metadata_recv[i]*rdt_size);
        }
        else
        {
            memcpy(PTR_OFFSET(extra_buffer, send_indexes[i]*max_send_count*sdt_size),
                       PTR_OFFSET(temp_recv_buffer,offset),
                       metadata_recv[i]*rdt_size);
        }

        offset += metadata_recv[i]*rdt_size;
        pos_status[send_index] = 1;
        scounts[send_index] = metadata_recv[i];
    }
}

task->super.status = UCC_OK;   

out: return; }`