stanford-futuredata / Baleen

Baleen: Robust Multi-Hop Reasoning at Scale via Condensed Retrieval (NeurIPS'21)
MIT License
42 stars 5 forks source link

TypeError after inference #2

Closed MFajcik closed 1 year ago

MFajcik commented 2 years ago

Hi, when saving the inference results as json file via hover_inference.py, the dictionary contains set. Sets are not serializable via json. Thus the saving fails.

python -m hover_inference --root ./experiments/ --datadir . --index wiki17.hover.2bit
Traceback (most recent call last):                                                                                                                            
  File "xxx/conda/envs/colbert-v0.4/lib/python3.7/runpy.py", line 193, in _run_module_as_main                                                      
    "__main__", mod_spec)                                                                                                                                     
  File "xxx/.conda/envs/colbert-v0.4/lib/python3.7/runpy.py", line 85, in _run_code                                                                 
    exec(code, run_globals)                                                                                                                                   
  File "yyy/baleen/Baleen/hover_inference.py", line 53, in <module>                                                                             
    main(args)                                                                                                                                                
  File "yyy/baleen/Baleen/hover_inference.py", line 43, in main                                                                                 
    f.write(ujson.dumps(outputs) + '\n')                                                                                                                      
TypeError: {3910663, 1373715, 833561, 2479648, 3921953, 3408419, 3188274, 1399859, 372789, 1117238, 3283510, 3342401, 2585678, 1428049, 4948563, 1399892, 4449
365, 4216407, 4502103, 819287, 3598429, 5187684, 625781, 3042432, 1485442, 3487369, 4166284, 148110, 3713169, 1338005, 1951900, 936613, 437414, 556716, 266616
2, 573620, 4666549, 638144, 4154562, 4315335, 4230859, 4788429, 2613967, 174801, 4054227, 3768532, 5224152, 4914913, 2469090, 460517, 4820205, 1360625, 426418
5, 3064580, 424200, 4601613, 4707087, 2140434, 3422995, 3878677, 3583776, 2412329, 5212973, 3787053, 4286261, 2512694, 821559, 4174137, 3351359, 349002, 38961
43, 3414369, 875881, 1557358, 3957103, 4061041, 3913073, 2986353, 959347, 803705, 4757370, 1752441, 2359693, 4729260, 1178030, 1897903, 5206962, 564149, 42382
75, 4074960, 1900502, 4158425, 4635100, 4552679, 1106923, 3795442, 3049975, 2750972, 4602365, 1399295} is not JSON serializable

Every item in dictionary to be saved looks like this

0: ([(424200, 2), (4635100, 1), (4635100, 0)], 
{3910663, 1373715, 833561, 2479648, 3921953, 3408419, 3188274, 1399859, 372789, 1117238, 3283510, 3342401, 2585678, 1428049, 4948563, 1399892, 4449365, 4216407, 4502103, 819287, 3598429, 5187684, 625781, 3042432, 1485442, 3487369, 4166284, 148110, 3713169, 1338005, 19
51900, 936613, 437414, 556716, 2666162, 573620, 4666549, 638144, 4154562, 4315335, 4230859, 4788429, 2613967, 174801, 4054227, 3768532, 5224152, 4914913, 2469090, 460517, 4820205, 1360625, 4264185, 3064580, 424200, 4601613, 4707087, 2140434, 3422995, 3878677, 3583776, 2412329, 5212973, 3787053, 4286261, 2512694, 821559, 4174137, 3351359, 349002, 3896143, 3414369, 875881, 1557358, 3957103, 4061041, 3913073, 2986353, 959347, 803705, 4757370, 1752441, 2359693, 4729260, 1178030, 1897903, 5206962, 564149, 4238275, 4074960, 1900502, 4158425, 4635100, 4552679, 1106923, 3795442, 3049975, 2750972, 4602365, 1399295})

This is quite annoying, when spending few hours inferring the actual retrieval results :). Cheers, Martin

environment

name: colbert-v0.4
channels:
  - pytorch
  - conda-forge
  - defaults
