timaeus-research / devinterp

Tools for studying developmental interpretability in neural networks.
62 stars 14 forks source link

Add support for non-regression/classification data. #79

Closed jqhoogland closed 5 hours ago

jqhoogland commented 3 months ago

Currently, the library expects the dataloader to provide / the model to consume (x, y) pairs. This isn't appropriate for, e.g., autoregressive tasks like language modeling.

See, for example, HuggingFace's Trainer._prepare_input and the snippet below for how to handle this (we probably want to allow the user to return other intermediate results in evaluate, which they may want to use for estimation:

data = _prepare_input(data, device)
results = evaluate(model, data)

if isinstance(results, dict):
    loss = results.pop("loss")
elif isinstance(results, tuple):
    loss = results[0]
    if len(results) > 1:
        results = loss[1:]
elif isinstance(results, torch.Tensor):
    loss = results
    results = None
elif hasattr(results, "loss"):
    loss = results.loss
else:
    raise ValueError("compute_loss must return a dict, tuple, or torch.Tensor")

I'll file a PR, but probably not until after NeurIPS.

jqhoogland commented 3 months ago

This is also going to require changes to how the initial_loss is computed.

svwingerden commented 5 hours ago

Resolved by https://github.com/timaeus-research/devinterp/pull/80