0%| | 0/10 [00:00<?, ?it/s]
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
[<ipython-input-18-4b79db204498>](https://localhost:8080/#) in <cell line: 6>()
4 all_viz_set = get_viz_idx(test_set, dataset_name, num_viz_samples)
5
----> 6 visualize_results(gsat, all_viz_set, test_set, num_viz_samples, dataset_name, model_config['use_edge_attr'])
1 frames
[/usr/local/lib/python3.10/dist-packages/torch_geometric/utils/subgraph.py](https://localhost:8080/#) in subgraph(subset, edge_index, edge_attr, relabel_nodes, num_nodes, return_edge_mask)
97 edge_mask = node_mask[edge_index[0]] & node_mask[edge_index[1]]
98 edge_index = edge_index[:, edge_mask].to('cpu')
---> 99 edge_attr = edge_attr[edge_mask] if edge_attr is not None else None
100
101 if relabel_nodes:
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)
when i run the last