Genentech / gReLU

gReLU is a python library to train, interpret, and apply deep learning models to DNA sequences.
https://genentech.github.io/gReLU/
MIT License
228 stars 23 forks source link

Incompatibility between modisco with ISM and Specificify transform with subtract function #80

Open dagarfield opened 2 weeks ago

dagarfield commented 2 weeks ago

The function run_modisco currently calculates that attributions that are fed to modisco with:

attrs = -np.log2(np.divide(ism_preds, ref_preds))

This runs into issues, however, if there are negative values involved. A fix could be:

ref_preds_safe = np.where(ref_preds == 0, np.finfo(float).eps, ref_preds)
division_result = np.divide(ism_preds, ref_preds_safe)
division_result_safe = np.where(division_result <= 0, np.finfo(float).eps, division_result)
attrs = -np.log2(division_result_safe)

More intelligently, however, for someone better with numpy arrays would be to take the absolute values for ism_preds and then re-assign negative values after the fact (or some how otherwise scaling)

dagarfield commented 2 weeks ago

Note: Changing from ism to saliency does not work and seems to extend an error from modisco itself

Running modisco
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File <timed eval>:1

File ~/scratch/conda/envs/gRelu_v1/lib/python3.10/site-packages/grelu/interpret/score.py:325, in run_modisco(model, seqs, genome, prediction_transform, window, meme_file, out_dir, devices, num_workers, batch_size, n_shuffles, seed, method, **kwargs)
    323 one_hot_arr = one_hot_arr.transpose(0, 2, 1).astype("float32")
    324 attrs = attrs.transpose(0, 2, 1).astype("float32")
--> 325 pos_patterns, neg_patterns = modiscolite.tfmodisco.TFMoDISco(
    326     hypothetical_contribs=attrs,
    327     one_hot=one_hot_arr,
    328     **kwargs,
    329 )
    331 print("Writing modisco output")
    332 if not os.path.exists(out_dir):

File ~/scratch/conda/envs/gRelu_v1/lib/python3.10/site-packages/modiscolite/tfmodisco.py:281, in TFMoDISco(one_hot, hypothetical_contribs, sliding_window_size, flank_size, min_metacluster_size, weak_threshold_for_counting_sign, max_seqlets_per_metacluster, target_seqlet_fdr, min_passing_windows_frac, max_passing_windows_frac, n_leiden_runs, n_leiden_iterations, min_overlap_while_sliding, nearest_neighbors_to_compute, affmat_correlation_threshold, tsne_perplexity, frac_support_to_trim_to, min_num_to_trim_to, trim_to_window_size, initial_flank_to_add, prob_and_pertrack_sim_merge_thresholds, prob_and_pertrack_sim_dealbreaker_thresholds, subcluster_perplexity, merging_max_seqlets_subsample, final_min_cluster_size, min_ic_in_window, min_ic_windowsize, ppm_pseudocount, verbose)
    275 contrib_scores = np.multiply(one_hot, hypothetical_contribs)
    277 track_set = core.TrackSet(one_hot=one_hot, 
    278     contrib_scores=contrib_scores,
    279     hypothetical_contribs=hypothetical_contribs)
--> 281 seqlet_coords, threshold = extract_seqlets.extract_seqlets(
    282     attribution_scores=contrib_scores.sum(axis=2),
    283     window_size=sliding_window_size,
    284     flank=flank_size,
    285     suppress=(int(0.5*sliding_window_size) + flank_size),
    286     target_fdr=target_seqlet_fdr,
    287     min_passing_windows_frac=min_passing_windows_frac,
    288     max_passing_windows_frac=max_passing_windows_frac,
    289     weak_threshold_for_counting_sign=weak_threshold_for_counting_sign) 
    291 seqlets = track_set.create_seqlets(seqlet_coords) 
    293 pos_seqlets, neg_seqlets = [], []

File ~/scratch/conda/envs/gRelu_v1/lib/python3.10/site-packages/modiscolite/extract_seqlets.py:163, in extract_seqlets(attribution_scores, window_size, flank, suppress, target_fdr, min_passing_windows_frac, max_passing_windows_frac, weak_threshold_for_counting_sign)
    158 pos_null_values, neg_null_values = _laplacian_null(track=smoothed_tracks, 
    159     window_size=window_size, num_to_samp=10000)
    161 pos_threshold = _isotonic_thresholds(pos_values, pos_null_values, 
    162     increasing=True, target_fdr=target_fdr)
