m3dev / gokart

Gokart solves reproducibility, task dependencies, constraints of good code, and ease of use for Machine Learning Pipeline.
https://gokart.readthedocs.io/en/latest/
MIT License
318 stars 57 forks source link

Dict with length 1 loses its key after self.load() #293

Closed naokimaejima closed 1 month ago

naokimaejima commented 2 years ago

Thanks for this useful module. Let me report a strange behavior.

One of my tasks dumps a dict with length 1 (like InputTask as an example below). When I load the dumped intermediate .pkl file from the task , I can read it intact, But when I self.load() it in another task, the key seems to be lost.

[pipeline example]

class InputTask(gokart.TaskOnKart):

    def run(self):
        a_dict = {"a":1}
        # a_dict = {"a":1, "b":2}
        self.dump(a_dict)

    def output(self):
        return self.make_target("from_input_task.pkl")

class OutputTask(gokart.TaskOnKart):

    input_task = gokart.TaskInstanceParameter()

    def requires(self):
        return self.input_task

    def run(self):
        a_dict = self.load()
        self.dump(a_dict)

    def output(self):
        return self.make_target("from_output_task.pkl")

class MainTask(gokart.TaskOnKart):

    def requires(self):
        return OutputTask(input_task=InputTask())

When I load pickle file, the key of the dict is lost as below.

>> python -m pickle resources/from_input_task_40b1decdc02b4e434a354c04860ed95a.pkl 
{'a': 1}
>> python -m pickle resources/from_output_task_5ece2af2be7b8ef168b704b0dd233858.pkl
1

As far as I know, this seems to happen only in the case of dict with length 1. I guess that the codes below is the cause of the bug. I appreciate if you check it.

https://github.com/m3dev/gokart/blob/c5376e64d1787130c072e3cbda2b8b6e2620e8e3/gokart/task.py#L235-L236

nersonu commented 2 years ago

@hirosassa @Hi-king When parsing JSON, it is often the case that the dict root key has a length of 1. It seems to me that if you want to include a process to distinguish between a dict that is the result of a task and a dict specified in requires, you would have to change luigi. Is it conceivable to eliminate the L235-236 and L265-266 process?

mski-iksm commented 2 years ago

@naokimaejima @nersonu

Thank you for raising interesting issue.

Is it conceivable to eliminate the L235-236 and L265-266 process?

I agree with your idea.

As @naokimaejima mentioned, this behavior is caused because of following code.

if target is None and isinstance(data, dict) and len(data) == 1: 
     return list(data.values())[0] 

This is implemented to make loading easier when using default requires().

class A(gokart.TaskOnKart):
    task = gokart.TaskInstanceParameter()

    def run(self):
        a = self.load()

In the above example, self.load() is possible because of the default requires() inherited from TaskOnKart. https://github.com/m3dev/gokart/blob/master/gokart/task.py#L92

The default requires() will return a dict of gokart.TaskInstanceParameter(). Therefore, when single task is declared, loading it should be done as list(self.load())[0]. When single task is declared, the mentioned code snippet will work and return single task instead the dict which will allow loading with self.load().

However, I think the benefit of this implementation is very limited and disadvantage caused by confusion is more significant. So I think removing this implementation is good idea.

PR is welcome!

hirosassa commented 1 month ago

@naokimaejima Hey! Sorry for being waiting for so long. @mski-iksm fixed this issue and it will be released shortly. Think you for your suggestion!