Open ekiefl opened 4 months ago
@ekiefl hey Evan, i'm actually interested in this as well, currently staring at this
could you confirm all the logic is contained there and maybe i'll attempt to transcribe it to numpy / pytorch? i need it for open sourced alphafold3, but can also do a pull request here
@lucidrains I don't have intimate knowledge of the calculation, but I'm sure @ShintaroMinami would appreciate a PR--I know I would.
@lucidrains I don't have intimate knowledge of the calculation, but I'm sure @ShintaroMinami would appreciate a PR--I know I would.
thanks Evan, will give it a shot and let Shintaro review it
@ekiefl would you like to do a code review, perhaps test out this function on your data? i can then attempt to translate that to a file in this repo if numbers are in the right ballpark
Hi @lucidrains, sure.
I think a comparison to biotite would be useful, but I'm having trouble finding a point of entry for your implementation. Could you help me write the missing function in this test?
import shutil
import numpy as np
from numpy.typing import NDArray
from pathlib import Path
from biotite.structure.io import load_structure
from biotite.structure.residues import apply_residue_wise
from biotite.structure.sasa import sasa
NUM_POINTS = 200
STRUCT_DIR = Path("./structures/")
OUT_DIR = Path("./sasas/")
def get_biotite_sasa(path: Path) -> NDArray:
structure = load_structure(path)
atomic_sasa = sasa(structure, vdw_radii="Single", point_number=NUM_POINTS)
residue_sasa = apply_residue_wise(structure, atomic_sasa, np.sum)
assert residue_sasa is not None
return residue_sasa
def get_af3_sasa(path: Path) -> NDArray:
raise NotImplementedError()
def get_r2(method_1, method_2):
ss_res = np.sum((method_1 - method_2) ** 2)
ss_tot = np.sum((method_1 - np.mean(method_1)) ** 2)
r2 = 1 - (ss_res / ss_tot)
return r2
def summarize_r2_results(data):
data = np.array(data)
n = len(data)
mean = np.mean(data)
median = np.median(data)
data_min = np.min(data)
data_max = np.max(data)
std_dev = np.std(data, ddof=1) # Sample standard deviation (n-1)
print("Basic R2 Statistics:")
print(f"Count : {n}")
print(f"Mean : {mean:.4f}")
print(f"Median : {median}")
print(f"Minimum : {data_min}")
print(f"Maximum : {data_max}")
print(f"Standard Deviation : {std_dev:.4f}")
if __name__ == "__main__":
if OUT_DIR.exists():
shutil.rmtree(OUT_DIR)
r2s: list[float] = []
for path in STRUCT_DIR.glob("*.pdb"):
biotite_sasa = get_biotite_sasa(path)
af3_sasa = get_af3_sasa(path)
r2s.append(get_r2(biotite_sasa, af3_sasa))
summarize_r2_results(r2s)
If you want to run the script yourself, attached is structures.zip
, which you should place in the same directory as the script and unzip. It's 75 relatively homologous structures.
@ekiefl 🙏 will check it out tomorrow morning
DSSP is my go-to tool for calculating RSA and it would be great if the numpy/torch implementation you've provided for SS could be extended for RSA as well.