generatebio / chroma

A generative model for programmable protein design
Apache License 2.0
696 stars 90 forks source link

long time persist model parameters #15

Open cyrushu opened 1 year ago

cyrushu commented 1 year ago

Dear developers,

Here is a solution for long time persist model parameters. Which would save some networks. It would be better to have a sha256 check inside the cache-check process.

diff --git a/chroma/utility/api.py b/chroma/utility/api.py
index 902b776..ce996c8 100644
--- a/chroma/utility/api.py
+++ b/chroma/utility/api.py
@@ -21,7 +21,11 @@ import requests

 import chroma

-ROOT_DIR = os.path.dirname(os.path.dirname(chroma.__file__))
+# SETTING CHROMA_ROOT_DIR or use default directory: ~/.config/chroma
+ROOT_DIR = os.environ.get(
+    "CHROMA_ROOT_DIR",
+    os.path.join(os.path.expanduser("~"), ".config", "chroma"))
+os.makedirs(ROOT_DIR, exist_ok=True)

 def register_key(key: str, key_directory=ROOT_DIR) -> None:
@@ -92,11 +96,8 @@ def download_from_generate(

     # Create a hash of the URL + weight name to determine the path for the cached/temporary file
     url_hash = hashlib.md5((base_url + weights_name).encode()).hexdigest()
-    temp_dir = os.path.join(tempfile.gettempdir(), "chroma_weights", url_hash)
-    destination = os.path.join(temp_dir, "weights.pt")
-
-    # Ensure the directory exists
-    os.makedirs(temp_dir, exist_ok=True)
+    os.makedirs(os.path.join(ROOT_DIR, "weights"), exist_ok=True)
+    destination = os.path.join(ROOT_DIR, "weights", f"{url_hash}.pt")

     # Check if cache exists
     cache_exists = os.path.exists(destination)
@@ -117,8 +118,14 @@ def download_from_generate(
     response = requests.get(base_url, params=params)
     response.raise_for_status()  # Raise an error for HTTP errors

-    with open(destination, "wb") as file:
-        file.write(response.content)
+    # Write into temp_file
+    temp_file = tempfile.TemporaryFile()
+    temp_file.write(response.content)
+
+    # Write into cached destination
+    with open(destination, "wb") as f:
+        temp_file.seek(0)
+        f.write(temp_file.read())

     print(f"Data saved to {destination}")
     return destination