pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
20.56k stars 3.57k forks source link

Overwrite of loaded weights on the HeteroDictLinear layer #9472

Open ver228 opened 3 days ago

ver228 commented 3 days ago

🐛 Describe the bug

Hey, there is a bug where after loading a state dictionary the weights of a lazy initialised HeteroDictLinear layer get overwritten. Here is a code snipped that reproduced the error.

import torch
import torch.nn.functional as F
from torch_geometric.nn import HeteroDictLinear, HeteroConv, GraphSAGE
from typing import Any

def _check_parameters_match(state_dict_v1 : dict[str, Any], state_dict_v2: dict[str, Any]) -> list[str]:
    # compare pretrained weights to actual model weights
    wrong_matches = []
    for name, param in state_dict_v1.items():
        # check difference between pretrained weights and actual model weights
        if not torch.isclose(state_dict_v2[name], param).all().item():
            wrong_matches.append(name)

    if wrong_matches:
        print(f"The following parameters do not match: {wrong_matches}")
    else:
        print("All weights match!")

class SimpleGCN(torch.nn.Module):
    def __init__(self, nodes, edges, hidden_channels=20):
        super().__init__()
        self.fc = HeteroDictLinear({k: -1 for k in nodes}, hidden_channels)
        self.conv = HeteroConv(
                {k: GraphSAGE((-1, -1), hidden_channels, num_layers=1) for k in edges}
            )
    def forward(self, x, edge_index):
        x = self.fc(x)
        x = self.conv(x, edge_index)
        return x

nodes = ["A", "B"]
edges = [("A", "to", "B"), ("B", "rev_to", "A")]
# Example data (usually you would use your dataset here)
edge_index = {k : torch.tensor([[0, 1, 2, 3],
                           [1, 0, 3, 2]], dtype=torch.long) for k in edges}
x = {n : torch.randn((4, 10)) for n in nodes}  # Assuming 4 nodes with 10 features each

# Create model without passing data to initialize lazy layers
model_v1 = SimpleGCN(nodes, edges)
model_v1(x, edge_index)

# deep copy of pretrained weights
presaved_weights = {k: v.clone() for k, v in model_v1.state_dict().items()}

# load then eval
model_v2 = SimpleGCN(nodes, edges)
model_v2.load_state_dict(presaved_weights)
_check_parameters_match(model_v2.state_dict(), presaved_weights) # Weights are loaded correctly!

model_v2(x, edge_index)  # Pass example data to initialize lazy layers
_check_parameters_match(model_v2.state_dict(), presaved_weights) # The weights of the HeteroDictLinear 

The problem is that this block. The reset_parameter should be called for each unitialized layer individually or have a sentinel flag if there is any uninitialized layer. I am happy to submit a PR if you agree.

Versions

PyTorch version: 2.3.0+cu118 Is debug build: False CUDA used to build PyTorch: 11.8 ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64) GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 Clang version: Could not collect CMake version: Could not collect Libc version: glibc-2.35

Python version: 3.11.4 (main, Jun 24 2024, 14:34:57) [GCC 11.4.0] (64-bit runtime) Python platform: Linux-5.4.0-152-generic-x86_64-with-glibc2.35 Is CUDA available: True CUDA runtime version: Could not collect CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3090 Nvidia driver version: 525.147.05 cuDNN version: Probably one of the following: /usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.6 /usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.6 /usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.6 /usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.6 /usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.6 /usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.6 /usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.6 HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 48 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 64 On-line CPU(s) list: 0-63 Vendor ID: AuthenticAMD Model name: AMD Ryzen Threadripper PRO 5975WX 32-Cores CPU family: 25 Model: 8 Thread(s) per core: 2 Core(s) per socket: 32 Socket(s): 1 Stepping: 2 Frequency boost: enabled CPU max MHz: 3600.0000 CPU min MHz: 1800.0000 BogoMIPS: 7186.27 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca Virtualization: AMD-V L1d cache: 1 MiB (32 instances) L1i cache: 1 MiB (32 instances) L2 cache: 16 MiB (32 instances) L3 cache: 128 MiB (4 instances) NUMA node(s): 1 NUMA node0 CPU(s): 0-63 Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected

Versions of relevant libraries: [pip3] mypy==1.10.0 [pip3] mypy-extensions==1.0.0 [pip3] numpy==1.26.4 [pip3] pytorch-lightning==2.2.5 [pip3] torch==2.3.0+cu118 [pip3] torch_geometric==2.5.3 [pip3] torchmetrics==1.4.0.post0 [pip3] triton==2.3.0 [conda] Could not collect