facebookresearch / diplomacy_cicero

Code for Cicero, an AI agent that plays the game of Diplomacy with open-domain natural language negotiation.
Other
1.28k stars 157 forks source link

Compiling protos problems #16

Open hirokinko opened 1 year ago

hirokinko commented 1 year ago

I ran the unit test and found two problems with the code generated from protos.

  1. google.protobuf.message is not imported
  2. KeyError: 'fairdiplomacy.ParlaiFlags' in unit_tests/test_bqre1p_lambdas.py

I can solve the first problem to apply the patch to the generation code, but the second problem seems to be in the code added in "heyhi/bin/patch_protos.py"; I've not solved this yet. Has anyone been able to solve this problem?

The following is the result of running the unit test.

❯ python -m pytest unit_tests/
========================================================== test session starts ===========================================================
platform linux -- Python 3.7.16, pytest-6.2.1, py-1.11.0, pluggy-0.13.1
rootdir: /home/hirokinko/Workspaces/diplomacy_cicero
plugins: hydra-core-1.1.2, requests-mock-1.10.0, regressions-2.4.1, datadir-1.4.1, anyio-3.6.2
collected 237 items                                                                                                                      

unit_tests/test_ampersand_uncorruption.py .                                                                                        [  0%]
unit_tests/test_bart_classifier.py .                                                                                               [  0%]
unit_tests/test_bqre1p_lambdas.py F                                                                                                [  1%]
unit_tests/test_carriage_returns.py .                                                                                              [  1%]
unit_tests/test_cfrstats_pickle.py ...                                                                                             [  2%]
unit_tests/test_check_permute_powers.py .                                                                                          [  3%]
unit_tests/test_discriminator_teacher_functions.py ................                                                                [ 10%]
unit_tests/test_draw_state_utils.py ..                                                                                             [ 10%]
unit_tests/test_message_editing.py ................                                                                                [ 17%]
unit_tests/test_order_idxs.py ...........                                                                                          [ 22%]
unit_tests/test_parlai_all_holds_filtering.py ....                                                                                 [ 24%]
unit_tests/test_parlai_flattener_to_unflattener.py ..............                                                                  [ 29%]
unit_tests/test_parlai_formatting.py .............................................                                                 [ 48%]
unit_tests/test_parlai_input_validation.py .                                                                                       [ 49%]
unit_tests/test_parlai_input_validation_parts.py ...............................                                                   [ 62%]
unit_tests/test_parlai_message_history_building.py ...                                                                             [ 63%]
unit_tests/test_parlai_message_history_truncation.py .                                                                             [ 64%]
unit_tests/test_parlai_order_filtering.py .                                                                                        [ 64%]
unit_tests/test_parlai_special_tokens.py .                                                                                         [ 64%]
unit_tests/test_pseudo_cache.py .                                                                                                  [ 65%]
unit_tests/test_pydipcc.py ...........................................................                                             [ 90%]
unit_tests/test_resample_duplicate_disbands.py .                                                                                   [ 90%]
unit_tests/test_rollout_spring_ending.py ............                                                                              [ 95%]
unit_tests/test_search_helpers.py .                                                                                                [ 96%]
unit_tests/test_sleep.py ......                                                                                                    [ 98%]
unit_tests/test_timestamp.py .                                                                                                     [ 99%]
unit_tests/test_utils_game.py .                                                                                                    [ 99%]
unit_tests/test_utils_orders.py .                                                                                                  [100%]

================================================================ FAILURES ================================================================
_________________________________________________________ TestBQRE1PLambdas.test _________________________________________________________

self = <unit_tests.test_bqre1p_lambdas.TestBQRE1PLambdas testMethod=test>

    def test(self):
        cfg = heyhi.conf.load_config(
            heyhi.conf.CONF_ROOT / "common/agents/for_tests/bqre1p_20210821_rol0.prototxt",
>           overrides=["bqre1p.base_searchbot_cfg.model_path=MOCKV2",],
        )

unit_tests/test_bqre1p_lambdas.py:19: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
heyhi/conf.py:397: in load_config
    return cfg.to_frozen()
conf/conf_pb2.py:64: in to_frozen
    value = maybe_to_dict(value)
conf/conf_pb2.py:47: in maybe_to_dict
    return msg.to_frozen()
conf/conf_pb2.py:64: in to_frozen
    value = maybe_to_dict(value)
conf/conf_pb2.py:47: in maybe_to_dict
    return msg.to_frozen()
conf/conf_pb2.py:64: in to_frozen
    value = maybe_to_dict(value)
conf/conf_pb2.py:47: in maybe_to_dict
    return msg.to_frozen()
conf/conf_pb2.py:64: in to_frozen
    value = maybe_to_dict(value)
