Closed SebieF closed 3 months ago
Some input from Claude Sonnet 3.5 on the topic:
Your approach of saving intermediate results to disk is a good way to manage memory usage when dealing with large datasets or long sequences. As for dynamically setting the SAVE_AFTER value, it's a reasonable idea, but the optimal approach depends on various factors. Here are a few considerations and a possible heuristic:
Memory usage: The primary goal is to prevent running out of memory. A dynamic SAVE_AFTER value could be based on monitoring the current memory usage.
Disk I/O overhead: While saving might not add significant overhead, frequent small writes could potentially impact performance, especially on slower storage systems.
Sequence length variability: If your sequences vary greatly in length, a fixed SAVE_AFTER value might not be optimal.
Here's a possible heuristic approach:
import psutil
import sys
class EmbeddingService:
def __init__(self, embedder: EmbedderInterface = None, use_half_precision: bool = False,
max_memory_usage: float = 0.8, initial_save_after: int = 50):
self._embedder = embedder
self._use_half_precision = use_half_precision
self._max_memory_usage = max_memory_usage
self._initial_save_after = initial_save_after
def compute_embeddings(self, sequence_file: str, output_dir: Path, protocol: Protocol,
force_output_dir: Optional[bool] = False,
force_recomputing: Optional[bool] = False) -> str:
# ... (previous code)
embeddings = {}
idx: int = 0
last_save_id: int = 0
save_after: int = self._initial_save_after
for embedding in tqdm(self._embedder.embed_many(protein_sequences.values()),
total=len(protein_sequences.values())):
embeddings[sequence_ids[idx]] = embedding
idx += 1
if len(embeddings) >= save_after or self._should_save(embeddings):
embeddings = self._reduce_embeddings(embeddings, self._embedder)
last_save_id = self._save_embeddings(save_id=last_save_id, embeddings=embeddings,
embeddings_file_path=embeddings_file_path)
save_after = self._adjust_save_after(save_after, len(embeddings))
del embeddings
embeddings = {}
# ... (rest of the method)
def _should_save(self, embeddings: Dict[str, ndarray]) -> bool:
current_memory_usage = psutil.virtual_memory().percent / 100
return current_memory_usage > self._max_memory_usage
def _adjust_save_after(self, current_save_after: int, last_batch_size: int) -> int:
current_memory_usage = psutil.virtual_memory().percent / 100
if current_memory_usage < self._max_memory_usage * 0.8:
return min(current_save_after * 2, 1000) # Increase, but cap at 1000
elif current_memory_usage > self._max_memory_usage * 0.9:
return max(current_save_after // 2, 10) # Decrease, but keep at least 10
return current_save_after
This approach does the following:
It introduces a _max_memory_usage parameter (default 80%) to determine when to trigger a save based on memory usage.
The _should_save method checks if the current memory usage exceeds the specified threshold.
The _adjust_save_after method dynamically adjusts the save_after value based on the current memory usage:
If memory usage is comfortably low, it increases save_after.
If memory usage is getting close to the threshold, it decreases save_after.
It also sets minimum and maximum values for save_after to prevent extreme values.
In the main loop, it checks both the save_after count and the _should_save condition.
This approach provides a balance between a fixed interval and dynamic adjustment based on memory usage. It starts with your initial value and then adjusts based on the observed memory usage.
Remember to install the psutil library (pip install psutil) if you decide to use this approach.
Keep in mind that this is just one possible heuristic, and you may need to fine-tune it based on your specific use case and the characteristics of your data and system.
96 introduced cached saving of embeddings to the hard disk during embeddings calculation to improve available RAM. At the moment, saving to disk is fixed after as a constant after
100
embeddings have been calculated. This creates some overhead and might even fail for very long sequences in the first batches. Ideally, the point when to save would be set dynamically based on available RAM and sequence lengths.