kundajelab / fastISM

In-silico Saturation Mutagenesis implementation with 10x or more speedup for certain architectures.
MIT License
18 stars 3 forks source link

Basenji #2

Closed davek44 closed 3 years ago

davek44 commented 3 years ago

Hi Surag, Avanti, and Anshul,

Love the project here! I wonder if you wouldn’t mind giving me a hand getting the Basenji model to work. I expect it’s similar to BP-net in principal, but I’ve implemented several of my own layers that might be breaking within your software.

I’m working with the latest model available here: https://storage.googleapis.com/basenji_barnyard/model_human.tf

Currently, I’m getting an AssertionError here.

~/anaconda3/envs/py38/lib/python3.8/site-packages/fastism/fast_ism_utils.py in compute_segment_change_ranges(model, nodes, edges, inbound_edges, node_to_segment, input_seqlen, input_filters, input_change_ranges, seq_input_idx)
    308     # initialise with input tensor, which has segment idx 0
    309     # only sequence input tensor should be in segment 0
--> 310     assert(sum([node_to_segment[x] == 0 for x in node_to_segment]) == 1)
    311     segments_to_process.append((0, input_tensor))
    312     segments_to_process_input_seqlens[0] = [input_seqlen]

Trying another model that I have around produces a different error. I’m not sure why they would be different.

~/anaconda3/envs/py38/lib/python3.8/site-packages/fastism/fast_ism_utils.py in segment_subgraph(current_node, nodes, edges, inbound_edges, node_to_segment, stop_segment_idxs, segment_idx, num_convs_in_cur_segment)
    179 
    180         if len(edges[current_node]) > 1:
--> 181             raise NotImplementedError(
    182                 "Layer with multiple outputs, what to do?")
    183 
NotImplementedError: Layer with multiple outputs, what to do?

One simple challenge that I’m facing in debugging is that I can’t figure out how to install your package from a github clone so that I can add pdb breakpoints and make changes. Would you be able to advise on how to do that?

Thanks for any help you can offer!

suragnair commented 3 years ago

Hi David, thanks! The current version supports a handful of commonly used Keras layers (Supported Layers). I'd be eager to get it to work with Basenji. Should be possible to simplify the architecture so that it's compatible with fastISM, or modify fastISM code where required. For now, I'm not able to open the model link you shared (NoSuchKey error).

For debugging, you could simply clone the repo and one option would be to do a sys.path.append to the location of the clone and import like in this notebook.

import sys
sys.path.append("../") # path to repo 

import fastISM
from fastISM.models.basset import basset_model

...

model_fism = fastISM.FastISM(model)

Let me know if that works.

davek44 commented 3 years ago

Great, that worked! I discovered that the assertion here fails: https://github.com/kundajelab/fastISM/blob/master/fastISM/fast_ism_utils.py#L310

The problem is that I apply the nonlinearity at the beginning of my convolution blocks (which is typically how residual blocks operate), and my model has a nonlinearity in between the sequence input tensor and the first convolution layer. Thus, that nonlinearity is getting lumped into "segment idx 0", causing the assertion line to find more than one node. That nonlinearity accomplishes nothing, but it would be a bit awkward in my code to remove. Would it be possible to make FastISM robust to this?

suragnair commented 3 years ago

Got it, I think that can be handled. I don't think there are other parts that rely on that assumption, but I'll check it.

I suspect there will be other issues as well since the model uses many custom layers. I downloaded model_human.h5 from https://console.cloud.google.com/storage/browser/basenji_barnyard, but getting the error ValueError: Unknown layer: StochasticShift. Presumably needs some imports to work. Could you guide me on how I could load the model?

davek44 commented 3 years ago

You should be able to use tf.keras.models.load_model to load the .tf suffix version I linked to above. To load the h5 weights, you'll need to grab the basenji codebase and follow a path like this one https://github.com/calico/basenji/blob/master/bin/basenji_test.py. For what it's worth, StochasticShift is an augmentation layer that doesn't need to be in the Model for predicting.

For that initial nonlinearity, I can reproduce the problem with the following minimal example.

sequence = tf.keras.Input(shape=(1024, 4), name='sequence')
current = sequence
current = tf.keras.layers.ReLU()(current)
current = tf.keras.layers.Conv1D(filters=64, kernel_size=16, padding='same')(current)
current = tf.keras.layers.GlobalAveragePooling1D()(current)
model = tf.keras.Model(inputs=sequence, outputs=current)

