davidwilby / deepsensor

A Python package for tackling diverse environmental prediction tasks with NPs.
https://tom-andersson.github.io/deepsensor/
MIT License
0 stars 0 forks source link

Potential bug: TaskLoader does not run when patching and using 'Gapfill' samlping strategy #9

Open MartinSJRogers opened 1 week ago

MartinSJRogers commented 1 week ago

Description

The Taskloader function seems to freeze when asked to generate patched functions using the gapfill sampling strategy:

Reproduction steps

This first example runs fine

# Instantiate task loader
task_loader = TaskLoader(
    context = [modis_ds, amsr_ds] *5,
    target = modis_ds,
    context_delta_t = [-2, -2, -1, -1, 0, 0, 1,1,2,2], 
    target_delta_t = 0,
    #links = [(4, 0)],
)
for date in tqdm(dates, disable=not progress):
    tasks_per_date = task_loader(date, context_sampling=["all", "all", "all", "all", "all",
                                                                "all", "all", "all", "all", "all"],
                            target_sampling="all",
                            patch_strategy="sliding",
                            patch_size=0.5,
                            stride=0.25
            )
            for task in tasks_per_date:
                task.remove_context_nans().remove_target_nans()
            train_tasks.extend(tasks_per_date)

But if you call this example, nothing happens:

# Instantiate task loader
task_loader = TaskLoader(
    context = [modis_ds, amsr_ds] *5,
    target = modis_ds,
    context_delta_t = [-2, -2, -1, -1, 0, 0, 1,1,2,2], 
    target_delta_t = 0,
    links = [(4, 0)],
)
## Code runs down to here and then stops:
for date in tqdm(dates, disable=not progress):
     tasks_per_date = task_loader(date, context_sampling=["all", "all", "all", "all", "gapfill",
                                                                "all", "all", "all", "all", "all"],
                            target_sampling="gapfill",
                            patch_strategy="sliding",
                            patch_size=0.5,
                            stride=0.25
            )
            for task in tasks_per_date:
                task.remove_context_nans().remove_target_nans()
            train_tasks.extend(tasks_per_date)

Version

Patchwise_train fork, monotonic errors branch,

Screenshots

![DESCRIPTION](LINK.png)

OS

Windows

davidwilby commented 1 week ago

Could you maybe send your code (privately) so I can see whether this is relating to the datasets used? I haven't been able to replicate this behaviour so far, though using a different dataset I expect.

Have you run this code with a debugger? I'd like to check 1. Is it really hanging indefinitely, or just taking a long time, and if so 2. where is the freezing happening, is it actually in TaskLoader.__call__?

MartinSJRogers commented 6 days ago

Thanks @davidwilby I will send a private gist now. I haven't run this code with a debugger, I have just used my trusted print statements all over the place again. The print statements prove the code is hanging indefinitely/ taking a very long time when generating tasks, but I guess the debugger is needed to identify where in the actual DeepSensor codebase the code is hanging indefinitely? In the code I have sent you, if you set patching to False the code runs fine.

davidwilby commented 1 hour ago

In running your example @MartinSJRogers , I've tracked the hanging behaviour down to this section of deepsensor.data.TaskLoader.task_generation:

https://github.com/davidwilby/deepsensor/blob/23733df8c0967f33bfabb056d7940df3e76a3f17/deepsensor/data/loader.py#L1367-L1395

In this section keep_searching is never being set to True for some or all of your patches I think, so this while loop just runs indefinitely.

target_mask = added_mask & ~curr_mask
if isinstance(target_var, xr.Dataset):
    keep_searching = np.all(target_mask.to_array().data == False)
else:
    keep_searching = np.all(target_mask.data == False)
if keep_searching:
    continue  # No target points -- use a different `added_mask`