wiseodd / llm-bayesopt-exps

Official experiment code for the "Sober Look at LLMs for Material Discovery" paper.
1 stars 0 forks source link

run_finetuning.py raises: KeyError: 'inputs_id' when it starts "Fitting Laplace" #2

Closed fablos closed 2 months ago

fablos commented 2 months ago

python run_finetuning.py --foundation_model t5-base --problem redox-mer completes the training but soon after Fitting Laplace raises the following error stack:

Test Function: redox-mer; Foundation Model: t5-base; Prompt Type: just-smiles; Randseed: 1
---------------------------------------------------------------------------------------------------------------

Fitting Laplace...
Traceback (most recent call last):
  File "/home/user/wok/llm-bayesopt-exps/run_finetuning.py", line 274, in <module>
    model = LoRALLMBayesOpt(
            ^^^^^^^^^^^^^^^^
  File "/home/user/wok/llm-bayesopt-exps/llm_bayesopt/lora.py", line 54, in __init__
    super().__init__(
  File "/home/user/wok/llm-bayesopt-exps/llm_bayesopt/base.py", line 59, in __init__
    self.train_model()
  File "/home/user/wok/llm-bayesopt-exps/llm_bayesopt/lora.py", line 71, in train_model
    self._posthoc_laplace(train_loader)
  File "/home/user/wok/llm-bayesopt-exps/llm_bayesopt/lora.py", line 245, in _posthoc_laplace
    self.bnn.fit(train_loader)
  File "/home/user/miniconda3/envs/sober/lib/python3.11/site-packages/laplace/baselaplace.py", line 1592, in fit
    super().fit(train_loader, override=override, progress_bar=progress_bar)
  File "/home/user/miniconda3/envs/sober/lib/python3.11/site-packages/laplace/baselaplace.py", line 844, in fit
    loss_batch, H_batch = self._curv_closure(X, y, N=N)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/miniconda3/envs/sober/lib/python3.11/site-packages/laplace/baselaplace.py", line 1565, in _curv_closure
    return self.backend.kron(X, y, N=N, **self._asdl_fisher_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/miniconda3/envs/sober/lib/python3.11/site-packages/laplace/curvature/curvlinops.py", line 87, in kron
    linop = KFACLinearOperator(
            ^^^^^^^^^^^^^^^^^^^
  File "/home/user/miniconda3/envs/sober/lib/python3.11/site-packages/curvlinops/kfac.py", line 264, in __init__
    super().__init__(
  File "/home/user/miniconda3/envs/sober/lib/python3.11/site-packages/curvlinops/_base.py", line 126, in __init__
    sum(
  File "/home/user/miniconda3/envs/sober/lib/python3.11/site-packages/curvlinops/_base.py", line 127, in <genexpr>
    self._batch_size_fn(X)
  File "/home/user/miniconda3/envs/sober/lib/python3.11/site-packages/laplace/curvature/curvlinops.py", line 85, in <lambda>
    kwargs["batch_size_fn"] = lambda x: x[self.dict_key_x].shape[0]
                                        ~^^^^^^^^^^^^^^^^^
  File "/home/user/miniconda3/envs/sober/lib/python3.11/site-packages/transformers/tokenization_utils_base.py", line 257, in __getitem__
    return self.data[item]
           ~~~~~~~~~^^^^^^
KeyError: 'inputs_id'

Conda Enviroment on Ubuntu 24.04 created following instruction in Setup from the file README.md and fix pip install git+https://github.com/aleximmer/laplace@update-deps .

fablos commented 2 months ago

Quick fix, by simply adding dict_key_x="input_ids" while we create the Laplace instance in the file lora.py seems to fix the issue.

  if cfg.subset_of_weights == "last_layer":
      self.bnn = Laplace(
          model,
          likelihood="regression",
          subset_of_weights=cfg.subset_of_weights,
          hessian_structure=cfg.hess_factorization,
          sigma_noise=1 if cfg.noise_var is None else math.sqrt(cfg.noise_var),
          last_layer_name=cfg.last_layer_name,
          dict_key_x="input_ids",
      )
  else:
      self.bnn = Laplace(
          model,
          likelihood="regression",
          subset_of_weights=cfg.subset_of_weights,
          hessian_structure=cfg.hess_factorization,
          sigma_noise=1 if cfg.noise_var is None else math.sqrt(cfg.noise_var),
          dict_key_x="input_ids",
      )
wiseodd commented 2 months ago

Yup, thanks a lot for testing this! The code for the paper uses a pre-released version of laplace-torch. Since then laplace-torch has been updated to v0.2 and so there are some breaks here and there.

wiseodd commented 2 months ago

Fixed by #3 and also upstream in laplace-torch. Thanks @fablos!