conf/conf_pb2.py:47: in maybe_to_dict
    return msg.to_frozen()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = 

    def to_frozen(self):
        set_fields = frozenset(f[0].name for f in self.ListFields())

        # Mapping from fields within oneof to the name of the oneof.
        field_to_oneof_name = {}
        for oneof in descriptor.oneofs:
            for field in oneof.fields:
                field_to_oneof_name[field.name] = oneof.name

        def maybe_to_dict(msg):
            if isinstance(msg, _message.Message):
                return msg.to_frozen()
            return msg

        ret = {}
        for name, field in name2field.items():
            if name in field_to_oneof_name:
                chosen_oneof = self.WhichOneof(field_to_oneof_name[name])
                if chosen_oneof is None:
                    value = None
                elif name != chosen_oneof and name != field_to_oneof_name[name]:
                    value = None
                else:
                    value = getattr(self, chosen_oneof)
            else:
                value = getattr(self, name)

            if isinstance(value, _message.Message):
                value = maybe_to_dict(value)
            elif field.label == field.LABEL_REPEATED:
                if type(value).__name__.split(".")[-1] == "ScalarMapContainer":
                    value = {x: maybe_to_dict(value[x]) for x in value}
                else:
                    value = tuple(maybe_to_dict(x) for x in value)
            else:
                # A scalar.
                assert (
                    isinstance(value, (float, str, int, bool)) or value is None
                ), f"Excepted a value for {name} to be a scalar. Got {value}"
                if field.enum_type is not None and value is not None:
                    enum_values = field.enum_type.values_by_number
                    if value not in enum_values:
                        raise RuntimeError(
                            f"{name}: {value} not in {[v.name for v in enum_values.values()]}"
                        )
                    value = enum_values[value].name
                if name not in set_fields and not field.has_default_value:
                    value = False if field.type == field.TYPE_BOOL else None
            ret[name] = value
        for oneof in descriptor.oneofs:
            chosen = self.WhichOneof(oneof.name)
            ret["which_" + oneof.name] = chosen
            ret[oneof.name] = ret[chosen] if chosen is not None else None
>       return FROZEN_SYM_BD[descriptor.full_name](ret, self)
E       KeyError: 'fairdiplomacy.ParlaiFlags'

conf/conf_pb2.py:89: KeyError
============================================================ warnings summary ============================================================
fairdiplomacy/models/state_space.py:183
  /home/hirokinko/Workspaces/diplomacy_cicero/fairdiplomacy/models/state_space.py:183: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.
  Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
    adjacencies = np.zeros((len(locs), len(locs)), dtype=np.bool)  # type: ignore

unit_tests/test_parlai_formatting.py::TestOrderFormatting::test_task_token
  /home/hirokinko/Workspaces/diplomacy_cicero/unit_tests/test_parlai_formatting.py:295: DeprecationWarning: Please use assertEqual instead.
    "units: Austria: A BUD, A MUN, A RUM, F TRI; England: A BEL, F BRE, F ENG, F IRI, F NWY; France: A BUR, A GAS, A PAR, F POR; Germany: *A MUN, A BER, A HOL, F DEN, F SWE; Italy: A TYR, A VEN, F ADR, F TUN; Russia: A UKR, A WAR, F BAL, F SEV; Turkey: A ARM, A GRE, A SER, F AEG, F BLA [EO_STATE] F1902R England 1 1 order:",

unit_tests/test_parlai_special_tokens.py::TestSpecialTokens::test_special_tokens
  /home/hirokinko/Workspaces/diplomacy_cicero/unit_tests/test_parlai_special_tokens.py:21: DeprecationWarning: Please use assertEqual instead.
    self.assertEquals(v1_toks, st.SPECIAL_TOKENS_V1)

unit_tests/test_parlai_special_tokens.py::TestSpecialTokens::test_special_tokens
  /home/hirokinko/Workspaces/diplomacy_cicero/unit_tests/test_parlai_special_tokens.py:23: DeprecationWarning: Please use assertEqual instead.
    self.assertEquals(v1_toks, sorted(v1_toks, key=len, reverse=True))

unit_tests/test_parlai_special_tokens.py::TestSpecialTokens::test_special_tokens
  /home/hirokinko/Workspaces/diplomacy_cicero/unit_tests/test_parlai_special_tokens.py:138: DeprecationWarning: Please use assertEqual instead.
    "->",

unit_tests/test_parlai_special_tokens.py::TestSpecialTokens::test_special_tokens
  /home/hirokinko/Workspaces/diplomacy_cicero/unit_tests/test_parlai_special_tokens.py:142: DeprecationWarning: Please use assertEqual instead.
    self.assertEquals(v2_toks, sorted(v2_toks, key=len, reverse=True))

-- Docs: https://docs.pytest.org/en/stable/warnings.html
======================================================== short test summary info =========================================================
FAILED unit_tests/test_bqre1p_lambdas.py::TestBQRE1PLambdas::test - KeyError: 'fairdiplomacy.ParlaiFlags'
=============================================== 1 failed, 236 passed, 6 warnings in 8.47s ================================================
c-flaherty commented 1 year ago

Did you run make first? See comment here for info on FROZEN_SYM_BD: https://github.com/facebookresearch/diplomacy_cicero/issues/20. It's by autogenerated code generated by make.

