strongio / keras-bert

A simple technique to integrate BERT from tf hub to keras
258 stars 108 forks source link

Nested tensor dimensions #24

Closed kerstenj closed 5 years ago

kerstenj commented 5 years ago

I would like to use BERT for a multi-class multi-task classification. For each sentence (let's say with a fixed number of n tokens) to classify, BERT would (when I got it right) provide a vector of 768 elements, i.e., (n,768). When batches are involved, I would expect to have (None, n, 768). With keras-bert, I obtain ((None, n), 768). For feeding this tensor to keras' text YoonKimCNN, I have to add a further dimension here, but the nested structure remains, so that also the final layer have this ((None, n), m), even though I would expect to obtain (None,m) in the end. Structure:

Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_ids (InputLayer)          (None, 256)          0                                            
__________________________________________________________________________________________________
input_masks (InputLayer)        (None, 256)          0                                            
__________________________________________________________________________________________________
segment_ids (InputLayer)        (None, 256)          0                                            
__________________________________________________________________________________________________
bert_layer_1 (BertLayer)        ((None, 256), 768)   110104890   input_ids[0][0]                  
                                                                 input_masks[0][0]                
                                                                 segment_ids[0][0]                
__________________________________________________________________________________________________
reshape_1 (Reshape)             ((None, 256), 768, 1 0           bert_layer_1[0][0]               
__________________________________________________________________________________________________
consume_mask_1 (ConsumeMask)    ((None, 256), 768, 1 0           reshape_1[0][0]                  
__________________________________________________________________________________________________
conv1d_1 (Conv1D)               ((None, 256), 766, 1 512         consume_mask_1[0][0]             
__________________________________________________________________________________________________
conv1d_2 (Conv1D)               ((None, 256), 765, 1 640         consume_mask_1[0][0]             
__________________________________________________________________________________________________
conv1d_3 (Conv1D)               ((None, 256), 764, 1 768         consume_mask_1[0][0]             
__________________________________________________________________________________________________
global_max_pooling1d_1 (GlobalM ((None, 256), 128)   0           conv1d_1[0][0]                   
__________________________________________________________________________________________________
global_max_pooling1d_2 (GlobalM ((None, 256), 128)   0           conv1d_2[0][0]                   
__________________________________________________________________________________________________
global_max_pooling1d_3 (GlobalM ((None, 256), 128)   0           conv1d_3[0][0]                   
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     ((None, 256), 384)   0           global_max_pooling1d_1[0][0]     
                                                                 global_max_pooling1d_2[0][0]     
                                                                 global_max_pooling1d_3[0][0]     
__________________________________________________________________________________________________
dropout_1 (Dropout)             ((None, 256), 384)   0           concatenate_1[0][0]              
__________________________________________________________________________________________________
dense_4 (Dense)                 ((None, 256), 256)   98560       dropout_1[0][0]                  
__________________________________________________________________________________________________
dense_1 (Dense)                 ((None, 256), 128)   49280       dropout_1[0][0]                  
__________________________________________________________________________________________________
dropout_4 (Dropout)             ((None, 256), 256)   0           dense_4[0][0]                    
__________________________________________________________________________________________________
dropout_2 (Dropout)             ((None, 256), 128)   0           dense_1[0][0]                    
__________________________________________________________________________________________________
dense_5 (Dense)                 ((None, 256), 128)   32896       dropout_4[0][0]                  
__________________________________________________________________________________________________
dense_2 (Dense)                 ((None, 256), 64)    8256        dropout_2[0][0]                  
__________________________________________________________________________________________________
dropout_5 (Dropout)             ((None, 256), 128)   0           dense_5[0][0]                    
__________________________________________________________________________________________________
dropout_3 (Dropout)             ((None, 256), 64)    0           dense_2[0][0]                    
__________________________________________________________________________________________________
dense_6 (Dense)                 ((None, 256), 25)    3225        dropout_5[0][0]                  
__________________________________________________________________________________________________
dense_3 (Dense)                 ((None, 256), 1)     65          dropout_3[0][0]                  
==================================================================================================
Total params: 110,299,092
Trainable params: 71,072,922
Non-trainable params: 39,226,170
__________________________________________________________________________________________________

This looks different from what we can see here. Any suggestions how to get rid of the nested structure are welcome.

kerstenj commented 5 years ago

Had to change

    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.output_size)

to

    def compute_output_shape(self, input_shape):
        return (input_shape[0][0], self.output_size)