dependencies:
  - _libgcc_mutex=0.1=conda_forge
  - _openmp_mutex=4.5=2_kmp_llvm
  - blas=2.114=mkl
  - blas-devel=3.9.0=14_linux64_mkl
  - bzip2=1.0.8=h7f98852_4
  - ca-certificates=2021.10.8=ha878542_0
  - cudatoolkit=11.1.1=h6406543_10
  - cupy=10.4.0=py37h52a254a_0
  - faiss=1.7.0=py37cuda111hcc9d9d6_8_cuda
  - faiss-gpu=1.7.0=h788eb59_8
  - ffmpeg=4.3=hf484d3e_0
  - freetype=2.10.4=h0708190_1
  - gmp=6.2.1=h58526e2_0
  - gnutls=3.6.13=h85f3911_1
  - jpeg=9b=h024ee3a_2
  - lame=3.100=h7f98852_1001
  - ld_impl_linux-64=2.36.1=hea4e1c9_2
  - libblas=3.9.0=14_linux64_mkl
  - libcblas=3.9.0=14_linux64_mkl
  - libfaiss=1.7.0=cuda111hf54f04a_8_cuda
  - libfaiss-avx2=1.7.0=cuda111h1234567_8_cuda
  - libffi=3.4.2=h7f98852_5
  - libgcc-ng=11.2.0=h1d223b6_16
  - libgfortran-ng=11.2.0=h69a702a_16
  - libgfortran5=11.2.0=h5c6108e_16
  - libiconv=1.16=h516909a_0
  - liblapack=3.9.0=14_linux64_mkl
  - liblapacke=3.9.0=14_linux64_mkl
  - libnsl=2.0.0=h7f98852_0
  - libpng=1.6.37=h21135ba_2
  - libstdcxx-ng=11.2.0=he4da1e4_16
  - libtiff=4.0.9=he6b73bb_1
  - libuv=1.43.0=h7f98852_0
  - libzlib=1.2.11=h166bdaf_1014
  - llvm-openmp=13.0.1=he0ac6c6_1
  - mkl=2022.0.1=h8d4b97c_803
  - mkl-devel=2022.0.1=ha770c72_804
  - mkl-include=2022.0.1=h8d4b97c_803
  - ncurses=6.3=h27087fc_1
  - nettle=3.6=he412f7d_0
  - ninja=1.10.2=h4bd325d_1
  - numpy=1.21.6=py37h976b520_0
  - olefile=0.46=pyh9f0ad1d_1
  - openh264=2.1.1=h780b84a_0
  - openssl=3.0.3=h166bdaf_0
  - pillow=5.4.1=py37h34e0f95_0
  - pip=21.0.1=pyhd8ed1ab_0
  - python=3.7.12=hf930737_100_cpython
  - python_abi=3.7=2_cp37m
  - pytorch=1.9.0=py3.7_cuda11.1_cudnn8.0.5_0
  - readline=8.1=h46c0cb4_0
  - setuptools=62.1.0=py37h89c1867_0
  - sqlite=3.38.2=h4ff8645_0
  - tbb=2021.5.0=h924138e_1
  - tk=8.6.12=h27826a3_0
  - torchaudio=0.9.0=py37
  - torchvision=0.10.0=py37_cu111
  - wheel=0.37.1=pyhd8ed1ab_0
  - xz=5.2.5=h516909a_1
  - zlib=1.2.11=h166bdaf_1014
  - pip:
    - anyio==3.5.0
    - argon2-cffi==21.3.0
    - argon2-cffi-bindings==21.2.0
    - attrs==21.4.0
    - babel==2.10.1
    - backcall==0.2.0
    - beautifulsoup4==4.11.1
    - bitarray==2.4.1
    - bleach==5.0.0
    - blis==0.7.7
    - catalogue==2.0.7
    - certifi==2021.10.8
    - cffi==1.15.0
    - charset-normalizer==2.0.12
    - click==8.0.4
    - cymem==2.0.6
    - debugpy==1.6.0
    - decorator==5.1.1
    - defusedxml==0.7.1
    - entrypoints==0.4
    - fastjsonschema==2.15.3
    - fastrlock==0.8
    - filelock==3.6.0
    - gitdb==4.0.9
    - gitpython==3.1.27
    - huggingface-hub==0.5.1
    - idna==3.3
    - importlib-metadata==4.11.3
    - importlib-resources==5.7.1
    - ipykernel==6.13.0
    - ipython==7.32.0
    - ipython-genutils==0.2.0
    - ipywidgets==7.7.0
    - jedi==0.18.1
    - jinja2==3.1.1
    - joblib==1.1.0
    - json5==0.9.6
    - jsonschema==4.4.0
    - jupyter==1.0.0
    - jupyter-client==7.3.0
    - jupyter-console==6.4.3
    - jupyter-core==4.10.0
    - jupyter-server==1.16.0
    - jupyterlab==3.3.4
    - jupyterlab-pygments==0.2.2
    - jupyterlab-server==2.13.0
    - jupyterlab-widgets==1.1.0
    - langcodes==3.3.0
    - markupsafe==2.1.1
    - matplotlib-inline==0.1.3
    - mistune==0.8.4
    - murmurhash==1.0.7
    - nbclassic==0.3.7
    - nbclient==0.6.0
    - nbconvert==6.5.0
    - nbformat==5.3.0
    - nest-asyncio==1.5.5
    - notebook==6.4.11
    - notebook-shim==0.1.0
    - packaging==21.3
    - pandocfilters==1.5.0
    - parso==0.8.3
    - pathy==0.6.1
    - pexpect==4.8.0
    - pickleshare==0.7.5
    - preshed==3.0.6
    - prometheus-client==0.14.1
    - prompt-toolkit==3.0.29
    - psutil==5.9.0
    - ptyprocess==0.7.0
    - pycparser==2.21
    - pydantic==1.8.2
    - pygments==2.12.0
    - pyparsing==3.0.8
    - pyrsistent==0.18.1
    - python-dateutil==2.8.2
    - pytz==2022.1
    - pyyaml==6.0
    - pyzmq==22.3.0
    - qtconsole==5.3.0
    - qtpy==2.0.1
    - regex==2022.4.24
    - requests==2.27.1
    - sacremoses==0.0.49
    - scipy==1.7.3
    - send2trash==1.8.0
    - six==1.16.0
    - smart-open==5.2.1
    - smmap==5.0.0
    - sniffio==1.2.0
    - soupsieve==2.3.2.post1
    - spacy==3.2.4
    - spacy-legacy==3.0.9
    - spacy-loggers==1.0.2
    - srsly==2.4.3
    - terminado==0.13.3
    - thinc==8.0.15
    - tinycss2==1.1.1
    - tokenizers==0.10.3
    - tornado==6.1
    - tqdm==4.64.0
    - traitlets==5.1.1
    - transformers==4.10.0
    - typer==0.4.1
    - typing-extensions==3.10.0.2
    - ujson==5.2.0
    - urllib3==1.26.9
    - wasabi==0.9.1
    - wcwidth==0.2.5
    - webencodings==0.5.1
    - websocket-client==1.3.2
    - widgetsnbextension==3.6.0
    - zipp==3.8.0
