SymbioticLab / FedScale

FedScale is a scalable and extensible open-source federated learning (FL) platform.
https://fedscale.ai
Apache License 2.0
388 stars 119 forks source link

[Async simulation] Implementation idea for task scheduling #174

Open ewenw opened 2 years ago

ewenw commented 2 years ago

Description

Hi FedScale team, here's my suggestion on how to implement the async simulation mode using device traces without needing a constant arrival parameter (related to #162):

sort device traces by start time
queue = initialize min priority queue
while tasks_issued < buffer_size:
   event_time, event_type, client_id = queue.get()
   if event_type == 'start':
        current_concurrency += 1
        if current_concurrency < MAX_CONCURRENCY:
            issue_task(event_time)
    else:
        current_concurrency -= 1
        if current_concurrency == MAX_CONCURRENCY - 1:
            issue_task(event_time)

issue_task(event_time):
    client, trace_start, trace_end = sample next available client at event_time
    add client task to individual executor's queue
    queue.put((trace_start, 'start', client))
    queue.put((trace_end, 'end', client)

This works well in my implementation, but might be harder to integrate into fedscale, hence I'm creating an issue to document it. Let me know if you have any questions / concerns.

Below is the python code for this scheduling algorithm, feel free to run it and validate the output:

import random
from queue import PriorityQueue

id = 0

def generate_start_end(time):
    # next available client
    global id
    start_time = time + random.randint(0, 1)
    duration = random.randint(1, 3)
    id += 1
    return start_time, start_time + duration, id

min_pq = PriorityQueue()
total_tasks = 1

TOTAL_TASKS = 10
MAX_CONCURRENCY = 1
current_concurrency = 0
start_times = {}

def new_task(event_time):
    new_start, new_end, client_id = generate_start_end(event_time)
    min_pq.put((new_start, 'start', client_id))
    min_pq.put((new_end, 'end', client_id))
    start_times[client_id] = new_start

new_task(0)
while not min_pq.empty():
    event_time, event_type, client_id = min_pq.get()
    if event_type == 'start':
        current_concurrency += 1
        if total_tasks < TOTAL_TASKS and current_concurrency < MAX_CONCURRENCY:
            new_task(event_time)
            total_tasks += 1
    else:
        current_concurrency -= 1
        if total_tasks < TOTAL_TASKS and current_concurrency == MAX_CONCURRENCY - 1:
            new_task(event_time)
            total_tasks += 1
        print(f"processing event starting at {start_times[client_id]} and ending at {event_time}")

Use case

No response

fanlai0990 commented 2 years ago

Great! Thanks a lot! Actually, Amber pushed a similar idea yesterday #173, which of course needs more efforts. We will work on this once we have more bandwidth.