USGS-R / river-dl

Deep learning model for predicting environmental variables on river systems
Creative Commons Zero v1.0 Universal
21 stars 15 forks source link

RGCN in pytorch #159

Closed jdiaz4302 closed 2 years ago

jdiaz4302 commented 2 years ago

I see that the current repo has the RGCN.py file from river-dl with the model code written in tensorflow. This PR contains code for 2 versions of the RGCN (v0 = river-dl implementation, v1 = paper equations) written in pytorch. This is just for sharing the model code, giving an overview of it, and maybe doing some further changes.

jdiaz4302 commented 2 years ago

Here is some minimal code for using the models

import torch
import numpy as np
from river_dl.RGCN_v1 import *

data = torch.rand([455, 365, 16])
A = np.random.normal(size = [455, 455]) 

model = RGCN_v1(input_dim = 16,
                hidden_dim = 20,
                adj_matrix = A,
                recur_dropout = 0, 
                dropout = 0)

out, (h, c) = model(data)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

count_parameters(model)

RGCN_v1 and RGCN_v0 can be used interchangeably. When switching between the two, you ~should~ will use less parameters with RGCN_v1 (which will be displayed with the last line of code above). For that example, RGCN_v0 has 5461 parameters while v1 has 3401 which is a pretty huge % change.

The example uses random data of size [455, 365, 16] to represent one batch from the river-dl data associated with the [455, 455] adjacency matrix.

I have these models coded to output the sequence of predictions along with the last h and c state. This can be easily modified to output the list of h and c states if that is preferred. Also, we can just output the sequence of predictions if that is more compatible with the GWN and you're not interested in states or DA.

SimonTopp commented 2 years ago

@jdiaz4302, just saw this PR after submitting my updates that include the RGCN and RGCN_v1 models. If you'd rather incorporate them here let me know and I can remove them from my PR.

jdiaz4302 commented 2 years ago

My intention was opening a PR from SimonTopp:RGCN_torch to SimonTopp:main but apparently I messed that up 😅 . I'll close this out and copy-paste my stuff into that PR