stanfordnlp / pyreft

ReFT: Representation Finetuning for Language Models
https://arxiv.org/abs/2404.03592
Apache License 2.0
1.12k stars 96 forks source link

[P0] Multigpu and model sharding #25

Open frankaging opened 6 months ago

frankaging commented 6 months ago

Descriptions:

pyvene library was designed for model interpretability, not for some production use case which requires training and inference efficiency. pyreft is different. It will have some practical use cases, and require all those production-ready training and inference efficiency.

This ticket may require multiple PRs, including changes in pyvene:

frankaging commented 6 months ago

currently, only a single GPU is supported by pyvene.

if don't do single GPU guarding like running 'export CUDA_VISIBLE_DEVICES=0'

the training will throw an error: ref:: https://github.com/stanfordnlp/pyreft/issues/31

danikhan632 commented 5 months ago

    def compute_loss(self, intervenable, inputs, return_outputs=False):
        # Directly use tensors; avoid premature conversions
        subspaces = inputs["subspaces"].permute(1, 0, 2) if "subspaces" in inputs else None

        # Prepare unit locations
        unit_locations = {
            "sources->base": (None, inputs["intervention_locations"].permute(1, 0, 2))
        }

        # Ensure tensor dimensions and devices are correctly set
        print("Debug Info: ", unit_locations["sources->base"][1].shape, inputs["input_ids"].device)

        # Forward pass
        _, cf_outputs = intervenable(
            {"input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"]},
            unit_locations=unit_locations,
            labels=inputs["labels"],
            subspaces=subspaces
        )

        # Return outputs
        return (cf_outputs.loss, cf_outputs) if return_outputs else cf_outputs.loss

I fixed the previous issue in pyvene but now encounter this odd issue depending on number for GPUs compute_loss isn't called by any pyreft/pyvene code but rather huggingface Trainer

output using single gpu: Debug Info: torch.Size([4, 4, 1]) cuda:0 Intervening...

output using two gpus: Debug Info: torch.Size([4, 8, 1]) cuda:0 Intervening...

output using three gpus: Debug Info: torch.Size([4, 12, 1]) cuda:0 Intervening...