AI4S2S / s2spy

A high-level python package integrating expert knowledge and artificial intelligence to boost (sub) seasonal forecasting
https://ai4s2s.readthedocs.io/
Apache License 2.0
20 stars 7 forks source link

plot_correlation() requires recalculation of correlation map #91

Closed semvijverberg closed 2 years ago

semvijverberg commented 2 years ago

Issue: plot_correlation() requires recalculation of correlation map, which hampers interpretability that might be wanted after fitting a complex pipeline. The RGDR method can be used a 'brute' forcing dimensionality reduction approach, in which, a feature selection method or explainable AI method might identify some important features. It would then be nice if the user is able to visualize - in hindsight - which are these important features on the map. What I would like to be able to do is:

import xarray as xr
import matplotlib.pyplot as plt
import s2spy.time
import s2spy.rgdr
from s2spy import RGDR

file_path = '../tests/test_rgdr/test_data'
field = xr.open_dataset(f'{file_path}/sst_daily_1979-2018_5deg_Pacific_175_240E_25_50N.nc')
target = xr.open_dataset(f'{file_path}/tf5_nc5_dendo_80d77.nc')

cal = s2spy.time.AdventCalendar((8, 31), freq = "30d")
cal = cal.map_to_data(field)
field_resampled = s2spy.time.resample(cal, field)
target_resampled = s2spy.time.resample(cal, target)

rgdr = RGDR()
rgdr.fit(precursor_field, target_timeseries)

# fit model
# do features selection or feature importance analysis
# Desire to investigate important feature identified by features selection or feature importance analysis with label "lag:1_cluster:-2"

rgdr.plot_correlation(lag=1)
rgdr.plot_labels(lag=1, label=-2)
geek-yang commented 2 years ago

The related PR is merged and this issue can be closed now.