You shouldn't have to add an google.protobuf.message import anywhere. I tested and unit tests all pass. Maybe try confirming make runs successfully, and then we can revisit this if it's still an issue?

lightvector commented 1 year ago

Here is where the protobuf message for ParlaiFlags is defined: https://github.com/facebookresearch/diplomacy_cicero/blob/main/conf/agents.proto#L692

In theory if everything is working, this should get compiled by make into various definitions in the autogenerated file conf/agents_pb2.py, such as this one that should be near the bottom of that file:

FROZEN_SYM_BD['fairdiplomacy.ParlaiFlags'] = create_frozen_class(_sym_db.GetSymbol('fairdiplomacy.ParlaiFlags'), 'ParlaiFlags')
FrozenParlaiFlags = FROZEN_SYM_BD['fairdiplomacy.ParlaiFlags']

among a lot of other lines adding entries to FROZEN_SYM_BD in that same file:


...
FROZEN_SYM_BD['fairdiplomacy.ParlaiNoPressAgent'] = create_frozen_class(_sym_db.GetSymbol('fairdiplomacy.ParlaiNoPressAgent'), 'ParlaiNoPressAgent')
FrozenParlaiNoPressAgent = FROZEN_SYM_BD['fairdiplomacy.ParlaiNoPressAgent']
FROZEN_SYM_BD['fairdiplomacy.ParlaiNonsenseDetectionEnsemble'] = create_frozen_class(_sym_db.GetSymbol('fairdiplomacy.ParlaiNonsenseDetectionEnsemble'), 'ParlaiNonsenseDetectionEnsemble')
FrozenParlaiNonsenseDetectionEnsemble = FROZEN_SYM_BD['fairdiplomacy.ParlaiNonsenseDetectionEnsemble']
FROZEN_SYM_BD['fairdiplomacy.NonsenseClassifier'] = create_frozen_class(_sym_db.GetSymbol('fairdiplomacy.NonsenseClassifier'), 'NonsenseClassifier')
FrozenNonsenseClassifier = FROZEN_SYM_BD['fairdiplomacy.NonsenseClassifier']
FROZEN_SYM_BD['fairdiplomacy.ParlaiDiscriminativeNucleusModel'] = create_frozen_class(_sym_db.GetSymbol('fairdiplomacy.ParlaiDiscriminativeNucleusModel'), 'ParlaiDiscriminativeNucleusModel')
FrozenParlaiDiscriminativeNucleusModel = FROZEN_SYM_BD['fairdiplomacy.ParlaiDiscriminativeNucleusModel']
FROZEN_SYM_BD['fairdiplomacy.ParlaiModel'] = create_frozen_class(_sym_db.GetSymbol('fairdiplomacy.ParlaiModel'), 'ParlaiModel')
FrozenParlaiModel = FROZEN_SYM_BD['fairdiplomacy.ParlaiModel']
FROZEN_SYM_BD['fairdiplomacy.ParlaiFlags'] = create_frozen_class(_sym_db.GetSymbol('fairdiplomacy.ParlaiFlags'), 'ParlaiFlags')
FrozenParlaiFlags = FROZEN_SYM_BD['fairdiplomacy.ParlaiFlags']
FROZEN_SYM_BD['fairdiplomacy.ReproAgent'] = create_frozen_class(_sym_db.GetSymbol('fairdiplomacy.ReproAgent'), 'ReproAgent')
FrozenReproAgent = FROZEN_SYM_BD['fairdiplomacy.ReproAgent']
FROZEN_SYM_BD['fairdiplomacy.BRSearchAgent'] = create_frozen_class(_sym_db.GetSymbol('fairdiplomacy.BRSearchAgent'), 'BRSearchAgent')
FrozenBRSearchAgent = FROZEN_SYM_BD['fairdiplomacy.BRSearchAgent']
FROZEN_SYM_BD['fairdiplomacy.TheBestAgent'] = create_frozen_class(_sym_db.GetSymbol('fairdiplomacy.TheBestAgent'), 'TheBestAgent')
FrozenTheBestAgent = FROZEN_SYM_BD['fairdiplomacy.TheBestAgent']
FROZEN_SYM_BD['fairdiplomacy.BQRE1PAgent'] = create_frozen_class(_sym_db.GetSymbol('fairdiplomacy.BQRE1PAgent'), 'BQRE1PAgent')
FrozenBQRE1PAgent = FROZEN_SYM_BD['fairdiplomacy.BQRE1PAgent']
FROZEN_SYM_BD['fairdiplomacy.BQRE1PAgent.PlayerTypes'] = create_frozen_class(_sym_db.GetSymbol('fairdiplomacy.BQRE1PAgent.PlayerTypes'), 'BQRE1PAgent.PlayerTypes')
FrozenBQRE1PAgent.PlayerTypes = FROZEN_SYM_BD['fairdiplomacy.BQRE1PAgent.PlayerTypes']
...