openclimatefix / PVNet

PVnet main repo
MIT License
15 stars 3 forks source link

Fix shuffling and minor tweaks #118

Closed dfulu closed 6 months ago

dfulu commented 6 months ago

Pull Request

Apparently if you set shuffle=False when you pass a shuffled datapipe into the dataloader, then it still isn't shuffled. Which means we've been training without additional shuffling on each epoch.

Checklist:

dfulu commented 6 months ago

I minimal example of the above:

from lightning.pytorch import seed_everything
from torch.utils.data import DataLoader
from torch.utils.data import IterDataPipe, functional_datapipe

from torch.utils.data.datapipes.iter import IterableWrapper

seed_everything(122)

@functional_datapipe("simple_unbatch")
class unbatcher(IterDataPipe):

    def __init__(self, source_datapipe: IterDataPipe):
        self.source_datapipe = source_datapipe

    def __iter__(self):
        for batch in self.source_datapipe:
            for example in batch:
                yield example

# This is like our datapipe yielding batches
batch_size = 5
total_samples = 100

dp_source = IterableWrapper(
    [[i for i in range(j, j+batch_size)] for j in range(0, total_samples, batch_size)] 
)

dp_shuffled = (
    dp_source
    # Shuffle the batches
    .shuffle(buffer_size=5)
    # Shard
    .sharding_filter()
    # Unbatch to samples
    .simple_unbatch()
    # Shuffle the samples
    .shuffle(buffer_size=100)
    # Rebatch
    .batch(batch_size)
)

print(
    "The datapipe should already be shuffled but when we put it in the dataloader with "
    "shuffle=False the samples come out unshuffled:\n"
)

dl = DataLoader(
    dp_shuffled,     
    shuffle=False, 
    batch_size=None,  # batched in datapipe step
    num_workers=2,
    pin_memory=False,
    worker_init_fn=None,
    prefetch_factor=2,
    persistent_workers=False,
)

vs = []
for s in dl:
    print(s)

print("\nIf shuffle=True the samples come out shuffled:\n")
dl = DataLoader(
    dp_shuffled,     
    shuffle=True, 
    batch_size=None,  # batched in datapipe step
    num_workers=2,
    pin_memory=False,
    worker_init_fn=None,
    prefetch_factor=2,
    persistent_workers=False,
)

vs = []
for s in dl:
    print(s)

print("\nIf not shuffled in the datapipe the data is also not shuffled in the dataloader:\n")

dp_unshuffled = (
    dp_source
    # Shard
    .sharding_filter()
    # Unbatch to samples
    .simple_unbatch()
    # Rebatch
    .batch(batch_size)
)

dl = DataLoader(
    dp_unshuffled,     
    shuffle=True, 
    batch_size=None,  # batched in datapipe step
    num_workers=2,
    pin_memory=False,
    worker_init_fn=None,
    prefetch_factor=2,
    persistent_workers=False,
)
vs = []
for s in dl:
    print(s)

This produces the output:

Seed set to 122
The datapipe should already be shuffled but when we put it in the dataloader with shuffle=False the samples come out unshuffled:

[0, 1, 2, 3, 4]
[5, 6, 7, 8, 9]
[10, 11, 12, 13, 14]
[15, 16, 17, 18, 19]
[20, 21, 22, 23, 24]
[25, 26, 27, 28, 29]
[30, 31, 32, 33, 34]
[35, 36, 37, 38, 39]
[40, 41, 42, 43, 44]
[45, 46, 47, 48, 49]
[50, 51, 52, 53, 54]
[55, 56, 57, 58, 59]
[60, 61, 62, 63, 64]
[65, 66, 67, 68, 69]
[70, 71, 72, 73, 74]
[75, 76, 77, 78, 79]
[80, 81, 82, 83, 84]
[85, 86, 87, 88, 89]
[90, 91, 92, 93, 94]
[95, 96, 97, 98, 99]

If shuffle=True the samples come out shuffled:

[25, 62, 64, 54, 14]
[15, 67, 69, 79, 39]
[41, 85, 52, 40, 87]
[31, 90, 77, 30, 92]
[88, 63, 27, 49, 8]
[93, 68, 17, 99, 58]
[10, 46, 28, 12, 86]
[35, 96, 18, 37, 91]
[71, 89, 60, 53, 43]
[81, 94, 65, 78, 33]
[45, 20, 70, 73, 21]
[95, 0, 80, 83, 1]
[11, 61, 29, 13, 74]
[36, 66, 19, 38, 84]
[9, 44, 26, 72, 22]
[59, 34, 16, 82, 2]
[50, 42, 7, 23, 51]
[75, 32, 57, 3, 76]
[48, 24, 47, 6, 5]
[98, 4, 97, 56, 55]

If not shuffled in the datapipe the data is also not shuffled in the dataloader:

/home/jamesfulton/mambaforge/envs/pvnet0/lib/python3.10/site-packages/torch/utils/data/graph_settings.py:103: UserWarning: `shuffle=True` was set, but the datapipe does not contain a `Shuffler`. Adding one at the end. Be aware that the default buffer size might not be sufficient for your task.
  warnings.warn(
[0, 1, 2, 3, 4]
[5, 6, 7, 8, 9]
[10, 11, 12, 13, 14]
[15, 16, 17, 18, 19]
[20, 21, 22, 23, 24]
[25, 26, 27, 28, 29]
[30, 31, 32, 33, 34]
[35, 36, 37, 38, 39]
[40, 41, 42, 43, 44]
[45, 46, 47, 48, 49]
[50, 51, 52, 53, 54]
[55, 56, 57, 58, 59]
[60, 61, 62, 63, 64]
[65, 66, 67, 68, 69]
[70, 71, 72, 73, 74]
[75, 76, 77, 78, 79]
[80, 81, 82, 83, 84]
[85, 86, 87, 88, 89]
[90, 91, 92, 93, 94]
[95, 96, 97, 98, 99]
codecov[bot] commented 6 months ago

Codecov Report

Attention: 6 lines in your changes are missing coverage. Please review.

Comparison is base (b2d0e30) 57.62% compared to head (eaa94f7) 57.63%. Report is 2 commits behind head on main.

Files Patch % Lines
pvnet/data/wind_datamodule.py 0.00% 3 Missing :warning:
pvnet/data/datamodule.py 66.66% 1 Missing :warning:
pvnet/models/utils.py 50.00% 1 Missing :warning:
pvnet/training.py 0.00% 1 Missing :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #118 +/- ## ========================================== + Coverage 57.62% 57.63% +0.01% ========================================== Files 26 26 Lines 1706 1702 -4 ========================================== - Hits 983 981 -2 + Misses 723 721 -2 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.