Closed ybracke closed 8 months ago
Update generate.py
generate.py
Documentation
Sketch:
gen_cfg = transformers.GenerationConfig() gen_cfg.max_new_tokens = 2048 # random gen_cfg.early_stopping = CONFIGS["beam_search_decoding"]["early_stopping"] gen_cfg.length_penalty = CONFIGS["beam_search_decoding"]["length_penalty"] gen_cfg.num_beams = CONFIGS["beam_search_decoding"]["num_beams"] # Generate def generate_normalization(batch): inputs = tokenizer_input( batch["orig"], padding="longest", # truncation=True, # max_length=CONFIGS["tokenizer"]["max_length_input"], return_tensors="pt", ) input_ids = inputs.input_ids.to(device) attention_mask = inputs.attention_mask.to(device) outputs = model.generate( input_ids, attention_mask=attention_mask, generation_config=gen_cfg)
Update
generate.py
Documentation
Sketch: