sacdallago / biotrainer

Biological prediction models made simple.
https://biocentral.cloud/app
Academic Free License v3.0
34 stars 8 forks source link

Improve embeddings saving strategy #98

Closed SebieF closed 3 months ago

SebieF commented 4 months ago

96 introduced cached saving of embeddings to the hard disk during embeddings calculation to improve available RAM. At the moment, saving to disk is fixed after as a constant after 100 embeddings have been calculated. This creates some overhead and might even fail for very long sequences in the first batches. Ideally, the point when to save would be set dynamically based on available RAM and sequence lengths.

SebieF commented 4 months ago

Some input from Claude Sonnet 3.5 on the topic:

Your approach of saving intermediate results to disk is a good way to manage memory usage when dealing with large datasets or long sequences. As for dynamically setting the SAVE_AFTER value, it's a reasonable idea, but the optimal approach depends on various factors. Here are a few considerations and a possible heuristic:

Memory usage: The primary goal is to prevent running out of memory. A dynamic SAVE_AFTER value could be based on monitoring the current memory usage.

Disk I/O overhead: While saving might not add significant overhead, frequent small writes could potentially impact performance, especially on slower storage systems.

Sequence length variability: If your sequences vary greatly in length, a fixed SAVE_AFTER value might not be optimal.

Here's a possible heuristic approach:


import psutil
import sys

class EmbeddingService:
    def __init__(self, embedder: EmbedderInterface = None, use_half_precision: bool = False,
                 max_memory_usage: float = 0.8, initial_save_after: int = 50):
        self._embedder = embedder
        self._use_half_precision = use_half_precision
        self._max_memory_usage = max_memory_usage
        self._initial_save_after = initial_save_after

    def compute_embeddings(self, sequence_file: str, output_dir: Path, protocol: Protocol,
                           force_output_dir: Optional[bool] = False,
                           force_recomputing: Optional[bool] = False) -> str:
        # ... (previous code)

        embeddings = {}
        idx: int = 0
        last_save_id: int = 0
        save_after: int = self._initial_save_after

        for embedding in tqdm(self._embedder.embed_many(protein_sequences.values()),
                              total=len(protein_sequences.values())):
            embeddings[sequence_ids[idx]] = embedding
            idx += 1

            if len(embeddings) >= save_after or self._should_save(embeddings):
                embeddings = self._reduce_embeddings(embeddings, self._embedder)
                last_save_id = self._save_embeddings(save_id=last_save_id, embeddings=embeddings,
                                                     embeddings_file_path=embeddings_file_path)

                save_after = self._adjust_save_after(save_after, len(embeddings))

                del embeddings
                embeddings = {}

        # ... (rest of the method)

    def _should_save(self, embeddings: Dict[str, ndarray]) -> bool:
        current_memory_usage = psutil.virtual_memory().percent / 100
        return current_memory_usage > self._max_memory_usage

    def _adjust_save_after(self, current_save_after: int, last_batch_size: int) -> int:
        current_memory_usage = psutil.virtual_memory().percent / 100
        if current_memory_usage < self._max_memory_usage * 0.8:
            return min(current_save_after * 2, 1000)  # Increase, but cap at 1000
        elif current_memory_usage > self._max_memory_usage * 0.9:
            return max(current_save_after // 2, 10)  # Decrease, but keep at least 10
        return current_save_after

This approach does the following:

It introduces a _max_memory_usage parameter (default 80%) to determine when to trigger a save based on memory usage.

The _should_save method checks if the current memory usage exceeds the specified threshold.

The _adjust_save_after method dynamically adjusts the save_after value based on the current memory usage:
    If memory usage is comfortably low, it increases save_after.
    If memory usage is getting close to the threshold, it decreases save_after.
    It also sets minimum and maximum values for save_after to prevent extreme values.

In the main loop, it checks both the save_after count and the _should_save condition.

This approach provides a balance between a fixed interval and dynamic adjustment based on memory usage. It starts with your initial value and then adjusts based on the observed memory usage.

Remember to install the psutil library (pip install psutil) if you decide to use this approach.

Keep in mind that this is just one possible heuristic, and you may need to fine-tune it based on your specific use case and the characteristics of your data and system.