JinmiaoChenLab / SEDR

MIT License
43 stars 9 forks source link

how to save graph_dict and read in the next time use #6

Closed yeswzc closed 10 months ago

yeswzc commented 10 months ago

Hi, is there a way to save the graph_dict in python that can be loaded into python next time? I tried pickle pickle.HIGHEST_PROTOCOL but it cann't be read.

Thank you!

Xuhang01 commented 10 months ago

Hi, I have a solution, but maybe not the best one.

import torchsnapshot from tensordict import TensorDict

import copy tmp = copy.deepcopy(graph_dict)

save

tmp['adj_norm'] = tmp['adj_norm'].to_dense() tmp['adj_label'] = tmp['adj_label'].to_dense() d = TensorDict(tmp, []) state = {'state': d} snapshot = torchsnapshot.Snapshot.take(app_state=state, path="snapshot")

restore

snapshot = torchsnapshot.Snapshot(path="snapshot") graph_dict_r = TensorDict({}, []) state_target = {"state": graph_dict_r} snapshot.restore(app_state=state_target) assert(graph_dict_r == d).all()

convert to SEDR input

graph_dict_r = graph_dict_r.to_dict() graph_dict_r['adj_norm'] = graph_dict_r['adj_norm'].to_sparse() graph_dict_r['adj_label'] = graph_dict_r['adj_label'].to_sparse()

yeswzc commented 10 months ago

Thank you! Works well!