--> 163 neg_threshold = _isotonic_thresholds(neg_values, neg_null_values,
    164     increasing=False, target_fdr=target_fdr)
    166 pos_threshold, neg_threshold = _refine_thresholds(
    167       vals=np.concatenate([pos_values, neg_values], axis=0),
    168       pos_threshold=pos_threshold,
    169       neg_threshold=neg_threshold,
    170       min_passing_windows_frac=min_passing_windows_frac,
    171       max_passing_windows_frac=max_passing_windows_frac) 
    173 distribution = np.array(sorted(np.abs(np.concatenate(smoothed_tracks,
    174     axis=0))))

File ~/scratch/conda/envs/gRelu_v1/lib/python3.10/site-packages/modiscolite/extract_seqlets.py:117, in _isotonic_thresholds(values, null_values, increasing, target_fdr, min_frac_neg)
    114 sample_weight = np.concatenate([np.ones(n1), np.ones(n2)*w], axis=0)
    116 model = IsotonicRegression(out_of_bounds='clip', increasing=increasing)
--> 117 model.fit(X, y, sample_weight=sample_weight)
    119 min_prec_x = model.X_min_ if increasing else model.X_max_
    120 min_precision = model.transform([min_prec_x])[0]

File ~/scratch/conda/envs/gRelu_v1/lib/python3.10/site-packages/sklearn/base.py:1473, in _fit_context.<locals>.decorator.<locals>.wrapper(estimator, *args, **kwargs)
   1466     estimator._validate_params()
   1468 with config_context(
   1469     skip_parameter_validation=(
   1470         prefer_skip_nested_validation or global_skip_validation
   1471     )
   1472 ):
-> 1473     return fit_method(estimator, *args, **kwargs)

File ~/scratch/conda/envs/gRelu_v1/lib/python3.10/site-packages/sklearn/isotonic.py:387, in IsotonicRegression.fit(self, X, y, sample_weight)
    383 check_consistent_length(X, y, sample_weight)
    385 # Transform y by running the isotonic regression algorithm and
    386 # transform X accordingly.
--> 387 X, y = self._build_y(X, y, sample_weight)
    389 # It is necessary to store the non-redundant part of the training set
    390 # on the model to make it possible to support model persistence via
    391 # the pickle module as the object built by scipy.interp1d is not
    392 # picklable directly.
    393 self.X_thresholds_, self.y_thresholds_ = X, y

File ~/scratch/conda/envs/gRelu_v1/lib/python3.10/site-packages/sklearn/isotonic.py:322, in IsotonicRegression._build_y(self, X, y, sample_weight, trim_duplicates)
    319 unique_X, unique_y, unique_sample_weight = _make_unique(X, y, sample_weight)
    321 X = unique_X
--> 322 y = isotonic_regression(
    323     unique_y,
    324     sample_weight=unique_sample_weight,
    325     y_min=self.y_min,
    326     y_max=self.y_max,
    327     increasing=self.increasing_,
    328 )
    330 # Handle the left and right bounds on X
    331 self.X_min_, self.X_max_ = np.min(X), np.max(X)

File ~/scratch/conda/envs/gRelu_v1/lib/python3.10/site-packages/sklearn/utils/_param_validation.py:186, in validate_params.<locals>.decorator.<locals>.wrapper(*args, **kwargs)
    184 global_skip_validation = get_config()["skip_parameter_validation"]
    185 if global_skip_validation:
--> 186     return func(*args, **kwargs)
    188 func_sig = signature(func)
    190 # Map *args/**kwargs to the function signature

File ~/scratch/conda/envs/gRelu_v1/lib/python3.10/site-packages/sklearn/isotonic.py:157, in isotonic_regression(y, sample_weight, y_min, y_max, increasing)
    114 """Solve the isotonic regression model.
    115 
    116 Read more in the :ref:`User Guide <isotonic>`.
   (...)
    154        7.33..., 7.33..., 7.33..., 7.33..., 7.33...])
    155 """
    156 order = np.s_[:] if increasing else np.s_[::-1]
--> 157 y = check_array(y, ensure_2d=False, input_name="y", dtype=[np.float64, np.float32])
    158 y = np.array(y[order], dtype=y.dtype)
    159 sample_weight = _check_sample_weight(sample_weight, y, dtype=y.dtype, copy=True)

File ~/scratch/conda/envs/gRelu_v1/lib/python3.10/site-packages/sklearn/utils/validation.py:1087, in check_array(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_writeable, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, estimator, input_name)
   1085     n_samples = _num_samples(array)
   1086     if n_samples < ensure_min_samples:
-> 1087         raise ValueError(
   1088             "Found array with %d sample(s) (shape=%s) while a"
   1089             " minimum of %d is required%s."
   1090             % (n_samples, array.shape, ensure_min_samples, context)
   1091         )
   1093 if ensure_min_features > 0 and array.ndim == 2:
   1094     n_features = array.shape[1]

ValueError: Found array with 0 sample(s) (shape=(0,)) while a minimum of 1 is required.