kundajelab / fastISM

In-silico Saturation Mutagenesis implementation with 10x or more speedup for certain architectures.
MIT License
19 stars 3 forks source link

Failing on Squeeze Excitation layers #12

Closed louadi closed 1 year ago

louadi commented 1 year ago

Hi there,

Thanks for the very helpful package!

I m having trouble loading a model with custom layers. and getting the following error. Any feedback would be very helpful!

Best, Zakaria

LayerNormalization=tf.keras.layers.LayerNormalization
StochasticDepth= tfa.layers.StochasticDepth

fastism.fast_ism_utils.SEE_THROUGH_LAYERS.add('LayerNormalization')
fastism.fast_ism_utils.SEE_THROUGH_LAYERS.add('StochasticDepth')

fast_ism_model = FastISM(model, test_correctness=False,)
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Input In [28], in <cell line: 1>()
----> 1 fast_ism_model = FastISM(model, test_correctness=False,)

File ~/miniconda3/envs/tf/lib/python3.9/site-packages/fastism/fast_ism.py:14, in FastISM.__init__(self, model, seq_input_idx, change_ranges, early_stop_layers, test_correctness)
      9 def __init__(self, model, seq_input_idx=0, change_ranges=None,
     10              early_stop_layers=None, test_correctness=True):
     11     super().__init__(model, seq_input_idx, change_ranges)
     13     self.output_nodes, self.intermediate_output_model, self.intout_output_tensors, \
---> 14         self.fast_ism_model, self.input_specs = generate_models(
     15             self.model, self.seqlen, self.num_chars, self.seq_input_idx,
     16             self.change_ranges, early_stop_layers)
     18     self.intout_output_tensor_to_idx = {
     19         x: i for i, x in enumerate(self.intout_output_tensors)}
     21     if test_correctness:

File ~/miniconda3/envs/tf/lib/python3.9/site-packages/fastism/fast_ism_utils.py:959, in generate_models(model, seqlen, num_chars, seq_input_idx, change_ranges, early_stop_layers)
    945 node_to_segment, stop_segment_idxs, alternate_input_segment_idxs = segment_model(
    946     model, nodes, edges, inbound_edges, seq_input_idx, early_stop_layers)
    948 # stop_segment_idxs contains all the segments beyond which full computation
    949 # takes place, i.e. computations are equivalent to naive implementation.
    950 # By default segments including and downstream of those containing
   (...)
    957 # for each segment, compute metadata used for stitching together outputs
    958 # dict: segment_idx -> GraphSegment object
--> 959 segments = compute_segment_change_ranges(model, nodes, edges,
    960                                          inbound_edges,
    961                                          node_to_segment,
    962                                          stop_segment_idxs,
    963                                          seqlen, num_chars,
    964                                          change_ranges,
    965                                          seq_input_idx)
    966 # TODO: check if this makes sense
    967 # compute_segment_change_ranges does not process segments belonging to
    968 # alternate (non-sequence) inputs
   (...)
    972 # weaker version. Would not have an entry for segments in stop_segment_idxs
    973 # that are only connected to other segments in stop_segment_idxs
    974 assert(len(segments) >= len(set(node_to_segment.values())) -
    975        len(alternate_input_segment_idxs) - len(stop_segment_idxs) + 1)

File ~/miniconda3/envs/tf/lib/python3.9/site-packages/fastism/fast_ism_utils.py:403, in compute_segment_change_ranges(model, nodes, edges, inbound_edges, node_to_segment, stop_segment_idxs, input_seqlen, input_filters, input_change_ranges, seq_input_idx)
    398     assert(node_to_segment[cur_segment_tensor] in stop_segment_idxs)
    400 if len(segments_to_process_input_seqlens[cur_segment_to_process]) != \
    401         len(non_stop_segment_inbound):
    402     # should not be greater in any case
--> 403     assert(len(segments_to_process_input_seqlens[cur_segment_to_process]) <
    404            len(non_stop_segment_inbound))
    405     # hold off and wait till other input segments are populated
    406     assert(len(segments_to_process) > 0)

AssertionError: 
suragnair commented 1 year ago

Hi Zakaria, apologies that you had to encounter this error. Seems a bit tricky to debug. Would it be possible to share the architecture by any chance? If you prefer, you can also mail it to me at surag@stanford.edu. Even a minimal version with which you can replicate the error works.

louadi commented 1 year ago

Hi Surag,

Thanks for your fast reply!

I tried different version with the architecture and it seems like its this part that triggers the error:


se_avr = layers.Reshape(se_shape, name=name + "se_reshape_avr")(se_avr)
se_max = layers.Reshape(se_shape, name=name + "se_reshape_max")(se_max)

shared_layer_one  = layers.Dense(filters // se_ratio,
       activation='relu',
       kernel_initializer='he_normal',
       use_bias=True,
       bias_initializer='zeros')

shared_layer_two  = layers.Dense(filters,
       activation='sigmoid',
       kernel_initializer='he_normal',
       use_bias=True,
       bias_initializer='zeros')

se_avr=shared_layer_two(shared_layer_one(se_avr))
se_max=shared_layer_two(shared_layer_one(se_max))

se=layers.add([se_avr, se_max], name=name + "add_av_max")

x = layers.multiply([x, se], name=name + "se_excite")

So essentially a squeeze and excitation block. Do you think it is because the multiplication layer is not supported?

Best, Zakaria

suragnair commented 1 year ago

Is x an output of previous layers? I suspect it may be related to this issue: https://github.com/kundajelab/fastISM/blob/4fc1f44/test/test_unresolved.py. Essentially, something like this:

x = tf.keras.layers.Conv1D(10, 3, padding='same')(inp)
y = tf.keras.layers.Dense(10)(x)
x = tf.keras.layers.Add()([x,y])

So basically the input to the Add layer (in your case multiply) has one input that is connected to the sequence (inp above) and does not have any STOP LAYERS on the way (reshape, dense, etc.) while the other input is also connected to the sequence but it has a STOP LAYER on the way (like your se above). It raises the same error you are seeing. Can you confirm that my description is correct?

suragnair commented 1 year ago

I read up on squeeze and excitation networks. I think a bigger issue is that fastISM won't speed up SE layers. That's because fastISM works for layers when a change at position i only affects a few positions around it. However, in the case of SE (if I understand correctly), any mutation at any position will affect the representation at all other positions (due to the global max pooling). As a result, fastISM (even after I debug it properly), will stop after the first SE layer and not give any speedups.

So if you are using SE along with your main convs blocks then unfortunately fastISM will not be able to provide any speedup.

louadi commented 1 year ago

Right it make sense that both spacial and channel attention will be tricky, thank you for the explanation!

One last thing I am adding custom layers correctly or is there anything else I need to define?

suragnair commented 1 year ago

You are adding it correctly. However, I'm not sure if LayerNormalization is indeed a SEE THROUGH layer at test time, since similar to SE it's also taking an average across all positions.

louadi commented 1 year ago

Right that would explain my other bug. Thanks a lot anyway for your help and for the cool project!

suragnair commented 1 year ago

Thanks for giving it a try and feel free to reach out if you have any other issues!