AI-IT-AVs / RobustMat

RobustMat and Street Landmark Patch Datasets
MIT License
5 stars 2 forks source link

cannot find the PDEs #1

Open LongLiveForFreedom opened 8 months ago

LongLiveForFreedom commented 8 months ago

Hello, I read your paper that you use PDE to implement graph-diffusion, but I did not find the PDE part in your code, is it not uploaded yet?

Alan-She commented 8 months ago

Hi, the current version is a primary version and the graph neural PDE module is given in ''class Subg_GAT_net_ode(nn.Module)'' from networks/model.py.

LongLiveForFreedom commented 8 months ago
Screenshot 2023-11-16 at 19 18 48

sorry, I still cannot find the PDE mudule. Here is the whole class Subg_GAT_net_ode I got. `class Subg_GAT_net_ode(nn.Module): #process the input data and generate the graph embeddings def init(self, dimfeat_meas, dimhid_meas, dimreadout_meas, nheads_meas, dropout, alpha, device, dim_hid_edge, nheads_edge, full_seg_adj): super(Subg_GAT_net_ode, self).init() self.full_seg_adj = full_seg_adj N = self.full_seg_adj.size(0) self.odeint = odeint #from the imported module tol_scale = 1.0 self.atol = tol_scale 1e-3 #absolute tolerance self.rtol = tol_scale 1e-3 #relative tolerance

    self.odefunc = GAT_Measures_Net_ode_func(512, dimhid_meas, dimreadout_meas//nheads_meas, dropout, alpha, nheads_meas, self.full_seg_adj)

    self.method = 'euler'
    self.step_size = 1.0
    self.t = torch.tensor([0, 1], device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
    self.ImageFeature = CNN_ode(downsampling_method='res')
    self.dropout = dropout

def forward(self, x_meas_src):
    seg_from_image_src = self.ImageFeature(x_meas_src).view(x_meas_src.size(0), -1)
    seg_src = seg_from_image_src

    t = self.t.type_as(seg_src).cuda()
    integrator = self.odeint #integrate the output of GAT_Measures_Net_ode_func
    func = self.odefunc
    state = seg_src
    state_dt = integrator(
        func, state, t,
        atol=self.atol,
        rtol=self.rtol)
    z = state_dt[1]
    embedding_graph_src = z

    return embedding_graph_src #return the graph embeddings(used to compute the MI loss and image prediction loss)`
Alan-She commented 8 months ago

The details are mainly given by state_dt = integrator( func, state, t, atol=self.atol, rtol=self.rtol) where func is a graph learning. This can be generally regarded as a graph neural PDE module.

LongLiveForFreedom commented 8 months ago

Think you sir, this helps me a lot!