logix-project / logix

AI Logging for Interpretability and Explainability🔬
Apache License 2.0
71 stars 6 forks source link

A few questions on how to use this library + hyperparameter choices #109

Closed hamishivi closed 1 month ago

hamishivi commented 1 month ago

Hi, thanks for this library and awesome paper! I think the logra idea is really cool, and I've been trying to use it in my own research. I've been running into some issues/poor results, though, and it would be useful to get your feedback if possible. For some context, I'm trying to apply this library to instruction tuning, and ideally trying to reduce the storage/compute costs somewhat (small logra ranks, etc). Unfortunately, I've been getting results on par with random selection right now, although I'm not entirely sure I'm using the library correctly, so any advice would be really useful.

  1. Is using run.add_lora() or setting LogIXScheduler(..., lora="random") preferable? I see both are done in different examples. I've been trying to follow the language modelling example, which uses add_lora but then sets lora="None" in the scheduler. This seems fine looking at the logs/code, but I just wanted to check I wasn't missing something.
  2. How important have you found the lo(g)ra size to be? I am using a rank of 6 right now to keep storage costs low, but I wonder if this is so small it leads to poor performance. If you have ablation results on the rank, this would be useful!
  3. Similarly, how important is the hessian? I know of other methods, e.g. LESS, that just use cosine similarity over LoRAs instead of any second-order features, and in my experience this has worked well for some of my own experiments. I would also be curious about ablations here.
  4. Any particular reason your examples use a sum loss instead of a mean loss in the language modelling example? I guess EK-FAC used a sum loss to make sure they have a matching loss function, but I was curious if there were any other reasons you chose this over the common default of a mean loss.

I'm happy to share code/email etc! Currently my code looks like the following (simplified somewhat):

# some model and data setup, including model and data_loader
# logix config:
logix_config = {
    "root_dir": ".",
    "logging": {
        "flush_threshold": 50000*8064, # I found that I had to reduce flush threshold to save all my datapoints.
        "num_workers": 1,
        "cpu_offload": True,
        "log_dtype": "float16",
    },
    "lora": {
        "init": "random",
        "rank": 6,
    }
}

run = logix.init(project=args.grad_save_path, config='tmp_logix/logix_config.yaml')
run.watch(model, name_filter=["att", "mlp"], type_filter=[nn.Linear])
run.add_lora()
scheduler = logix.LogIXScheduler(
        run, lora="none", hessian="none", save="grad"
)

model, data_loader = accelerator.prepare(model, data_loader)
model.eval()
for _ in scheduler:
    for i, batch in tqdm(enumerate(data_loader), total=len(data_loader)):
            # add dataset index to data_id to avoid collisions
            data_id = [tokenizer.decode(batch["input_ids"][0]) + f'_{i}']
            targets = batch.pop("labels")
            with run(data_id=data_id, mask=batch["attention_mask"]):
                model.zero_grad()
                lm_logits = model(**batch).logits
                shift_logits = lm_logits[..., :-1, :].contiguous()
                shift_labels = targets[..., 1:].contiguous()
                loss = F.cross_entropy(
                    shift_logits.view(-1, shift_logits.size(-1)),
                    shift_labels.view(-1),
                    reduction="mean",  # logix example code had sum?
                    ignore_index=-100,
                )
                accelerator.backward(loss)
        run.finalize()

# setup log dataloader (dataloader over train grads)
log_loader = run.build_log_dataloader(batch_size=64)

# setup test_dataset and test_data_loader
run.setup({"grad": ["log"]})
run.eval()
merged_test_logs = []
for idx, batch in enumerate(tqdm(test_data_loader)):
    data_id = tokenizer.batch_decode(batch["input_ids"], skip_special_tokens=True)
    targets = batch.pop("labels")
    with run(data_id=data_id, mask=batch["attention_mask"]):
        model.zero_grad()
        logits = model(**batch).logits
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = targets[..., 1:].contiguous()
        loss = F.cross_entropy(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
            reduction="mean",
            ignore_index=-100,
        )
        accelerator.backward(loss)

    test_log = run.get_log()
    merged_test_logs.append(copy.deepcopy(test_log))

merged_test_log = merge_logs(merged_test_logs)
results = run.influence.compute_influence_all(merged_test_log, log_loader, mode="cosine")

influence_scores = results["influence"]
sangkeun00 commented 1 month ago

Hello @hamishivi

Thanks for your interest in our work! I'll try my best to address your questions below.

  1. add_lora() vs LogIXScheduler: Firstly, you are correct that we used both options in different examples, mostly due to some complications when distributed training is enabled. When you enable (data-parallel) distributed training, you want to ensure that model weights on all workers have the same weights, and this is typically done when you wrap the model (e.g. accelerate.prepare()). However, if you run add_lora() after you wrap the model, it's possible that lora weights on different ranks may be different. If you LogIXScheduler, you are more susceptible to this issue. So, tldr, if you use distributed training, manually adding lora before wrapping the model is the safest option!
  2. In our experiments, we noticed that increasing the lo(g)ra rank seemed to generally lead to better results. That being said, I haven't tried such small ranks like 6 on large-scale experiments (I've only tried it on small-scale experiments like MNIST). I also want to point out that logra rank 6 may be quite different from lora rank 6 as the projection is happening in both directions in logra. I would suggest using 32 or 64 if your storage permits.
  3. In our experiments, the (accurate) Hessian was quite important in achieving good results. Looking at your code, it seems you set hessian=none . Computing the Hessian with LoGra should be very cheap, so I would suggest setting it to raw as in my language modeling experiments. If it's set to none, then it just computes gradient dot product as in LESS (another point is that LESS rescales gradients with Adam states, so it would likely work better than naive grad dot product).
  4. We mostly followed the EKFAC paper for this. Since language model pretraining uses a fixed sequence length, mean vs sum didn't really matter, but it would be interesting to experiment with different options!