I manually removed that initial nonlinearity so that I can find the next hurdle. I used padding='same' in my MaxPool1D layers, so next I'm seeing a NotImplementedError for that. It sounds like the padding isn't relevant because I'm using pool_size=2 and stride=2. However, I tried simply changing the padding for a restored model, and it produced terrible results. So even if they are theoretically the same, the implementations seem to differ. If it's easy to implement this for you, that'd be great. Otherwise, I'll try training a new model with padding='valid' to see if that works OK.

Next, I've been using the GELU nonlinearity. I have my own implementation here https://github.com/calico/basenji/blob/master/basenji/layers.py#L81, but TF has it ready to add next release https://www.tensorflow.org/api_docs/python/tf/keras/activations/gelu. I solved that one by simply adding it to your list of SEE_THROUGH_LAYERS. Let me know if that's unsafe.

Next, I use Cropping1D layers to trim some off of the edges because those sequences can really only see one direction. That encounters a NotImplementedError.

I got rid of that to see what problems remain, and encountered the following that's a bit harder to decipher.

Traceback (most recent call last):
  File "./test_fastism.py", line 54, in <module>
    main()
  File "./test_fastism.py", line 47, in main
    model_fism = fastISM.FastISM(seqnn_model.model, test_correctness=True)
  File "/home/drk/code/fastISM/fastISM/fast_ism.py", line 13, in __init__
    self.fast_ism_model, self.input_specs = generate_models(
  File "/home/drk/code/fastISM/fastISM/fast_ism_utils.py", line 792, in generate_models
    intout_model, intout_output_tensors = generate_intermediate_output_model(
  File "/home/drk/code/fastISM/fastISM/fast_ism_utils.py", line 449, in generate_intermediate_output_model
    node_to_tensor, output_tensor_names = generate_intermediate_output_subgraph(
  File "/home/drk/code/fastISM/fastISM/fast_ism_utils.py", line 502, in generate_intermediate_output_subgraph
    node_to_tensor, output_tensor_names = generate_intermediate_output_subgraph(
  File "/home/drk/code/fastISM/fastISM/fast_ism_utils.py", line 502, in generate_intermediate_output_subgraph
    node_to_tensor, output_tensor_names = generate_intermediate_output_subgraph(
  File "/home/drk/code/fastISM/fastISM/fast_ism_utils.py", line 502, in generate_intermediate_output_subgraph
    node_to_tensor, output_tensor_names = generate_intermediate_output_subgraph(
  [Previous line repeated 37 more times]
  File "/home/drk/code/fastISM/fastISM/fast_ism_utils.py", line 506, in generate_intermediate_output_subgraph
    layer = nodes[parent_layer].__class__(**config)
TypeError: __init__() got an unexpected keyword argument 'name'

If you have any ideas for what that might be, let me know, and I'll pursue it. Otherwise, I can pass along this stripped down model that reaches that error.

Thanks again for working with me on this!

suragnair commented 3 years ago

Hi David, thanks for the digging! Adding GELU to SEE_THROUGH_LAYERS should be fine. Regarding MaxPool1D in your case (size 2), setting padding to "same" would essentially pad by 1 position if the length of the sequence is odd, else no padding. That might be why you're seeing bad results.

I was able to load and inspect the model (graph). What is striking is the sheer size of the model and the fact that it operates on 131kb of sequence input, much higher than 2kb on which I benchmarked fastISM. This is a bit alarming since fastISM stores intermediate outputs (after all conv-maxpool blocks, or after every conv if no maxpool between 2 convs) on the GPU itself, which takes an increasing amount of space with increasing batch size. I've observed fastISM works best when the GPU memory is maxed out. The actual fastISM computation tends to have a lot of overheads for small batch sizes.

I did a quick back-of-the-envelope calculation for the above model and summed total elements after all initial maxpools and all subsequent convolutions (this could be optimised somewhat). This comes to around 58M floats so perhaps ~200Mb per single example. That likely means we'll only be able to handle a a small number of examples at a time, I suspect in the 10-50 range in the best case. It's possible that overheads might dominate fastISM computation and dampen speedups.

I'll do a proof of concept with a stripped down architecture consisting only of convs and maxpools to see if the speedups are worthwhile and let you know how it goes. If it looks promising, we can get back to making it work for the specific layers that aren't working now,

davek44 commented 3 years ago

You're totally right about the small batch size requirement. I've been training with batch size 2, and then double it to batch size 4 for forward passes only. Maybe you can go a bit higher, but I expect 50 is out of the question.

I'm not sure that I follow the logic for why this would perform worse (relative to standard ISM) than larger batches of smaller sequences; I expect this arrangement will have a greater ratio of sequence (and thus convolutions) unaffected by a center mutation. I'll be interested in how your experiments go. Thanks!

suragnair commented 3 years ago

I started with the small model to get a feel for things. I suppose the model might not be relevant for real data but seemed easier to tackle. I removed the StochasticShift and initial activation, changed activations to relu and changed max pool padding to valid. After doing this the model worked with fastISM. The human model also worked with the same changes + removing Cropping1D. Benchmarking the models currently, should have updates soon.

I didn't encounter the error you faced above. Essentially what's happening is that the code is reconstructing a similar model from your input model, layer by layer:

config = deepcopy(nodes[parent_layer].get_config()) # copy the configuration of the layer
...
layer = nodes[parent_layer].__class__(**config) # initialise layer from config
...
layer.set_weights(nodes[parent_layer].get_weights()) # copy over weights

So it's probably bugging out for a specific layer in your case. You could just print nodes[parent_layer].__class__.__name__ to see which layer it is, or even just print parent_layer or nodes[parent_layer]. Feel free to upload the model in case you can't figure it out.

To generate the visualisation (I should put this in a debugging suite), you could do:

nodes, edges, inbound_edges, _, outputs = fastISM.flatten_model.get_flattened_graph(model) 
fastISM.flatten_model.viz_graph(nodes, edges, '/path/arch.png') 

Out of curiosity, would your goal be to do a complete ISM at each of the 131k positions, or at selected positions for each sequence? fastISM is suited for the former but not the latter since it makes the same mutations for all sequences in a batch.

suragnair commented 3 years ago

Ran a bunch of benchmarks and the results are interesting. It's a bit involved so I'll try to break it down as best as possible.

TLDR: Speedup of 2-3x seems the most likely case for the full human model.

I was able to go up to a batch size of 100 for the small model and got a speedup of about 10x on TITAN Xp 12GB with some optimisations.

The case for the full human model is somewhat tricky. Essentially the outputs are huge and GPU->CPU transfer of the outputs takes a significant amount of time. Consider the full human model and a modified version with negligible output dimensions:

inp = tf.keras.layers.Input((131072,4))
x = model(inp)
x = tf.keras.layers.GlobalAveragePooling1D()(x)
x = tf.keras.layers.Dense(1)(x)
model_so = tf.keras.models.Model(inputs=inp, outputs=x) # small output version

When running on a batch of size 1, model takes 54ms while model_so takes 50ms, while on a batch size of 25, model takes 1.65s and model_so takes 1.11s. The GPU->CPU transfer seems to be pretty slow when the batch size is larger in this case (batch size 25 would imply (25, 1024, 5313) output, which is ~550Mb for float32). This already limits the max possible speedup to something around 54/(54-50)~13x for batch size 1 and 1.65/(1.65-1.11)~3x for batch size 25.

So it seems like the optimal strategy would be to choose a small batch size for fastISM. But the catch is that fastISM has a significant fixed cost offset that scales sublinearly with batch size, possibly related to structures and mechanisms utilised by fastISM, though I haven't profiled them extensively. fastISM on model_so takes only 1.5x more time to run on batch of size 20 than a batch of size 2. No speedup is obtained at batch size 2 (~1.05) which increases to a speedup of ~7x with batch size 20.

Unfortunately this means fastISM needs to run at highest batch size possible, which means that max possible speedup is closer to ~3x because of the slow GPU->CPU transfer. I'm currently getting a speedup of 2x on the model. This may go up slightly depending on your GPU.

davek44 commented 3 years ago

Wow, thanks so much for doing those experiments! I would happily take a 2-3x speedup, although I bet there is some additional tuning that I could do to further improve it. For example, I'm almost always taking the sum across the length of the sequence to compute a single score for each target. Moving that into the model itself would mean ~100-1000x less data needs to be transferred from the GPU -> CPU. Thanks for pointing that bottleneck out.

I did some quick and dirty experiments with batch size and my naive ISM code, and I'm not seeing any change in performance. Are you saying that the batch size dependence is unique to fastISM? The computational graph doesn't change batch to batch, right?

I fixed the init name bug above. (My GELU layer didn't properly handle kwargs like 'name'.) The next crash comes in the following form.

  File "/home/drk/code/fastISM/fastISM/fast_ism.py", line 21, in __init__
    if not self.test_correctness():
  File "/home/drk/code/fastISM/fastISM/fast_ism.py", line 129, in test_correctness
    naive_out = naive_ism(x, replace_with=replace_with)
  File "/home/drk/code/fastISM/fastISM/ism_base.py", line 60, in __call__
    ism_outputs = np.repeat(np.expand_dims(unperturbed_output.numpy(), 1),
  File "<__array_function__ internals>", line 5, in repeat
  File "/home/drk/anaconda3/envs/py38/lib/python3.8/site-packages/numpy/core/fromnumeric.py", line 482, in repeat
    return _wrapfunc(a, 'repeat', repeats, axis=axis)
  File "/home/drk/anaconda3/envs/py38/lib/python3.8/site-packages/numpy/core/fromnumeric.py", line 61, in _wrapfunc
    return bound(*args, **kwds)
MemoryError: Unable to allocate 3.65 TiB for an array with shape (10, 49152, 384, 5313) and data type float32

I believe the values in that tuple refer to batch_size=10, input_length=49152, output_length=384, num_tasks=5313. That's way intractable. When I compute ISM scores, I'm typically doing a few hundred nucleotides in the center of the big sequence. Furthermore, I'm never initializing a full array, but instead processing the nucleotides across multiple batches and writing the output to an HDF5 as I go. More often, I'm scoring genetic variants, so just mutating one center nucleotide. Are there any barriers to operating that way using fastISM in order to avoid this massive array allocation?

suragnair commented 3 years ago

The tuple for which allocation was attempted essentially stores (on CPU), for each sequence, the outputs after mutating each position in the input. You can explicitly specify which positions you'd like to perturb by passing in the change_ranges. So if you want to perturb the central 100 bases, it would look something like (more examples here):

mid_point = model.input_shape[1]//2

fast_ism_model = FastISM(model,
                         change_ranges = [(mid+i,mid+i+1) for i in range(-50,50)])

This would initialise a tuple (on CPU) with dimensions (batch_size, 100, output_len, num_tasks) for every batch, perturb the central 100 positions (to 0 by default, but you can run for each base like here). Note that the background sequence stays the same for each of the 100 mutations and is not centered at each mutation separately.

fastISM would not be useful if you only perturb one base per sequence. The reason for this is that fastISM necessarily performs one full forward propagation on the original input sequence and caches some of the intermediate outputs. It is much faster for subsequent mutations (on same fixed background sequence) as it only recomputes what is required and uses the cached intermediate outputs for the rest. So if instead you perturb 100 positions in the center, fastISM would do one normal forward pass and 100 fast ones, which would be much faster than doing 100 normal forward passes. I hope that's a clear explanation, feel free to probe further if not.

If you can indeed reduce the model output by ~100-1000x, then it should be possible to get 6-7x from fastISM easily (like in model_so above), and perhaps more with a few simple optimisations!

davek44 commented 3 years ago

Thanks for clarifying why variant scoring isn’t as interesting with this approach.

Overall, it sounds like there’s a promising path here! The three things I see that I could still use some help with are:

  1. Modifying how the first segment is determined to allow SEE_THROUGH_LAYERS that arrive before the first convolution.
  2. Handling valid max pooling, at least for even pooling parameters.
  3. Handling cropping layers.

Would you be able to help implement those in the next version?

suragnair commented 3 years ago

Great! Glad to know it would fit your use case. Regarding the features:

  1. I'll get this done
  2. I'm a little hesitant since I'm unsure what will need to be modified and I believe same padding is rarely used for pooling. I will definitely look into it, but can't promise. Would it be too much work to retrain with valid padding? Performance wise it shouldn't matter since, with a pool size of 2, it essentially boils down to an edge effect when sequences are of odd length.
  3. This is a request from others as well and I'll work on it.

I expect to look at them in the next week and implement them as soon as possible. I'll also do a feasibility analysis for 2 and if it's not too tricky will try to work it in. Hope that works!

suragnair commented 3 years ago

1 and 3 are now implemented in v0.4.1 (6110e4a6a). Have done a few basic checks, added some tests and it seems to be working. Let me know if you can give it a shot.

davek44 commented 3 years ago

Hey Surag, it seems to be working well for me now! I sorted out the valid padding on my end, so don't worry as much about that. I'll keep pushing forward and let you know how the benchmarks play out on my end. Thanks for your help!

suragnair commented 3 years ago

Glad to hear that! Do let me know how it pans out and if you need me to implement any other features. I'll leave this Issue open for a while so feel free to add here.