amzn / zero-shot-rlhr

Apache License 2.0
51 stars 16 forks source link

Intermittent CUDA - OOM issue. #3

Open rpetchiappan opened 2 years ago

rpetchiappan commented 2 years ago

Hi Team, Thanks for all your great work with this paper and code implementation. We are trying to use the bert-rl for a classification. We are using "p3.2xlarge" (1 GPU- 16 GiB) machine for training. GPU memory is not getting cleared and accumulates over iterations and eventually results in OOM. We have tried reducing the training dataset (0.6 million, 0.1 million, and finally 49K), yet the issue is recurring. Can you kindly let us know if this configuration is sufficient for the experiment ? and also if there is any workaround for the issue ?

sanjeebtiwary commented 9 months ago

In case you encounter an OOM issue while training with BERT in the future, you might want to consider implementing mixed-precision training or gradient checkpointing. These techniques can significantly reduce memory usage while maintaining the training speed, especially for models with large memory footprints.

Load your BERT model and tokenizer

model = BertForSequenceClassification.from_pretrained('bert-base-uncased') tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

Sample data (replace with your own dataset loading logic)

texts = ["Sample text 1", "Sample text 2", "Sample text 3"] labels = [0, 1, 0]

Tokenize and convert to PyTorch tensors

tokenized_texts = tokenizer(texts, padding=True, truncation=True, return_tensors='pt') labels = torch.tensor(labels)

Create a DataLoader for handling batches

dataset = TensorDataset(tokenized_texts['input_ids'], tokenized_texts['attention_mask'], labels) batch_size = 4

Example of using mixed-precision training with APEX

model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

Example of gradient accumulation

accumulation_steps = 2 loss_fn = torch.nn.CrossEntropyLoss()

`for epoch in range(5): total_loss = 0 for i, (input_ids, attention_mask, label) in enumerate(DataLoader(dataset, batch_size=batch_size)): with amp.autocast(): outputs = model(input_ids, attention_mask=attention_mask) loss = loss_fn(outputs.logits, label)

Perform gradient accumulation

` loss = loss / accumulation_steps with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward()

    if (i + 1) % accumulation_steps == 0 or i == len(dataset) - 1:
        optimizer.step()
        optimizer.zero_grad()

    total_loss += loss.item()

print(f"Epoch {epoch + 1}, Loss: {total_loss / len(dataset)}")`