Closed RonanKMcGovern closed 4 months ago
@RonanKMcGovern thanks for the question.
Could you print the inputs["subspaces"]
dimension, before passing into the call in line 82 above? it should be a 3d tensor.
Before permutation, the first shape should be the batch size, after permutation, the first shape should be 1
if you are doing a single intervention.
i will update the tutorial - it is currently incompatible with the new infra code.
@RonanKMcGovern i checked in the fix: https://github.com/stanfordnlp/pyreft/pull/75
(i haven't rerun the notebook to verify tho, but please let me know if it works for your examples.) If it works, feel free to close this issue! ty!
p.s. please install with top of the tree after git pull
: pip install .
marking this issue as closed for now --- if the problem still exists, feel free to reopen or open a new issue.
yes I think I'm still running into this
`File ~/micromamba/envs/trtf/lib/python3.9/site-packages/pyreft/dataset.py:161, in ReftDataset.init(self, task, data_path, tokenizer, data_split, dataset, seed, max_n_example, kwargs) 159 for i, data_item in enumerate(tqdm(self.task_dataset)): 160 tokenized, last_position = self.tokenize(data_item) --> 161 tokenized = self.compute_intervention_and_subspaces(i, data_item, tokenized, last_position, kwargs) 162 self.result.append(tokenized)
File ~/micromamba/envs/trtf/lib/python3.9/site-packages/pyreft/dataset.py:251, in ReftDataset.compute_intervention_and_subspaces(self, id, data_item, result, last_position, *kwargs) 249 # we now assume each task has a constant subspaces 250 _subspaces = [data_item["subspaces"]] num_interventions --> 251 result["subspaces"] = _subspaces 253 return result
KeyError: 'subspaces'`
EDIT: oh sorry, wrong issue?
EDIT2: nevermind, restarting the notebook somehow fixed this
When running the script it appears that result['subspaces'] has not been initialized meaning that it cannot be appended to in:
On manually fixing that, it appears there's an issue permuting the subspaces:
because running the training results in: