ant-research / Spatio-Temporal-Hypergraph-Model

Apache License 2.0
36 stars 7 forks source link

关于训练报错问题 #7

Open ouyangqun5525 opened 2 months ago

ouyangqun5525 commented 2 months ago

在sttchgcn.py文件下面代码 `edge_attr_embed, edge_type_embed = None, None if data['edge_type'][0] is not None: if self.generate_edge_attr: edge_attr_embed = self.edge_attr_embedding_layer(data['edge_type'][0]) edge_type_embed = self.edge_type_embedding_layer(data['edge_type'][0])

    x_for_time_filter = self.conv_for_time_filter(
        x,
        edge_index=data['edge_index'][0],
        edge_attr_embed=edge_attr_embed,
        edge_time_embed=edge_time_embed,
        edge_dist_embed=edge_distance_embed,
        edge_type_embed=edge_type_embed
    )`

由于edge_attr_embed, edge_type_embed没有初始化值,所以送入后续的超图卷积中报错,会导致维度对不上,不知道是否是我的设置问题,期望能得到您的解答。 报错如下: Traceback (most recent call last): File "D:\project\pycharm\Spatio-Temporal-Hypergraph-Model\run.py", line 220, in out, loss = model(input_data, label=data.y[:, 0]) File "D:\Program Files\anconda\envs\dp\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "D:\project\pycharm\Spatio-Temporal-Hypergraph-Model\model\sthgcn.py", line 202, in forward x_for_time_filter = self.conv_for_time_filter( File "D:\Program Files\anconda\envs\dp\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl return forward_call(args, **kwargs) File "D:\project\pycharm\Spatio-Temporal-Hypergraph-Model\layer\conv.py", line 148, in forward out = self.propagate( File "C:\Users\ouyq\AppData\Local\Temp\layer.conv_HypergraphTransformer_propagate_swkuera1.py", line 187, in propagate out = self.message( File "D:\project\pycharm\Spatio-Temporal-Hypergraph-Model\layer\conv.py", line 251, in message alpha = softmax(alpha, index, ptr, size_i) File "D:\Program Files\anconda\envs\dp\lib\site-packages\torch_geometric\utils_softmax.py", line 68, in softmax out = (src - src_max).exp() RuntimeError: The size of tensor a (1111) must match the size of tensor b (0) at non-singleton dimension 0

Process finished with exit code 1

ouyangqun5525 commented 2 months ago

解决了,所有软件包的版本需要和原文设置一致,否则超图卷积模块容易报错