WorldCereal / presto-worldcereal

7 stars 0 forks source link

Augmentation sometimes fails with `ValueError: low >= high` in `np.random.randint` #110

Open kvantricht opened 1 week ago

kvantricht commented 1 week ago

In an attempt to compute embeddings for training data with augmentation enabled, I get failing cases in this section. It looks like some cases are not yet covered, where the determined min_center_point is after the determined max_center_point. We need to investigate when and why this happens and how to solve it.

One case where I observe this:

Derived min_center_point is 11, max_center_point is 9.

Probably wasn't detected before because no augmentation was previously performed on labelled datasets.

kvantricht commented 1 week ago

valid_date of the sample is 2019-11-17, while end_date of the extraction is 2019-11-30. Somehow this sample did not get an appropriate extraction window (maybe the valid_date here is a true observation date and for this sample it's actually at the end of the season). However, my impression was we already tried to deal with these cases in process_parquet by adding nodata timesteps to add the MIN_EDGE_BUFFER. @cbutsko is this something you can check if this worked as intended, with the example sample above as a test case? On my end, I'll remove these faulty samples during processing in the meantime.

cbutsko commented 1 week ago

This is something that I do in the _processparquet function, particularly these lines: https://github.com/WorldCereal/presto-worldcereal/blob/ce3fae1bb1054ba0f8c60edb7b0a0edc76dbf3b2/presto/utils.py#L223 Samples like the one you encountered should be identified and deleted before the whole thing even goes to pivoting, with a logger message about the number of such samples. But I guess it wouldn't hurt to be able to handle such cases in the _get_timesteppositions as well.

kvantricht commented 1 week ago

This is something that I do in the _processparquet function, particularly these lines:

https://github.com/WorldCereal/presto-worldcereal/blob/ce3fae1bb1054ba0f8c60edb7b0a0edc76dbf3b2/presto/utils.py#L223

Samples like the one you encountered should be identified and deleted before the whole thing even goes to pivoting, with a logger message about the number of such samples. But I guess it wouldn't hurt to be able to handle such cases in the _get_timesteppositions as well.

as I mentioned, I also though this was handled in process_parquet. So that's interesting because your process_parquet is my default and still I get the error. Could you check for the case I described if this indeed gets handled by the method? In which case something strange is happening.

By the way, can we add a check in process_parquet already to handle an edge case if it's still there? Like this piece of code on line https://github.com/WorldCereal/presto-worldcereal/blob/ce3fae1bb1054ba0f8c60edb7b0a0edc76dbf3b2/presto/utils.py#L328

    # -----------------
    from presto.dataops import NUM_TIMESTEPS

    min_center_point = np.maximum(
        NUM_TIMESTEPS // 2,
        df_pivot["valid_position"] + MIN_EDGE_BUFFER - NUM_TIMESTEPS // 2,
    )
    max_center_point = np.minimum(
        df_pivot["available_timesteps"] - NUM_TIMESTEPS // 2,
        df_pivot["valid_position"] - MIN_EDGE_BUFFER + NUM_TIMESTEPS // 2,
    )

    faulty_samples = min_center_point > max_center_point
    if faulty_samples.sum() > 0:
        print(f"Dropping {faulty_samples.sum()} faulty samples")
    df_pivot = df_pivot[~faulty_samples]

    # -----------------
cbutsko commented 1 week ago

oh gosh, I finally found this bugger 🐜 it was happening because I forgot to re-initialize the _enddate for those samples where _validdate is too close to the egde and I had to add dummy timesteps. For the _startdate, this was done here https://github.com/WorldCereal/presto-worldcereal/blob/ce3fae1bb1054ba0f8c60edb7b0a0edc76dbf3b2/presto/utils.py#L277 If I add the the same procedure for _enddate, the _availabletimesteps attribute gets correctly computed, and now for this particular 2019_ESP_ESYRCE_POLY_1113799 we have 17 available timesteps, and no error is thrown. I was able to make a test run through the whole dataset without errors. But I will still add the piece of code you suggested, just for a sanity check.

UPD: added this fix in a separate branch here https://github.com/WorldCereal/presto-worldcereal/commit/f03649e5da8bdba0a507b97791d20117ca516f9f Will convert to PR tomorrow after further checks of the window subsetting behavior.