kmeng01 / rome

Locating and editing factual associations in GPT (NeurIPS 2022)
https://rome.baulab.info
MIT License
572 stars 123 forks source link

Fixed bug causing disabling edits #44

Open keltin13 opened 6 months ago

keltin13 commented 6 months ago

ROME Bugfix Analysis

Fix Summary

Looking at the definition of Lambda, the existing code calculates the green $k$ in the compute_u function using the subject representation averaged over a set of random prefixes. The red $k$'s are calculated in compute_v and are not averaged over the set of random prefixes.

rome-equation-1

The fix uses the green $k*$ in both places in the denominator, which is what is done both in MEMIT. 'Rebuilding-ROME' (https://arxiv.org/abs/2403.07175), which re-implemented ROME based off of the MEMIT code, does this as well -- although they were unable to pinpoint the exact issue.

When the $k*$'s in the denominator are mismatched, on some edits the denominator becomes very small (the 'Division Factor'), causing the norm of the update to be huge and resulting in model collapse. I observed these 'disabling edits' independently, and have also been reported multiple times in the literature.

It is possible that using the green context-averaged $k*$ in all three locations would be optimal, but it appears that was not your intention since it is not done in the MEMIT code.

Update Norms

Here are the norms of the weight updates before and after the fix on the first 2000 samples of the CounterFact dataset:

update_norm_comparison

As well as the division factors:

division_factor_comparison

CounterFact Benchmark Results

"From Paper" are the results from the ROME paper, "Base Impl." are the results from my tests on first 2000 samples before the fix, and "+Bugfix" are the results from my tests on first 2000 samples after the fix. Performance is very slightly down across the board; I suspect this could be improved with another round of hyperparameter tuning.

S ES EM PS PM NS NM GE RS
Base Impl. 88.98 100.0 97.85 96.4 61.84 74.5 4.22 626.72 42.04
+Bugfix 88.57 98.8 93.39 95.45 59.19 75.33 4.54 626.27 41.73
From Paper 89.2 100.0 97.9 96.4 62.7 75.4 4.2 621.9 41.9

If we focus in only on the 'disabling edits' (the red points above), we see large improvements in all categories except Efficacy Score/Magnitude:

S ES EM PS PM NS NM GE RS
Base Impl. 10.41 100.0 30.71 87.5 27.42 3.75 -28.63 571.08 19.3
+Bugfix 61.17 50.0 0.16 100.0 53.68 52.5 1.98 631.42 51.55

Selected Examples

CounterFact case_173: "Chicago is a twin city of (Warsaw -> Istanbul)"

Before fix:

After fix:

CounterFact case_946: "Bay, which is located in (Philippines -> Italy)"

Before fix:

After fix: