ybracke / transnormer

A lexical normalizer for historical spelling variants using a transformer architecture.
GNU General Public License v3.0
6 stars 1 forks source link

Use a `GenerationConfig` for generation #74

Closed ybracke closed 8 months ago

ybracke commented 8 months ago

Update generate.py

Documentation

Sketch:

gen_cfg = transformers.GenerationConfig()
gen_cfg.max_new_tokens = 2048 # random
gen_cfg.early_stopping = CONFIGS["beam_search_decoding"]["early_stopping"]
gen_cfg.length_penalty = CONFIGS["beam_search_decoding"]["length_penalty"]
gen_cfg.num_beams = CONFIGS["beam_search_decoding"]["num_beams"]

# Generate
def generate_normalization(batch):
    inputs = tokenizer_input(
        batch["orig"],
        padding="longest",
        # truncation=True,
        # max_length=CONFIGS["tokenizer"]["max_length_input"],
        return_tensors="pt",
    )
    input_ids = inputs.input_ids.to(device)
    attention_mask = inputs.attention_mask.to(device)

    outputs = model.generate(
        input_ids, 
        attention_mask=attention_mask, 
        generation_config=gen_cfg)