ecboghiu / inflation

Implementations of the Inflation Technique for Causal Inference.
GNU General Public License v3.0
22 stars 3 forks source link

KeyError on scenarios with more than two outcomes per party #59

Closed apozas closed 2 years ago

apozas commented 2 years ago

Running the code

from causalinflation import InflationProblem, InflationSDP
sdp = InflationSDP(InflationProblem({"Lambda": ["A", "B"]},
                                 outcomes_per_party=[3, 2],
                                 settings_per_party=[2, 2],
                                 inflation_level_per_source=[2]))
sdp.generate_relaxation('npa2')

Produces the following error

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Input In [22], in <cell line: 1>()
----> 1 sdp.generate_relaxation('npa2')

File ~/Documents/GitHub/inflation/causalinflation/quantum/InflationSDP.py:448, in InflationSDP.generate_relaxation(self, column_specification)
    443 self.momentmatrix, self.orbits, self.symidx_to_sym_monarray_dict \
    444     = self._apply_inflation_symmetries(self.unsymmetrized_mm_idxs,
    445                                        self.unsymidx_to_unsym_monarray_dict,
    446                                        self.inflation_symmetries)
    447 for (k, v) in _unsymidx_from_hash_dict.items():
--> 448     self.canonsym_ndarray_from_hash_cache[k] = self.symidx_to_sym_monarray_dict[self.orbits[v]]
    449 del _unsymidx_from_hash_dict
    451 self.largest_moment_index = max(self.symidx_to_sym_monarray_dict.keys())

KeyError: 0

I don't find this error if I change the number of outcomes to [2,2], the inflation level to [1], or the NPA level to npa1. But I do find it when the number of outcomes is [3,3].

eliewolfe commented 2 years ago

Upon inspection, the _apply_inflation_symmetries was setting certain nonzero-index moments to the zero index under the permutation. This should never happen. Turns out that apply_inflation_symmetries was just fine, but the real bug was in calculate_momentmatrix, where where we were evaluating mon_is_zero BEFORE setting the monomial to canonical form. I fixed it, and also adjusted to_canonical to map all zero-equivalent operators to a canonical zero monomial.