openclimatefix / ocf_datapipes

OCF's DataPipe based dataloader for training and inference
MIT License
13 stars 11 forks source link

Add Visualization Options #231

Open jacobbieker opened 10 months ago

jacobbieker commented 10 months ago

Detailed Description

We want to be able to easily see what our batches look like and have utilities that plot them to help with debugging and ensuring that our pipelines are doing what we expect.

We have had multiple one-off visualization scripts before, but the goal of this is to build them into datapipes, and ideally keep them up to date, and possibly run them on PRs to give a quick, automatic view if any of the datapipes are changed or updated.

I think the steps would be

Possible Implementation

Satip used to have a step in the workflows that ran visualization code of the outputs of some processing steps on PRs, it was quite helpful to know if changes broke end-to-end processing pipelines, and for the images coming out still looked correct.

Notes

Goal:

jacobbieker commented 10 months ago

@dfulu @peterdudfield does this sound right to you? I was thinking matplotlib plots by default, as they can be saved out to disk easily, or opened in streamlit with st.pyplot for the dashboard.

peterdudfield commented 10 months ago

@dfulu @peterdudfield does this sound right to you? I was thinking matplotlib plots by default, as they can be saved out to disk easily, or opened in streamlit with st.pyplot for the dashboard.

This looks really great, just got rew comments

dfulu commented 10 months ago

Yeh this sounds good. I had also been wondering if it might be a generally good idea to save out batches in something like a netcdf. Do you think it would be slower to load or larger on disk to use a netcdf for each batch compared to a pytorch tensor?

jacobbieker commented 10 months ago

It would be a bit slower to load, as you'd have to convert it to a pytorch tensor before putting it into the model, but it would make the batches a lot easier to visualize, could mostly just call the inbuilt xarray plotting. I would probably lean towards saving them out as netCDFs and then just doing the conversion on the fly. I don't think they would be much larger, they'd still have the metadata which might make a difference, but I think it should be fine.

jacobbieker commented 10 months ago

@peterdudfield sounds good for having it just be a function. If we did move to NetCDF files being saved to disk, I would probably stick with matplotlib as that is what xarray uses in its in-built plotting methods, and it would reduce the work needed for doing this.

reticent-roklimber commented 5 months ago

Hi, I am quite familiar with plotly and I am currently working with weather data, handling visualisation and building ML models for flood inundation. I came across this while looking at issues as part of GSOC. I am interested in contributing to this.

reticent-roklimber commented 5 months ago

Hi, I ended up not applying for GSoC due to the time constraints at work, but I am still interested in contributing here. Can you let me know how to proceed?

peterdudfield commented 2 months ago

here's my very small attempt, that takes batches --> spits out some sort of markdown file. This only does wind and nwp and is pretty delicate

""" The idea is visualize one of the batches """
import pandas as pd
import sys

from ocf_datapipes.batch import NumpyBatch, BatchKey, NWPBatchKey
import torch
import plotly.graph_objects as go

def visualize_batch(batch: NumpyBatch):

    # Wind
    print('# Batch visualization')
    print('## Wind \n')
    keys = [
        BatchKey.wind,
        BatchKey.wind_t0_idx,
        BatchKey.wind_time_utc,
        BatchKey.wind_id,
        BatchKey.wind_observed_capacity_mwp,
        BatchKey.wind_nominal_capacity_mwp,
        BatchKey.wind_time_utc,
        BatchKey.wind_latitude,
        BatchKey.wind_longitude,
        BatchKey.wind_solar_azimuth,
        BatchKey.wind_solar_elevation,
    ]
    for key in keys:
        if key in batch.keys():
            print('\n')
            value = batch[key]
            if isinstance(value, torch.Tensor):
                print(f"{key} {value.shape=}")
                print(f"Max {value.max()}")
                print(f"Min {value.min()}")
            elif isinstance(value, int):
                print(f"{key} {value}")
            else:
                print(f"{key} {value}")

    # NWP
    print('## NWP \n')

    keys = [
        NWPBatchKey.nwp,
        NWPBatchKey.nwp_target_time_utc,
        NWPBatchKey.nwp_channel_names,
        NWPBatchKey.nwp_step,
        NWPBatchKey.nwp_t0_idx,
        NWPBatchKey.nwp_init_time_utc,
    ]

    nwp = batch[BatchKey.nwp]

    nwp_providers = nwp.keys()
    for provider in nwp_providers:
        print('\n')
        print(f"Provider {provider}")
        nwp_provider = nwp[provider]

        # plot nwp main data
        nwp_data = nwp_provider[NWPBatchKey.nwp]
        # average of lat and lon
        nwp_data = nwp_data.mean(dim=(3, 4))
        fig = go.Figure()
        for i in range(len(nwp_provider[NWPBatchKey.nwp_channel_names])):
            channel = nwp_provider[NWPBatchKey.nwp_channel_names][i]
            nwp_data_one_channel = nwp_data[0,:,i]
            time = nwp_provider[NWPBatchKey.nwp_target_time_utc][0]
            time = pd.to_datetime(time, unit='s')
            fig.add_trace(go.Scatter(x=time, y=nwp_data_one_channel, mode='lines', name=channel))

        fig.update_layout(title=f'{provider} NWP', xaxis_title='Time', yaxis_title='Value')
        fig.show(renderer='browser')
        name = f'{provider}_nwp.png'
        fig.write_image(name)
        print(f'![]({name})')
        print('\n')

        for key in keys:
            print('\n')
            value = nwp_provider[key]
            if 'time' in key.name:
                value = pd.to_datetime(value[0], unit='s')
                print(f"{key} {value.shape=}")
                print(f"Max {value.max()}")
                print(f"Min {value.min()}")
            elif isinstance(value, torch.Tensor):
                print(f"{key} {value.shape=}")
                print(f"Max {value.max()}")
                print(f"Min {value.min()}")
            elif isinstance(value, int):
                print(f"{key} {value}")
            else:
                print(f"{key} {value}")

with open('batch.md', 'w') as f:
    sys.stdout = f
    d = torch.load("device_batch_0.pt")
    visualize_batch(d)