sbi-dev / sbi

Simulation-based inference toolkit
https://sbi-dev.github.io/sbi/
Apache License 2.0
581 stars 150 forks source link

Graph learning with embedding net #712

Closed jeliason closed 2 years ago

jeliason commented 2 years ago

Hi all -

I am new to sbi and am wondering the best way to approach my problem. Briefly, I am trying to perform inference on the parameters of a spatial agent-based model. My idea was to simulate to a certain time point, capture the spatial positions as an observation at that time point, and then convert that observation into a graph (via a graph generating procedure like k-nearest neighbors). I would then use something like a graph convolutional network (GCN) to embed these graph observations for the density estimator.

However, I am getting tripped up by the format that sbi expects inputs. My simulator code looks something like this:

def simulator_model(parameters):

    # PLACEHOLDER: Here, I run simulation code for agent-based model and capture spatial positions and features of agents as pandas dataframe cell_data
    features = cell_data.filter(like=FEATURES).to_numpy()
    centroids = cell_data.loc[:,['X','Y']].to_numpy()
    num_nodes = features.shape[0]

    graph = dgl.DGLGraph()
    graph.add_nodes(num_nodes)
    graph.ndata[FEATURES] = torch.FloatTensor(features)

    adj = kneighbors_graph(
        centroids,
        k,
        mode=mode,
        include_self=False,
        metric="euclidean").toarray()
    if thresh is not None:
        adj[adj > thresh] = 0
    edge_list = np.nonzero(adj)
    graph.add_edges(list(edge_list[0]), list(edge_list[1]))

    return graph

As can be seen, my simulator returns a DGLGraph object, rather than a torch tensor. What would be a good way of approaching this problem so that I can successfully create batches of simulations (which are DGL graphs) and feed these to a density estimator which has a GCN as the embedding net?

CameronFen commented 2 years ago

I have some pretty gnarly code that does this for my own research topic. You need to actually modify the sbi code to generate PyTorch geometric batches of graphs and then you should perform the embedding of the graph in sbi (in particular modifying snpe_base). That way the data you are saving are tensors (which means less memory requirement when training and also less need for large-scale modifications of sbi to accept PyTorch geometric Data data structures) Feel free to email me at cameronfen@gmail.com. I'm happy to send you my code if interested.

jeliason commented 2 years ago

@CameronFen thank you! I'll reach out via email.

michaeldeistler commented 2 years ago

I'm closing this for now, but feel free to reopen if new questions come up! Good luck with your project!