Hope my answers were helpful! If you have any further concerns or questions, feel free to let me know!

sangkeun00 commented 1 month ago

@hamishivi If you have any further questions, I am happy to address them! Otherwise, I will close this issue soon!

hamishivi commented 1 month ago

Hi @sangkeun00 ! Sorry for not responding, I've been travelling recently. I'm trying out using hessians and will try using larger LoGra ranks soon. One thing I found is that out of the box:

I made some litle fixes in a fork, but I'm still in the process of working these out. I'm actually travelling this and next week so progress is a little slow. Thanks for checking in, feel free close if you want! If I have particularly burning issues I'm happy to re-open this issue or open new ones.

hamishivi commented 1 month ago

A small update: I still seem to get the inversion issue, even using float32 consistently, which blocks me from setting hessian=raw. It happens inconsistently, so unfortunately I can't provide an easy reproducible example right now. The error is like:

2024-07-22T10:09:36.118048056Z 
2024-07-22T10:09:36.343762089Z Traceback (most recent call last):
2024-07-22T10:09:36.343795456Z   File "/opt/conda/envs/venv/lib/python3.10/runpy.py", line 196, in _run_module_as_main
2024-07-22T10:09:36.344457422Z     return _run_code(code, main_globals, None,
2024-07-22T10:09:36.344479346Z   File "/opt/conda/envs/venv/lib/python3.10/runpy.py", line 86, in _run_code
2024-07-22T10:09:36.344499286Z     exec(code, run_globals)
2024-07-22T10:09:36.344507692Z   File "/gantry-runtime/minimal_multitask/compute_influence_logix.py", line 180, in <module>
2024-07-22T10:09:36.344577482Z     results = run.influence.compute_influence_all(merged_test_log, log_loader, mode="cosine")
2024-07-22T10:09:36.344585779Z   File "/opt/conda/envs/venv/lib/python3.10/site-packages/logix/analysis/influence_function.py", line 268, in compute_influence_all
2024-07-22T10:09:36.344646501Z     src_log = self.precondition(src_log, hessian=hessian, damping=damping)
2024-07-22T10:09:36.344665689Z   File "/opt/conda/envs/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
2024-07-22T10:09:36.344855169Z     return func(*args, **kwargs)
2024-07-22T10:09:36.344863305Z   File "/opt/conda/envs/venv/lib/python3.10/site-packages/logix/analysis/influence_function.py", line 74, in precondition
2024-07-22T10:09:36.344900389Z     preconditioned_grad = precondition_fn(
2024-07-22T10:09:36.344907834Z   File "/opt/conda/envs/venv/lib/python3.10/site-packages/logix/analysis/influence_function_utils.py", line 77, in precondition_raw
2024-07-22T10:09:36.344948646Z     cov_inverse = state.get_covariance_inverse_state(damping=damping)
2024-07-22T10:09:36.344955099Z   File "/opt/conda/envs/venv/lib/python3.10/site-packages/logix/state.py", line 148, in get_covariance_inverse_state
2024-07-22T10:09:36.345005299Z     self.covariance_inverse(damping=damping)
2024-07-22T10:09:36.345012183Z   File "/opt/conda/envs/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
2024-07-22T10:09:36.345054768Z     return func(*args, **kwargs)
2024-07-22T10:09:36.345061221Z   File "/opt/conda/envs/venv/lib/python3.10/site-packages/logix/state.py", line 130, in covariance_inverse
2024-07-22T10:09:36.345098005Z     self.covariance_inverse_state[module_name][mode] = torch.inverse(
2024-07-22T10:09:36.345343807Z torch._C._LinAlgError: linalg.inv: The diagonal element 1 is zero, the inversion could not be completed because the input matrix is singular.
sangkeun00 commented 1 month ago

Sorry for the delayed response! I've been busy with personal stuffs recently.

I haven't personally encountered this issue myself. Looking at your bug message, the issue seems that the Hessian is singular. I am a bit surprised as a damping term is added automatically when you set damping="none" (default), and this most likely ensures invertibility. The only potential issue I see is log_dtype being set to float16. If you compute gradients after fine-tuning your model, it is possible that your gradient norm gets very small beyond the float16 limit. This may lead to a large amount of 0 components in the Hessian, which then prevents matrix inversion.

To debug, I suggest you to go to your log directory, and open state/covariance_state.pt, and manually look into them, especially the one that causes this inversion error. Another thing you can do is setting log_dtype to float32 (this doesn't necessarily mean that you should also use fp32 for your model and training code as log_dtype is decoupled from dtype you use for training). If you want to set up a meeting, I am also open to it. Let me know if you have any other questions.

hamishivi commented 1 month ago

Thanks for the response! I tried setting log_dtype to float32 without luck (still got the same error), but I'll look into the covariance state to see what's going on, thanks. I'll close the issue for now.

I'm travelling a bit for personal stuff right now so a bit otherwise occupied, but if I'm still having issues next week I might reach out. Thanks so so so much for all your help! I think this is a really cool project :)

hamishivi commented 1 month ago

fyi, I at least worked out why I was getting the inversion issues - I had a few samples I hadn't filtered out that had no labels due to length issues, so of course their covariances were 0!

sangkeun00 commented 1 month ago

Glad to hear that you worked out the issue. If you want to use cosine similarity, I want to also warn you that it may sometimes give Nan due to division by 0. You may want to add a small value in this line (https://github.com/logix-project/logix/blob/main/logix/analysis/influence_function.py#L151) if you encounter this issue. Let me know if you face other issues anytime!