spotify / luigi

Luigi is a Python module that helps you build complex pipelines of batch jobs. It handles dependency resolution, workflow management, visualization etc. It also comes with Hadoop support built in.
Apache License 2.0
17.71k stars 2.39k forks source link

int values to FloatParameter crashes worker #3243

Open krabo0om opened 1 year ago

krabo0om commented 1 year ago

Summary

Providing an int or float variable but with the same value to a FloatParameter results in different task_ids. This, in turn, crashes the assistant worker.

Detailed description

When an assistant worker receives a new task in _get_work, the task is put into _scheduled_tasks using the server-provided task_id. However, for the rest of the code, the locally computed task_id is used, e.g., for the _running_tasks queue. This locally computed task_id is used to access _scheduled_tasks, which results in a KeyError.

The source of the problem is in the way the task_id is computed and/or how tasks are compared. Tasks with a FloatParameter of 5 or 5.0 are considered equal but their task_id is not, because the parameters are converted into str. However, str(5) != str(5.0) but 5 == 5.0.

Example and Workaround

Example code to force the bug is shown below. If the default is 5.0 and the instance value 5, it causes the same crash. I am not familiar enough with the luigi code to propose a good fix. My current work-around is the FixedFloatParameter, shown below, to make the task robust against int values.

import multiprocessing
import time

import luigi.server
import luigi.task_register

port = 11111
address = '127.0.0.1'

class FixedFloatParameter(luigi.FloatParameter):
    ''' fix the bug by forcing the value to be a float data type '''
    def serialize(self, x):
        return super().serialize(float(x))

# TODO: comment and uncomment to show bug
float_cls = luigi.FloatParameter
# float_cls = FixedFloatParameter

class FloatTask(luigi.Task):
    a = float_cls(default=5)

def create_problem():
    time.sleep(1)
    luigi.build([FloatTask()], scheduler_url=f'http://{address}:{port}', workers=0)

def assistant():
    time.sleep(2)
    luigi.build([luigi.Task()], assistant=True, scheduler_url=f'http://{address}:{port}', workers=1)

if __name__ == '__main__':
    # showcase problem of different task_ids 
    print('default 5: ' + FloatTask().task_id)
    print('set 5.0:   ' + FloatTask(5.0).task_id)
    print('clear register -> recreate task instances')
    luigi.task_register.Register.clear_instance_cache()
    print('set 5.0:   ' + FloatTask(5.0).task_id)

    luigi.task_register.Register.clear_instance_cache()
    # schedule the problematic task
    multiprocessing.Process(target=create_problem).start()
    # run an assistant 
    multiprocessing.Process(target=assistant).start()

    luigi.server.run(api_port=port, address=address)