calico / basenji

Sequential regulatory activity predictions with deep convolutional neural networks.
Apache License 2.0
410 stars 126 forks source link

Basset-style training failing due to shape mismatch #109

Open twrightsman opened 2 years ago

twrightsman commented 2 years ago

I'm attempting to train a Basset-style model using a single dataset of peak calls but I'm getting the following error after running basenji_train.py -k -o tmp/train_basset basenji/manuscripts/basset/params_basset.json tmp/data_basset

Traceback (most recent call last):                                                                                                                                                                         [2/1852]
  File "basenji/bin/basenji_train.py", line 174, in <module>                                                                          
    main()                                                                                                                                                                                                         
  File "basenji/bin/basenji_train.py", line 163, in main
    seqnn_trainer.fit_keras(seqnn_model)
  File "basenji/basenji/trainer.py", line 128, in fit_keras
    seqnn_model.model.fit(
  File "/home/twrightsman/.conda/envs/basenji/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 1095, in fit
    tmp_logs = self.train_function(iterator)
  File "/home/twrightsman/.conda/envs/basenji/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 828, in __call__
    result = self._call(*args, **kwds)
  File "/home/twrightsman/.conda/envs/basenji/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 888, in _call
    return self._stateless_fn(*args, **kwds)
  File "/home/twrightsman/.conda/envs/basenji/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2942, in __call__
    return graph_function._call_flat(
  File "/home/twrightsman/.conda/envs/basenji/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 1918, in _call_flat
    return self._build_call_outputs(self._inference_function.call(
  File "/home/twrightsman/.conda/envs/basenji/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 555, in call
    outputs = execute.execute(
  File "/home/twrightsman/.conda/envs/basenji/lib/python3.8/site-packages/tensorflow/python/eager/execute.py", line 59, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.InvalidArgumentError:  Input to reshape is a tensor with 64 values, but the requested shape requires a multiple of 164
         [[{{node ArithmeticOptimizer/ReorderCastLikeAndValuePreserving_bool_Reshape_3}}]] [Op:__inference_train_function_5771]

Function call stack:
train_function

Modifying the batch size of Basset from 64 to 164 fixes this issue. Should this be changed in the repo or am I doing something wrong on my end? I believe this may be related to calico/basenji#73 and calico/basenji#75?

diff --git a/manuscripts/basset/params_basset.json b/manuscripts/basset/params_basset.json
index 5415e8c..7e3e3c9 100644
--- a/manuscripts/basset/params_basset.json
+++ b/manuscripts/basset/params_basset.json
@@ -1,6 +1,6 @@
 {
     "train": {
-        "batch_size": 64,
+        "batch_size": 164,
         "shuffle_buffer": 8192,
         "optimizer": "sgd",
         "loss": "bce",
twrightsman commented 2 years ago

I noticed I was using the CPU version of tensorflow to train, so I switched to the GPU version and am unfortunately getting a similar error:

2022-02-04 05:20:40.388956: W tensorflow/core/framework/op_kernel.cc:1692] OP_REQUIRES failed at segment_reduction_ops_impl.h:425 : Invalid argument: data.shape = [164,1] does not start with segment_ids.shape = 
[164,164]                                                                                                                                                                                                          
Traceback (most recent call last):                                                                                                                                                                                 
  File "basenji/bin/basenji_train.py", line 174, in <module>                                                                          
    main()                                                                                                                                                                                                         
  File "basenji/bin/basenji_train.py", line 163, in main                                                                              
    seqnn_trainer.fit_keras(seqnn_model)                                                                                                                                                                           
  File "basenji/basenji/trainer.py", line 128, in fit_keras                                                                           
    seqnn_model.model.fit(
  File "/home/twrightsman/.conda/envs/basenji/lib/python3.8/site-packages/keras/engine/training.py", line 1184, in fit
    tmp_logs = self.train_function(iterator)
  File "/home/twrightsman/.conda/envs/basenji/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 885, in __call__
    result = self._call(*args, **kwds)
  File "/home/twrightsman/.conda/envs/basenji/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 950, in _call
    return self._stateless_fn(*args, **kwds)
  File "/home/twrightsman/.conda/envs/basenji/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3039, in __call__
    return graph_function._call_flat(
  File "/home/twrightsman/.conda/envs/basenji/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 1963, in _call_flat
    return self._build_call_outputs(self._inference_function.call(
  File "/home/twrightsman/.conda/envs/basenji/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 591, in call
    outputs = execute.execute(
  File "/home/twrightsman/.conda/envs/basenji/lib/python3.8/site-packages/tensorflow/python/eager/execute.py", line 59, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.InvalidArgumentError:  data.shape = [164,1] does not start with segment_ids.shape = [164,164]                                                                              
         [[node loop_body_2/UnsortedSegmentSum/pfor/UnsortedSegmentSum (defined at basenji/basenji/metrics.py:88) ]] [Op:__inference_tr
ain_function_6081]                                                                                                                                                                                                 

Function call stack:                                                                                                                                                                                               
train_function

I can stick to training on the CPU version for now, but still thought I'd report this.

davek44 commented 2 years ago

Hi, sorry for the delayed response. I'm unable to replicate this error. Are you using the latest code from the master branch? What version of tensorflow? I'm going to regenerate my tfrecords to see if maybe something changed there. It's possible the problem is in the data, rather than model.

davek44 commented 2 years ago

Regenerating the tfrecords similarly produced a model that trains well. I'm using tensorflow v2.6. Let me know if you discover anything that might be discordant on your end.

twrightsman commented 2 years ago

Thank you for the follow up; I'll repeat my steps on the latest master soon and see if the problem persists. Either way I'll let you know if I can reproduce it and if so, also the exact commands/data.

gouthamatla commented 2 years ago

Hi, I had similar problem. I guess it has to do with the number of targets. May I ask how should I change parameters in this repo if I want to train on two or three targets (e.g. Strong enhancers, weak enhancers, promoters) instead of multiples of 164 ?

davek44 commented 2 years ago

There is no requirement to train on target numbers that are multiples of 164. Just generate a training dataset with your 2-3 targets, and change the final model layer to match the number of targets.