aertslab / CREsted

Other
26 stars 1 forks source link

custom loss function enhancer design #15

Closed SeppeDeWinter closed 4 weeks ago

SeppeDeWinter commented 1 month ago

For example:


# use Heart, muscle and myoblast cells as contrast
classes_of_interest = [
    i for i, ct in enumerate(adata.obs_names)
    if "Heart" in ct or "muscle" in ct or "myoblast" in ct
]

# design enhancers that are high in heart but low in myoblast
target = np.array(
      [
        0 if "Cardiac muscle" not in x else 1 for x in adata.obs_names
        if "Heart" in x or "muscle" in x or "myoblast" in x
    ]
)

assert all(["Cardiac muscle" in x for x in adata.obs_names[np.array(classes_of_interest)[np.where(target)[0]]]])

from sklearn.metrics import pairwise
from crested.tl._utils import EnhancerOptimizer

def L2_distance(
    mutated_predictions: np.ndarray,
    original_prediction: np.ndarray,
    target: np.ndarray,
    classes_of_interest: list[int]):
    def scale(X):
        return ((X.T - X.min(1)) / (X.max(1) - X.min(1))).T
    L2_sat_mut = pairwise.euclidean_distances(scale(mutated_predictions)[:,classes_of_interest], target.reshape(1, -1))
    L2_baseline = pairwise.euclidean_distances(scale(original_prediction)[:, classes_of_interest], target.reshape(1, -1))
    return np.argmax((L2_baseline - L2_sat_mut).squeeze())

L2_optimizer = EnhancerOptimizer(
    optimize_func = L2_distance
)

intermediate_info_list, designed_sequences = evaluator.enhancer_design_in_silico_evolution(
  target_class=None, n_sequences=1, n_mutations=30,
  enhancer_optimizer = L2_optimizer,
  target = target,
  return_intermediate = True,
  no_mutation_flanks = (807, 807),
  classes_of_interest = classes_of_interest
)
SeppeDeWinter commented 1 month ago

Note, code for motif embedding has not been tested yet.

LukasMahieu commented 1 month ago

Okay, looks good and makes sense to me. In the near future we should really make a separate tutorial for enhancer design (including this information here), since as of now it's a one-liner in the introductory tutorial. @erceksi could you take a look too since you implemented the original function?

SeppeDeWinter commented 1 month ago

Added some extra changes.

Now multiple sequences should be processed in parallel. Before a call was made to model.predict for each sequence and each iteration.

Now a single call is made to model.predict for each iteration only.

SeppeDeWinter commented 1 month ago

From a quick and dirty benchmark, this code should be around 2x faster.