kyegomez / BitNet

Implementation of "BitNet: Scaling 1-bit Transformers for Large Language Models" in pytorch
https://discord.gg/qUtxnK2NMf
MIT License
1.69k stars 155 forks source link

About 'replace_hf.py' #21

Closed chyoob closed 4 months ago

chyoob commented 8 months ago

Hello @kyegomez

In the inference code of huggingface_example.py, it appears that replace_hf is executed, followed immediately by inference. However, upon examining replace_hf.py, I noticed it converts linear layers to bitlinear layers and seems to declare new weights. I'm curious if there's a need for additional code to transfer the original weights to the bitlinear layers.

maybe ... like this?

def replace_linears_in_hf(
    model,
):
    """
    Replaces all instances of nn.Linear in the given model with BitLinear15b.

    Args:
        model (nn.Module): The model to modify.

    Returns:
        None
    """
    for name, module in model.named_children():
        if isinstance(module, nn.Linear):
            # Replace the nn.Linear with BitLinear matching in features and and out_features, and add it to the model
            new_module = BitLinear(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None)

            with torch.no_grad():
                new_module.weight = module.weight
                if module.bias is not None:
                    new_module.bias = module.bias
            setattr(model, name, new_module)
        else:
            # Recursively apply to child modules
            replace_linears_in_hf(module)

Thanks.

Upvote & Fund

Fund with Polar

github-actions[bot] commented 6 months ago

Stale issue message

kyegomez commented 6 months ago

@chyoob great idea, submit a pull request pls

github-actions[bot] commented 4 months ago

Stale issue message