emanjavacas / pie

A fully-fledge PyTorch package for Morphological Analysis, tailored to morphologically rich and historical languages.
MIT License
22 stars 10 forks source link

Disable word encoding if not using wemb ? #72

Closed PonteIneptique closed 4 years ago

PonteIneptique commented 4 years ago

For the same reason as #70, I had a look at the weight of encoding words even if you do not use wemb.

Gain is minimal, but code weight as well so... What do you think ?

Set-up

import time
import glob
import os.path

import tqdm
import pie
from pie.tagger import simple_tokenizer, pack_batch

# file = "PathToAModel"
# base_encoder = pie.model.BaseModel.load(file).label_encoder

example = list(simple_tokenizer("""Lorem ipsum dolor sit amet, consectetur adipiscing elit. 
Phasellus dolor sapien, laoreet non turpis eget, tincidunt commodo magna. Duis at dapibus ipsum. 
Etiam fringilla et magna sed vehicula. 
Nunc tristique eros non faucibus viverra. 
Sed dictum scelerisque tortor, eu ullamcorper odio. 
Aenean fermentum a urna quis tempus. 
Maecenas imperdiet est a nisi pellentesque dictum. 
Maecenas ac hendrerit ante. Vestibulum eleifend nulla at vulputate sagittis. 
Maecenas sed magna diam. Sed facilisis tempus ipsum, nec mattis elit tincidunt lobortis. 
Phasellus vel ex lorem. Nulla nunc odio, tempor non consequat in, luctus elementum dolor. 
Nullam tincidunt purus vel lorem placerat, ac pulvinar turpis sodales. 
Sed eget urna ac quam cursus porta. 
Pellentesque luctus aliquet sem, a egestas purus finibus ac. 
Mauris nec mauris non metus tempor faucibus non in est. 
Class aptent taciti sociosqu ad litora torquent per conubia nostra, per inceptos himenaeos. 
Proin tristique nulla nec purus iaculis, eu scelerisque mi egestas. 
In hac habitasse platea dictumst.
Ut placerat a neque eget aliquet. """))

print(f"Sentences : {len(example)}")
print(f"Words : {len([x for s in example for x in s])}")
runs = 1000

Sentences : 22 Words : 194

from collections import defaultdict

from pie import utils, torch_utils, constants

def mod_pack_batch(label_encoder, batch, device=None, use_wemb=True):
    """
    Transform batch data to tensors
    """
    (word, char), tasks = mod_transform(label_encoder, batch, use_wemb=use_wemb)

    char = torch_utils.pad_batch(char, label_encoder.char.get_pad(), device=device)
    if use_wemb:
        word = torch_utils.pad_batch(word, label_encoder.word.get_pad(), device=device)

    output_tasks = {}
    for task, data in tasks.items():
        output_tasks[task] = torch_utils.pad_batch(
            data, label_encoder.tasks[task].get_pad(), device=device)

    return (word, char), output_tasks

def mod_transform(label_encoder, sents, use_wemb=True):
    """
    Parameters
    ===========
    sents : list of Example's

    Returns
    ===========
    tuple of (word, char), task_dict

        - word: list of integers
        - char: list of integers where each list represents a word at the
            character level
        - task_dict: Dict to corresponding integer output for each task
    """
    word, char, tasks_dict = [], [], defaultdict(list)

    for inp in sents:
        tasks = None

        # task might not be passed
        if isinstance(inp, tuple):
            inp, tasks = inp

        # input data
        if use_wemb:
            word.append(label_encoder.word.transform(inp))
        char.extend(label_encoder.char.transform(inp))

        # task data
        if tasks is None:
            # during inference there is no task data (pass None)
            continue

        for le in label_encoder.tasks.values():
            task_data = le.transform(tasks[le.target], inp)
            # add data
            if le.level == 'char':
                tasks_dict[le.name].extend(task_data)
            else:
                tasks_dict[le.name].append(task_data)

    return (word, char), tasks_dict

Current situation

import sys
records = []
for run in tqdm.tqdm(range(runs)):
    start = time.time()
    mod_pack_batch(base_encoder, example, "cpu", use_wemb=True)
    records.append(time.time() - start)

print(f"Average encoding time with a single encoders {sum(records) / len(records)} (Total : {sum(records)})")

100%|██████████| 1000/1000 [00:02<00:00, 360.18it/s] Average encoding time with a single encoders 0.002758941411972046 (Total : 2.758941411972046)

Proposed removal of wemb when not used

records = []
for run in tqdm.tqdm(range(runs)):
    start = time.time()
    x = mod_pack_batch(base_encoder, example, "cpu", use_wemb=False)
    records.append(time.time() - start)

print(f"Average encoding time with a single encoders {sum(records) / len(records)} (Total : {sum(records)})")

100%|██████████| 1000/1000 [00:02<00:00, 439.83it/s] Average encoding time with a single encoders 0.0022597649097442626 (Total : 2.2597649097442627)

PonteIneptique commented 4 years ago

Gain is very low (fraction of a second over 194k words...) but does it have a big cost to implement, I don't think so. What do you think @emanjavacas ?

emanjavacas commented 4 years ago

Hey. Thanks for looking into this. However, I'd be very reluctant to change stuff that works and is not problematic. As you see the time penalty is negligible, and we risk breaking things, making debugging and mantainence harder, etc...

PonteIneptique commented 4 years ago

Yeah. Unlike #70, this might not be necessary :)