FGA-DIKU / EHR

3 stars 0 forks source link

Optimization: Tokenization #61

Open kirilklein opened 1 week ago

kirilklein commented 1 week ago

Most of the computation time in create_data is now spend on tokenization There are obvious inefficiencies that need to be addressed.

For example in the _add_token function which is applied repeatedly, we do sorting by PID and abspos twice. We should sort once outside the function if possible. Also, probably doing set_index("PID") before to have pids correctly partitioned will help. Especially in functions like _get_first_event and get_segment_change Furthermore the tokenization itself can be improved. For instance, we can get rid of the lambda function in tokenize_frozen:

def tokenize_frozen(concepts: dd.Series, vocabulary: dict) -> dd.Series:
    """Tokenize the concepts in the features DataFrame with a frozen vocabulary."""
    return concepts.map(vocabulary).fillna(vocabulary["[UNK]"]).astype(int) 

Similarly, we can precompute the mapping in the dynamic case and just do .map once

def tokenize_with_update(concepts: dd.Series, vocabulary: dict) -> Tuple[dd.Series, dict]:
    """Tokenize the concepts in the features DataFrame and update the vocabulary."""
    # Make a copy of the vocabulary to avoid mutating the original
    vocabulary = vocabulary.copy()

    # Get unique concepts from the Dask Series without computing the entire Series
    unique_concepts = concepts.drop_duplicates().compute()

    # Identify new concepts not already in the vocabulary
    new_concepts = [concept for concept in unique_concepts if concept not in vocabulary]

    # Assign new indices to the new concepts
    start_index = len(vocabulary)
    new_indices = range(start_index, start_index + len(new_concepts))
    new_vocab_entries = dict(zip(new_concepts, new_indices))

    # Update the vocabulary with new entries
    vocabulary.update(new_vocab_entries)

    # Map the concepts to their token indices using the updated vocabulary
    concepts = concepts.map(vocabulary).astype(int)

    return concepts, vocabulary
kirilklein commented 3 days ago

Tokenization is taking around 3h now 1614ba1136ab2ebc002d8308176bd05b4cb7784f on azure, which is still slow compared to the 10m the old version (having features in memory and looping over patients) took. Data has increased since then but not by that much. Current data (with med and diag only) takes 44gib of RAM. -> we need to handle oom -> optimize dask code instead

kirilklein commented 3 days ago

Biggest issues: Vocabulary creation and writing is very slow Although, probably some operations are pushed into writing See whether we can avoid recomputation where possible