ApolloResearch / rib

Library for methods related to the Local Interaction Basis (LIB)
MIT License
3 stars 0 forks source link

`load_sequential_transformer` device argument doesn't work #220

Closed nix-apollo closed 8 months ago

nix-apollo commented 9 months ago

The following fails with a device error:

from rib.loader import load_sequential_transformer
import torch

model, _ = load_sequential_transformer(
    node_layers=['mlp_in.5', 'mlp_out.5'],
    last_pos_module_type=None,
    tlens_pretrained="pythia-14m",
    tlens_model_path=None,
    eps=None,
    fold_bias=True,
    device="cuda",
)

x = torch.tensor([[100, 200]], device="cuda")
# 
model(x)

Fails due to a device error. Running model.to("cuda") fixes it.

danbraunai-apollo commented 8 months ago

@nix-apollo I can't replicate this bug. Note that I had to remove the eps argument. Can you try again and close issue if it still exists.

nix-apollo commented 8 months ago

Yep works fine for me now. Unsure what changed.