Closed li1117heex closed 3 years ago
Could you provide more details ? What's the code you ran ?
tokenizer = FunnelTokenizer.from_pretrained('funnel-transformer/small')
def tokenize(batch):
return tokenizer(batch['text'], padding='max_length', truncation=True,max_length=512)
dataset = load_dataset("bookcorpus",split='train[:1000]').shuffle()
dataset = dataset.map(tokenize, batched=True, batch_size=512)
# dataset = LineByLineTextDataset(
# tokenizer=tokenizer,
# file_path="./wiki1000.txt",
# block_size=128
# )
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)
config=FunnelConfig(
return_dict=True
)
model= FunnelForMaskedLM(config=config)
training_args = TrainingArguments(
output_dir="./checkpoints",
overwrite_output_dir=True,
do_train=True,
num_train_epochs=1,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
save_steps=10000,
logging_dir='./ptlogs'
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=dataset,
)
trainer.train()
RuntimeError: CUDA out of memory. Tried to allocate 954.00 MiB (GPU 0; 15.90 GiB total capacity; 14.35 GiB already allocated; 753.75 MiB free; 14.39 GiB reserved in total by PyTorch) Exception raised from malloc at /pytorch/c10/cuda/CUDACachingAllocator.cpp:272 (most recent call first):
part of error output
from funnel model to bert model : error still happened
from your dataset to LineByLineTextDataset : error disapeared
notice i just loaded 1000 rows of data
the error happens when executing loss.backward()
Since you're using a data collator you don't need to tokenizer the dataset using map
. Could you try not to use map
and only the data collator instead ? The data collator is supposed to pad to the longest sequence in each batch afaik, instead of padding to 512.
Also cc @sgugger
Closing this one. Feel free to re-open if you have other questions about this issue
In your dataset ,cuda run out of memory as long as the trainer begins: however, without changing any other element/parameter,just switch dataset to
LineByLineTextDataset
,everything becames OK.