timeseriesAI / tsai

Time series Timeseries Deep Learning Machine Learning Python Pytorch fastai | State-of-the-art Deep Learning library for Time Series and Sequences in Pytorch / fastai
https://timeseriesai.github.io/tsai/
Apache License 2.0
4.92k stars 622 forks source link

How to Save and Update Model Parameters with Tsai for Time Series Regression? #839

Closed makinno closed 9 months ago

makinno commented 9 months ago

I am looking to implement get_parameters and set_parameters functions for a gMLP (gated Multi-Layer Perceptron) model using Tsai for time series regression. These two helper functions will facilitate the process of updating the local model with parameters received from the server and retrieving the updated model parameters from the local model.

The provided script below outlines how to achieve this for a simple CNN model using PyTorch. Could you please guide me on how to adapt this script to work with a gMLP model in Tsai?

def getparameters(net) -> List[np.ndarray]: return [val.cpu().numpy() for , val in net.state_dict().items()]

def set_parameters(net, parameters: List[np.ndarray]): params_dict = zip(net.state_dict().keys(), parameters) state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) net.load_state_dict(state_dict, strict=True)

I appreciate any assistance or insights on how to implement these functions effectively for a gMLP model using Tsai. Thank you!