luyug / GradCache

Run Effective Large Batch Contrastive Learning Beyond GPU/TPU Memory Constraint
Apache License 2.0
326 stars 19 forks source link

traning speed is very slow #30

Open liuweie opened 1 week ago

liuweie commented 1 week ago

Hi, I use grad_cache to train my model, but it seems very slow, I want to konw is this normal? Does using grad cache generally affect the training speed?

luyug commented 1 week ago

It'd be hard to diagnose this based on qualitative descriptions. Maybe you can share some of you setups, observed&reference throughput/latency, etc.

liuweie commented 1 week ago

It'd be hard to diagnose this based on qualitative descriptions. Maybe you can share some of you setups, observed&reference throughput/latency, etc.

thanks, I am using huggingface Trainer to train a Qwen7B model, here is my setups and corespodding code:

① compute loss function, which override the original Trainer compute_loss function:

 def compute_loss(
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        return_outputs: bool = False,
        ) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:

        features, labels = inputs
        query_input = features[0]
        pos_doc_input = features[1]

        loss_fn = DistributedContrastiveLoss(temperature=20, n_hard_negatives=0)
        gc = GradCache(
            models=[self.model, self.model],
            chunk_sizes=1,
            loss_fn=loss_fn,
            get_rep_fn=None,
            fp16=False
            )

        detached_loss = gc(query_input, pos_doc_input).requires_grad_()
        return detached_loss`

As you can see, I set chunksize=1,and I also tried to set chunksize=4\16\64, and batchsize in trainer setting is 256; the device is A800(80G) with 2 GPUs

②DistributedContrastiveLoss function, similar with DistributedContrastiveLoss in loss.py of grad_cache pakage, only added a temperature parameter to scale the score;

class DistributedContrastiveLoss(SimpleContrastiveLoss):
    def __init__(self, temperature, n_hard_negatives: int = 0):
        assert dist.is_initialized(), "Distributed training has not been properly initialized."

        super().__init__(temperature=temperature,n_hard_negatives=n_hard_negatives)
        self.word_size = dist.get_world_size()
        self.rank = dist.get_rank()

    def __call__(self, x: Tensor, y: Tensor, **kwargs):
        dist_x = self.gather_tensor(x)
        dist_y = self.gather_tensor(y)

        return super().__call__(dist_x, dist_y, **kwargs)

    def gather_tensor(self, t):
        gathered = [torch.empty_like(t) for _ in range(self.word_size)]
        dist.all_gather(gathered, t)
        gathered[self.rank] = t

        return torch.cat(gathered, dim=0)

which SimpleContrastiveLoss like this:

class SimpleContrastiveLoss:
    def __init__(self, temperature, n_hard_negatives: int = 0):
        self.target_per_qry = n_hard_negatives + 1
        self.temperature = temperature

    def __call__(self, x: Tensor, y: Tensor, target: Tensor = None, reduction: str = 'mean'):
        if target is None:
            assert x.size(0) * self.target_per_qry == y.size(0)
            target = torch.arange(0, x.size(0) * self.target_per_qry, self.target_per_qry, device=x.device)

        logits = torch.matmul(x, y.transpose(0, 1))
        logits = logits*self.temperature
        return F.cross_entropy(logits, target, reduction=reduction)

finally, code can run success, but grad updata is extremely slow

luyug commented 6 days ago

can you share the observed runtime and reference runtime

meanwhile, one thing to note is that huggingface Trainer can trigger features like deepspeed zero, which came after gradcache release and therefore may not be smoothly supported.

liuweie commented 2 days ago

@luyug I think I have figure this prolem out, thanks. but during my experiment , I found that the loss is very difficult to converge, here is my log:

{'loss': 4.8621, 'grad_norm': 3.703125, 'learning_rate': 0.0001978902953586498, 'epoch': 0.08}
{'loss': 3.3299, 'grad_norm': 1.140625, 'learning_rate': 0.00019261603375527427, 'epoch': 0.16}
{'loss': 3.0912, 'grad_norm': 3.578125, 'learning_rate': 0.00018734177215189873, 'epoch': 0.23}
{'loss': 2.6461, 'grad_norm': 1.0234375, 'learning_rate': 0.00018206751054852322, 'epoch': 0.31}
{'loss': 1.8239, 'grad_norm': 0.7578125, 'learning_rate': 0.00017679324894514769, 'epoch': 0.39}
{'loss': 2.4623, 'grad_norm': 4.4375, 'learning_rate': 0.00017151898734177218, 'epoch': 0.47}
{'loss': 2.1719, 'grad_norm': 1.1640625, 'learning_rate': 0.00016624472573839661, 'epoch': 0.55}
{'loss': 2.6063, 'grad_norm': 16.125, 'learning_rate': 0.0001609704641350211, 'epoch': 0.62}
{'loss': 2.2289, 'grad_norm': 6.40625, 'learning_rate': 0.0001556962025316456, 'epoch': 0.7}
{'loss': 2.1505, 'grad_norm': 1.5078125, 'learning_rate': 0.00015042194092827003, 'epoch': 0.78}
{'loss': 2.2342, 'grad_norm': 3.375, 'learning_rate': 0.00014514767932489453, 'epoch': 0.86}
{'loss': 1.7903, 'grad_norm': 2.125, 'learning_rate': 0.000139873417721519, 'epoch': 0.93}
{'loss': 2.5553, 'grad_norm': 2.21875, 'learning_rate': 0.00013459915611814345, 'epoch': 1.01}
{'loss': 1.8375, 'grad_norm': 0.95703125, 'learning_rate': 0.00012932489451476795, 'epoch': 1.09}
{'loss': 2.379, 'grad_norm': 0.71875, 'learning_rate': 0.0001240506329113924, 'epoch': 1.17}
{'loss': 2.5603, 'grad_norm': 3.203125, 'learning_rate': 0.00011877637130801689, 'epoch': 1.25}
....
....
...
{'loss': 1.9158, 'grad_norm': 2.84375, 'learning_rate': 3.966244725738397e-05, 'epoch': 2.41}
{'loss': 2.2063, 'grad_norm': 3.625, 'learning_rate': 3.438818565400844e-05, 'epoch': 2.49}
{'loss': 2.1187, 'grad_norm': 1.46875, 'learning_rate': 2.9113924050632914e-05, 'epoch': 2.57}
{'loss': 2.055, 'grad_norm': 0.94140625, 'learning_rate': 2.3839662447257385e-05, 'epoch': 2.65}

As you can see, the loss is alway around 2, and if I don't use grad cache , loss can converge to 0.2