prefix: xxx/.conda/envs/colbert-v0.4
okhat commented 2 years ago

Good catch! Can you cast the set to a list? That sounds like it'll fix this. If you make a pull request, I'll merge it.

MFajcik commented 2 years ago

@okhat Wouldn't it be better to return N deduplicated-lists (where N is number of hops) from COLBERT engine. So the individual retrieval results would have preserved order?

I would submit the pull-request for COLBERT, but I am not sure if this won't cause problems with some scripts you have.

edit:/ code-wise something like

from baleen.utils.loaders import *
from baleen.condenser.condense import Condenser

class Baleen:
    def __init__(self, collectionX_path: str, searcher, condenser: Condenser):
        self.collectionX = load_collectionX(collectionX_path)
        self.searcher = searcher
        self.condenser = condenser

    def search(self, query, num_hops, depth=100, verbose=False):
        assert depth % num_hops == 0, f"depth={depth} must be divisible by num_hops={num_hops}."
        k = depth // num_hops

        searcher = self.searcher
        condenser = self.condenser
        collectionX = self.collectionX

        facts = []
        stage1_preds = None
        context = None

        pids_bag = [[] for _ in range(num_hops)]

        for hop_idx in range(0, num_hops):
            ranking = list(zip(*searcher.search(query, context=context, k=depth)))
            ranking_ = []

            facts_pids = set([pid for pid, _ in facts])

            for pid, rank, score in ranking:
                # print(f'[{score}] \t\t {searcher.collection[pid]}')
                if len(ranking_) < k and pid not in facts_pids:
                    ranking_.append(pid)

                if len(pids_bag[hop_idx]) < k:
                    if all(pid not in pids_bag[hi] for hi in range(num_hops)):
                        pids_bag[hop_idx].append(pid)

            stage1_preds, facts, stage2_L3x = condenser.condense(query, backs=facts, ranking=ranking_)
            context = ' [SEP] '.join([collectionX.get((pid, sid), '') for pid, sid in facts])

        assert sum(len(pids_per_hop) for pids_per_hop in pids_bag) == depth #//edit fixed assert

        return stage2_L3x, pids_bag, stage1_preds