AlignmentResearch / tuned-lens

Tools for understanding how transformer predictions are built layer-by-layer
https://tuned-lens.readthedocs.io/en/latest/
MIT License
432 stars 47 forks source link

Has anyone trained a tuned lens on Gemma-2b or other Gemma models? #129

Open jbloomAus opened 8 months ago

levmckinney commented 7 months ago

Not that I'm aware of but their should be support for the Gemma architecture in the next release see https://github.com/AlignmentResearch/tuned-lens/pull/125.

If anyone does train one, I'm accepting PRs to https://huggingface.co/spaces/AlignmentResearch/tuned-lens/discussions. An example of a good PR adding a model is https://huggingface.co/spaces/AlignmentResearch/tuned-lens/discussions/45.

Does anyone have any idea of what Gemma's pre-training set consisted of? When training lenses in the past we've tried to keep the training set as close as possible to the pretraining data distribution. If no one knows, we can always fall by to just using the RedPJ sample again which is what we did for Llama 2.

prof-schacht commented 4 weeks ago

Did somebody train a Gemma-2-9b-it Tuned_lens model since the commit? I tried it but it failed with the following error: Traceback (most recent call last): File "/root/miniconda/envs/tunedlens/bin/tuned-lens", line 8, in sys.exit(main()) File "/root/miniconda/envs/tunedlens/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/errors/init.py", line 355, in wrapper return f(*args, **kwargs) File "/proj/tuned_lens/tuned-lens/tuned_lens/main.py", line 43, in main prog.execute() File "/proj/tuned_lens/tuned-lens/tuned_lens/main.py", line 33, in execute self.command.execute() File "/proj/tuned_lens/tuned-lens/tuned_lens/scripts/train_loop.py", line 376, in execute state, model, grad_acc_steps = self.setup() File "/proj/tuned_lens/tuned-lens/tuned_lens/scripts/train_loop.py", line 365, in setup grad_acc_steps = self.calculate_gradient_accumulation_steps( File "/proj/tuned_lens/tuned-lens/tuned_lens/scripts/train_loop.py", line 280, in calculate_gradient_accumulation_steps raise ValueError( ValueError: Can only take 0.56 steps on dataset with --tokens_per_step=262144.Requested 250 steps.