Miguel-Antonm / deploy_DL_space_weather_forecast

PFG: Deploy deep learning space weather forecasting
Apache License 2.0
0 stars 2 forks source link

Descargar el modelo entrenado #1

Open vrodriguezf opened 3 years ago

vrodriguezf commented 3 years ago

Hacer una función que descargue de todo lo necesario para cargar el modelo entrenado en el paper. La función recibirá como argumento un path a un directorio bajo el que se guardará todo. Suponiendo que ese directorio se llama "NBEATS", la estructura del directorio después de llamar a la función debería ser algo como:

- NBEATS
  - ensembleH3
    - sweep_run_1 # The name sweep_run_1 is in reality the name of the run given in wandb
      - best.pth
      - config.yaml
    - sweep_run_2
      - best.pth
      - config.yaml
    - ...
    - sweep_run_90
      - best.pth
      - config.yaml 
  - ensembleH5
    - sweep_run_1
    - ...
    - sweep_run_90
  - ensembleH7
  - ensembleH10
  - ensembleH14
  - ensembleH21
  - ensembleH27

La letra H hace referencia al "horizonte" de predicción del modelo. Si te fijas en el paper, todos los experimentos se hacen para modelos con diferentes horizontes. Los modelos con horizontes más pequeños (3 días) tienen más precisión y menos incertidumbre que los modelos que predicen 30 días (lógico).

Para cada horizonte hay un modelo entrenado, que es a su vez un ensemble de modelos NBEATS. Al contrario de lo que hablamos en nuestra última reunión, cada ensemble tiene 90 modelos en vez de 180. Los datos tienes que descargarlos usando la API de wandb, accediendo al proyecto swe_pytorch_ensemble de la entidad ecstevenson. En concreto, los sweeps que tienes que descargar dentro de ese proyecto son:

En cada sweep tendrás que ver cual es su ID para acceder a él mediante la API. Ten en cuenta que no hay un sweep para cada horizonte, en la mayoría de los casos hay un sweep para cada 2 valores de H, por lo que dentro de la función tendrás que ver qué horizonte tiene cada run. Lo puedes ver en la configuración del run (run.config)

Para cada run de cada ensemble, hay que descargar 2 ficheros: best.pth y config.yaml. El primero tiene los pesos de la red neuronal, y el segundo tiene la configuración de la arquitectura, que luego nos servirá para cargar el modelo de la forma adecuada.

Para mantener el repositorio organizado, para las funciones de carga/descarga de datos creáte otro modulo (es decir, otro notebook) que se llame "data" o algo así.

vrodriguezf commented 3 years ago

Fichero con modelo NBEATS

https://github.com/philipperemy/n-beats/blob/master/nbeats_pytorch/model.py

Descarga este fichero y colócalo en sfwd/models/nbeats.py. Luego en el código, lo importas con from sfwd.models.nbeats import * y ya deberías tener acceso a la red NBeatsNet

vrodriguezf commented 3 years ago

Ayer hablé con Emma y me dijo que el ensemble H3 tiene 180 runs porque fue el primero que hizo y todavía no tenía claro cuantos runs necesitaba por ensemble. Así que por ahora, descarga los 180 y ya está, consideramos ese caso como especial.

Otra cosa, también me dijo que tiene planeado reentrenar el modelo la semana que viene, con algunos hiiperparámetros ajustados, y un mayor número de valores de H. Así que, cuando esto ocurra, habrá que redescargarse el modelo.