Frightera / Medium_Notebooks_English

This repo contains the notebooks that is used in Medium posts.
24 stars 6 forks source link

Is it able to convert point cloud model to tensorflow Probabilistic Deep Learning model ? #1

Open kuangzy2011 opened 1 year ago

kuangzy2011 commented 1 year ago

Hi Frightera: I am doing some test with pointnet to classify point cloud object. And there might be some other unknown object are not listed in training set. So I tried to convert pointnet model(https://keras.io/examples/vision/pointnet/) to tensorflow probability model by refering your Simple Fully Probabilistic Bayesian CNN, and hope the model can say "I don't know" if target object is not in training model, but the results of the training were less than ideal. Is it possible to convert point cloud model to tensorflow Probabilistic Deep Learning model ?

Here is my testing model:
NUM_EXAMPLES = 83820
NUM_CLASSES = 10
BATCH_SIZE = 64
NUM_POINTS = 231

# Revise reinterpreted_batch_ndims arg. Needed for custom prior & posterior.
shape = (3, )
dtype = tf.float64
distribution = tfd.Normal(loc = tf.zeros(shape, dtype), scale = tf.ones(shape, dtype))
# batch_ndims = tf.size(distribution.batch_shape_tensor())

for i in range(len(shape) + 1):
    print('reinterpreted_batch_ndims: %d:' %(i))
    independent_dist = tfd.Independent(distribution, reinterpreted_batch_ndims = i)
    samples = independent_dist.sample()
    print('batch_shape: {}' 
          ' event_shape: {}' 
          ' Sample shape: {}'.format(independent_dist._batch_shape(),
                                independent_dist._event_shape(),
                                samples.shape))
    print('Samples:', samples.numpy(), '\n')

# The default posterior is Normal and if we use a laplace prior, we need to
# approximate the KL. If we try to compute KL with that 2 distributions, we will get:
# Error:
# No KL(distribution_a || distribution_b) registered for distribution_a type Normal and distribution_b type Laplace
# Call arguments received:
#   • inputs=tf.Tensor(shape=(None, 28, 28, 1), dtype=float32)
def approximate_kl(q, p, q_tensor):
    return tf.reduce_mean(q.log_prob(q_tensor) - p.log_prob(q_tensor))

divergence_fn = lambda q, p, q_tensor : approximate_kl(q, p, q_tensor) / NUM_EXAMPLES

def conv_reparameterization_layer(filters, kernel_size, padding):
    # For simplicity, we use default prior and posterior.
    # In the next parts, we will use custom mixture prior and posteriors.
    return tfpl.Convolution1DReparameterization(
        filters = filters,
        kernel_size = kernel_size,
        activation = None, 
        padding = padding,
        kernel_posterior_fn = tfpl.default_mean_field_normal_fn(is_singular=False),
        kernel_prior_fn = tfpl.default_multivariate_normal_fn,

        bias_prior_fn = tfpl.default_multivariate_normal_fn,
        bias_posterior_fn = tfpl.default_mean_field_normal_fn(is_singular=False),

        kernel_divergence_fn = divergence_fn,
        bias_divergence_fn = divergence_fn)

def dense_reparameterization_layer(filters):
    return tfpl.DenseReparameterization(
        units = filters, activation = None,
        kernel_posterior_fn = tfpl.default_mean_field_normal_fn(is_singular=False),
        kernel_prior_fn = tfpl.default_multivariate_normal_fn,

        bias_prior_fn = tfpl.default_multivariate_normal_fn,
        bias_posterior_fn = tfpl.default_mean_field_normal_fn(is_singular=False),

        kernel_divergence_fn = divergence_fn,
        bias_divergence_fn = divergence_fn)

def nll(y_true, y_pred):
    return -y_pred.log_prob(y_true)

def conv_bn(x, filters):
    x = conv_reparameterization_layer(filters, kernel_size=1, padding='valid')(x)
    x = layers.BatchNormalization(momentum=0.0)(x)
    return layers.Activation("relu")(x)

def dense_bn(x, filters):
    x = dense_reparameterization_layer(filters)(x)
    x = layers.BatchNormalization(momentum=0.0)(x)
    return layers.Activation("relu")(x)

class OrthogonalRegularizer(keras.regularizers.Regularizer):
    def __init__(self, num_features, l2reg=0.001):
        self.num_features = num_features
        self.l2reg = l2reg
        self.eye = tf.eye(num_features)

    def __call__(self, x):
        x = tf.reshape(x, (-1, self.num_features, self.num_features))
        xxt = tf.tensordot(x, x, axes=(2, 2))
        xxt = tf.reshape(xxt, (-1, self.num_features, self.num_features))
        return tf.reduce_sum(self.l2reg * tf.square(xxt - self.eye))

def tnet(inputs, num_features):

    # Initalise bias as the indentity matrix
    bias = keras.initializers.Constant(np.eye(num_features).flatten())
    reg = OrthogonalRegularizer(num_features)

    x = conv_bn(inputs, 32)
    x = conv_bn(x, 64)
    x = conv_bn(x, 512)
    x = layers.GlobalMaxPooling1D()(x)
    x = dense_bn(x, 256)
    x = dense_bn(x, 128)
    x = layers.Dense(
            num_features * num_features,
            kernel_initializer="zeros",
            bias_initializer=bias,
            activity_regularizer=reg,
        )(x)
    feat_T = layers.Reshape((num_features, num_features))(x)
    # Apply affine transformation to input features
    return layers.Dot(axes=(2, 1))([inputs, feat_T])

def create_model_normal(num_points):
    inputs = keras.Input(shape=(num_points, 3))

    x = tnet(inputs, 3)
    x = conv_bn(x, 32)
    x = conv_bn(x, 32)
    x = tnet(x, 32)
    x = conv_bn(x, 32)
    x = conv_bn(x, 64)
    x = conv_bn(x, 512)
    x = layers.GlobalMaxPooling1D()(x)
    x = dense_bn(x, 256)
    x = layers.Dropout(0.25)(x)
    x = dense_bn(x, 128)
    x = layers.Dropout(0.25)(x)

    dense = tfpl.DenseReparameterization(
            units = tfpl.OneHotCategorical.params_size(NUM_CLASSES), activation = None,
            kernel_posterior_fn = tfpl.default_mean_field_normal_fn(is_singular=False),
            kernel_prior_fn = tfpl.default_multivariate_normal_fn,

            bias_prior_fn = tfpl.default_multivariate_normal_fn,
            bias_posterior_fn = tfpl.default_mean_field_normal_fn(is_singular=False),

            kernel_divergence_fn = divergence_fn,
            bias_divergence_fn = divergence_fn)(x)
    outputs = tfpl.OneHotCategorical(NUM_CLASSES)(dense)
    model = keras.Model(inputs=inputs, outputs=outputs, name="pointnet")
    return model

######################################################################################

def create_model(num_points):
    return create_model_normal(num_points)

model = create_model(NUM_POINTS)
model.summary()

>>Model: "pointnet"
>>__________________________________________________________________________________________________
>>Layer (type)                    Output Shape         Param #     Connected to                     
>>==================================================================================================
>>input_3 (InputLayer)            [(None, 231, 3)]     0                                            
>>__________________________________________________________________________________________________
>>conv1d_reparameterization_22 (C (None, 231, 32)      256         input_3[0][0]                    
>>__________________________________________________________________________________________________
>>batch_normalization_34 (BatchNo (None, 231, 32)      128         conv1d_reparameterization_22[0][0
>>__________________________________________________________________________________________________
>>activation_34 (Activation)      (None, 231, 32)      0           batch_normalization_34[0][0]     
>>__________________________________________________________________________________________________
>>conv1d_reparameterization_23 (C (None, 231, 64)      4224        activation_34[0][0]              
>>__________________________________________________________________________________________________
>>batch_normalization_35 (BatchNo (None, 231, 64)      256         conv1d_reparameterization_23[0][0
>>__________________________________________________________________________________________________
>>activation_35 (Activation)      (None, 231, 64)      0           batch_normalization_35[0][0]     
>>__________________________________________________________________________________________________
>>conv1d_reparameterization_24 (C (None, 231, 512)     66560       activation_35[0][0]              
>>__________________________________________________________________________________________________
>>batch_normalization_36 (BatchNo (None, 231, 512)     2048        conv1d_reparameterization_24[0][0
>>__________________________________________________________________________________________________
>>activation_36 (Activation)      (None, 231, 512)     0           batch_normalization_36[0][0]     
>>__________________________________________________________________________________________________
>>global_max_pooling1d_6 (GlobalM (None, 512)          0           activation_36[0][0]              
>>__________________________________________________________________________________________________
>>dense_reparameterization_18 (De (None, 256)          262656      global_max_pooling1d_6[0][0]     
>>__________________________________________________________________________________________________
>>batch_normalization_37 (BatchNo (None, 256)          1024        dense_reparameterization_18[0][0]
>>__________________________________________________________________________________________________
>>activation_37 (Activation)      (None, 256)          0           batch_normalization_37[0][0]     
>>__________________________________________________________________________________________________
>>dense_reparameterization_19 (De (None, 128)          65792       activation_37[0][0]              
>>__________________________________________________________________________________________________
>>batch_normalization_38 (BatchNo (None, 128)          512         dense_reparameterization_19[0][0]
>>__________________________________________________________________________________________________
>>activation_38 (Activation)      (None, 128)          0           batch_normalization_38[0][0]     
>>__________________________________________________________________________________________________
>>dense (Dense)                   (None, 9)            1161        activation_38[0][0]              
>>__________________________________________________________________________________________________
>>reshape_4 (Reshape)             (None, 3, 3)         0           dense[0][0]                      
>>__________________________________________________________________________________________________
>>dot_4 (Dot)                     (None, 231, 3)       0           input_3[0][0]                    
>>                                                                 reshape_4[0][0]                  
>>__________________________________________________________________________________________________
>>conv1d_reparameterization_25 (C (None, 231, 32)      256         dot_4[0][0]                      
>>__________________________________________________________________________________________________
>>batch_normalization_39 (BatchNo (None, 231, 32)      128         conv1d_reparameterization_25[0][0
>>__________________________________________________________________________________________________
>>activation_39 (Activation)      (None, 231, 32)      0           batch_normalization_39[0][0]     
>>__________________________________________________________________________________________________
>>conv1d_reparameterization_26 (C (None, 231, 32)      2112        activation_39[0][0]              
>>__________________________________________________________________________________________________
>>batch_normalization_40 (BatchNo (None, 231, 32)      128         conv1d_reparameterization_26[0][0
>>__________________________________________________________________________________________________
>>activation_40 (Activation)      (None, 231, 32)      0           batch_normalization_40[0][0]     
>>__________________________________________________________________________________________________
>>conv1d_reparameterization_27 (C (None, 231, 32)      2112        activation_40[0][0]              
>>__________________________________________________________________________________________________
>>batch_normalization_41 (BatchNo (None, 231, 32)      128         conv1d_reparameterization_27[0][0
>>__________________________________________________________________________________________________
>>activation_41 (Activation)      (None, 231, 32)      0           batch_normalization_41[0][0]     
>>__________________________________________________________________________________________________
>>conv1d_reparameterization_28 (C (None, 231, 64)      4224        activation_41[0][0]              
>>__________________________________________________________________________________________________
>>batch_normalization_42 (BatchNo (None, 231, 64)      256         conv1d_reparameterization_28[0][0
>>__________________________________________________________________________________________________
>>activation_42 (Activation)      (None, 231, 64)      0           batch_normalization_42[0][0]     
>>__________________________________________________________________________________________________
>>conv1d_reparameterization_29 (C (None, 231, 512)     66560       activation_42[0][0]              
>>__________________________________________________________________________________________________
>>batch_normalization_43 (BatchNo (None, 231, 512)     2048        conv1d_reparameterization_29[0][0
>>__________________________________________________________________________________________________
>>activation_43 (Activation)      (None, 231, 512)     0           batch_normalization_43[0][0]     
>>__________________________________________________________________________________________________
>>global_max_pooling1d_7 (GlobalM (None, 512)          0           activation_43[0][0]              
>>__________________________________________________________________________________________________
>>dense_reparameterization_20 (De (None, 256)          262656      global_max_pooling1d_7[0][0]     
>>__________________________________________________________________________________________________
>>batch_normalization_44 (BatchNo (None, 256)          1024        dense_reparameterization_20[0][0]
>>__________________________________________________________________________________________________
>>activation_44 (Activation)      (None, 256)          0           batch_normalization_44[0][0]     
>>__________________________________________________________________________________________________
>>dense_reparameterization_21 (De (None, 128)          65792       activation_44[0][0]              
>>__________________________________________________________________________________________________
>>batch_normalization_45 (BatchNo (None, 128)          512         dense_reparameterization_21[0][0]
>>__________________________________________________________________________________________________
>>activation_45 (Activation)      (None, 128)          0           batch_normalization_45[0][0]     
>>__________________________________________________________________________________________________
>>dense_1 (Dense)                 (None, 1024)         132096      activation_45[0][0]              
>>__________________________________________________________________________________________________
>>reshape_5 (Reshape)             (None, 32, 32)       0           dense_1[0][0]                    
>>__________________________________________________________________________________________________
>>dot_5 (Dot)                     (None, 231, 32)      0           activation_40[0][0]              
>>                                                                 reshape_5[0][0]                  
>>__________________________________________________________________________________________________
>>conv1d_reparameterization_30 (C (None, 231, 32)      2112        dot_5[0][0]                      
>>__________________________________________________________________________________________________
>>batch_normalization_46 (BatchNo (None, 231, 32)      128         conv1d_reparameterization_30[0][0
>>__________________________________________________________________________________________________
>>activation_46 (Activation)      (None, 231, 32)      0           batch_normalization_46[0][0]     
>>__________________________________________________________________________________________________
>>conv1d_reparameterization_31 (C (None, 231, 64)      4224        activation_46[0][0]              
>>__________________________________________________________________________________________________
>>batch_normalization_47 (BatchNo (None, 231, 64)      256         conv1d_reparameterization_31[0][0
>>__________________________________________________________________________________________________
>>activation_47 (Activation)      (None, 231, 64)      0           batch_normalization_47[0][0]     
>>__________________________________________________________________________________________________
>>conv1d_reparameterization_32 (C (None, 231, 512)     66560       activation_47[0][0]              
>>__________________________________________________________________________________________________
>>batch_normalization_48 (BatchNo (None, 231, 512)     2048        conv1d_reparameterization_32[0][0
>>__________________________________________________________________________________________________
>>activation_48 (Activation)      (None, 231, 512)     0           batch_normalization_48[0][0]     
>>__________________________________________________________________________________________________
>>global_max_pooling1d_8 (GlobalM (None, 512)          0           activation_48[0][0]              
>>__________________________________________________________________________________________________
>>dense_reparameterization_22 (De (None, 256)          262656      global_max_pooling1d_8[0][0]     
>>__________________________________________________________________________________________________
>>batch_normalization_49 (BatchNo (None, 256)          1024        dense_reparameterization_22[0][0]
>>__________________________________________________________________________________________________
>>activation_49 (Activation)      (None, 256)          0           batch_normalization_49[0][0]     
>>__________________________________________________________________________________________________
>>dropout_4 (Dropout)             (None, 256)          0           activation_49[0][0]              
>>__________________________________________________________________________________________________
>>dense_reparameterization_23 (De (None, 128)          65792       dropout_4[0][0]                  
>>__________________________________________________________________________________________________
>>batch_normalization_50 (BatchNo (None, 128)          512         dense_reparameterization_23[0][0]
>>__________________________________________________________________________________________________
>>activation_50 (Activation)      (None, 128)          0           batch_normalization_50[0][0]     
>>__________________________________________________________________________________________________
>>dropout_5 (Dropout)             (None, 128)          0           activation_50[0][0]              
>>__________________________________________________________________________________________________
>>dense_reparameterization_24 (De (None, 10)           2580        dropout_5[0][0]                  
>>__________________________________________________________________________________________________
>>one_hot_categorical_2 (OneHotCa multiple             0           dense_reparameterization_24[0][0]
>>==================================================================================================
>>Total params: 1,352,541
>>Trainable params: 1,346,461
>>Non-trainable params: 6,080

>>Train model
EPOCHS = 30

# Audjusting learning rate
learning_rate_reduction = ReduceLROnPlateau(monitor='val_acc', 
                                            patience=3, 
                                            verbose=1, 
                                            factor=0.5, 
                                            min_lr=0.00001)

model.compile(
        loss=nll,
        optimizer=keras.optimizers.Adam(learning_rate=0.0001),
        metrics=["accuracy"]
    )

history = model.fit(train_dataset, epochs=EPOCHS, validation_data=test_dataset, verbose=2, callbacks=[learning_rate_reduction])

>>train_dataset <BatchDataset shapes: ((None, 231, 3), (None, 10)), types: (tf.float64, tf.float64)>
>>test_dataset <BatchDataset shapes: ((None, 231, 3), (None, 10)), types: (tf.float64, tf.float64)>
>>Epoch 1/30
>>1310/1310 - 60s - loss: 21.9061 - accuracy: 0.3590 - val_loss: 182635040.0000 - val_accuracy: 0.0833
>>Epoch 2/30
>>1310/1310 - 44s - loss: 20.6475 - accuracy: 0.7531 - val_loss: 60451084.0000 - val_accuracy: 0.1429
>>Epoch 3/30
>>1310/1310 - 44s - loss: 20.2906 - accuracy: 0.8938 - val_loss: 14749431.0000 - val_accuracy: 0.1071
>>Epoch 4/30
>>1310/1310 - 44s - loss: 20.0421 - accuracy: 0.9462 - val_loss: 22719518.0000 - val_accuracy: 0.1024
>>Epoch 5/30
>>1310/1310 - 43s - loss: 19.8038 - accuracy: 0.9664 - val_loss: 26609368.0000 - val_accuracy: 0.1381
>>Epoch 6/30
>>1310/1310 - 44s - loss: 19.5439 - accuracy: 0.9753 - val_loss: 2488617.5000 - val_accuracy: 0.0905
>>Epoch 7/30
>>1310/1310 - 44s - loss: 19.2616 - accuracy: 0.9819 - val_loss: 5732172.0000 - val_accuracy: 0.1048
>>Epoch 8/30
>>1310/1310 - 43s - loss: 18.9594 - accuracy: 0.9854 - val_loss: 3599270.0000 - val_accuracy: 0.1095
>>Epoch 9/30
>>1310/1310 - 44s - loss: 18.6469 - accuracy: 0.9876 - val_loss: 1626159.1250 - val_accuracy: 0.0976
>>Epoch 10/30
>>1310/1310 - 44s - loss: 18.3281 - accuracy: 0.9896 - val_loss: 2010936.1250 - val_accuracy: 0.1048
>>Epoch 11/30
>>1310/1310 - 44s - loss: 18.0074 - accuracy: 0.9907 - val_loss: 2303319.7500 - val_accuracy: 0.1190
>>Epoch 12/30
>>1310/1310 - 45s - loss: 17.6917 - accuracy: 0.9915 - val_loss: 6043947.5000 - val_accuracy: 0.0905
>>Epoch 13/30
>>1310/1310 - 44s - loss: 17.3721 - accuracy: 0.9932 - val_loss: 7441842.0000 - val_accuracy: 0.1071
>>Epoch 14/30
>>1310/1310 - 44s - loss: 17.0469 - accuracy: 0.9932 - val_loss: 8680777.0000 - val_accuracy: 0.1071
>>Epoch 15/30
>>1310/1310 - 44s - loss: 16.7201 - accuracy: 0.9942 - val_loss: 5293146.0000 - val_accuracy: 0.0810
>>Epoch 16/30
>>1310/1310 - 44s - loss: 16.3949 - accuracy: 0.9944 - val_loss: 3354950.5000 - val_accuracy: 0.1119
>>Epoch 17/30
>>1310/1310 - 44s - loss: 16.0781 - accuracy: 0.9944 - val_loss: 89280024.0000 - val_accuracy: 0.1167
>>Epoch 18/30
>>1310/1310 - 44s - loss: 15.7646 - accuracy: 0.9953 - val_loss: 2574807.5000 - val_accuracy: 0.0905
>>Epoch 19/30
>>1310/1310 - 43s - loss: 15.4577 - accuracy: 0.9952 - val_loss: 1522962.6250 - val_accuracy: 0.1143
>>Epoch 20/30
>>1310/1310 - 45s - loss: 15.1654 - accuracy: 0.9950 - val_loss: 26839674.0000 - val_accuracy: 0.0714
>>Epoch 21/30
>>1310/1310 - 43s - loss: 14.8791 - accuracy: 0.9954 - val_loss: 68192664.0000 - val_accuracy: 0.0976
>>Epoch 22/30
>>1310/1310 - 43s - loss: 14.5917 - accuracy: 0.9952 - val_loss: 69933488.0000 - val_accuracy: 0.0976
>>Epoch 23/30
>>1310/1310 - 44s - loss: 14.3119 - accuracy: 0.9955 - val_loss: 916111680.0000 - val_accuracy: 0.1286
>>Epoch 24/30
>>1310/1310 - 44s - loss: 14.0397 - accuracy: 0.9953 - val_loss: 1315758208.0000 - val_accuracy: 0.0952
>>Epoch 25/30
>>1310/1310 - 44s - loss: 13.7770 - accuracy: 0.9954 - val_loss: 176787792.0000 - val_accuracy: 0.1095
>>Epoch 26/30
>>1310/1310 - 44s - loss: 13.5258 - accuracy: 0.9958 - val_loss: 209021744.0000 - val_accuracy: 0.1000
>>Epoch 27/30
>>1310/1310 - 44s - loss: 13.2780 - accuracy: 0.9955 - val_loss: 217535568.0000 - val_accuracy: 0.1238
>>Epoch 28/30
>>1310/1310 - 44s - loss: 13.0325 - accuracy: 0.9960 - val_loss: 62926640.0000 - val_accuracy: 0.1024
>>Epoch 29/30
>>1310/1310 - 44s - loss: 12.7926 - accuracy: 0.9960 - val_loss: 111481688.0000 - val_accuracy: 0.0905
>>Epoch 30/30
>>1310/1310 - 44s - loss: 12.5660 - accuracy: 0.9951 - val_loss: 61252780.0000 - val_accuracy: 0.0810
Frightera commented 1 year ago

The implementation seems correct, but validation loss is sky-rocketing. I assume model is dealing with too extreme values and overfitting on training dataset.

Can you remove the line: outputs = tfpl.OneHotCategorical(NUM_CLASSES)(dense)

And try to work with logits: tfpl.DistributionLambda(lambda x: tfd.OneHotCategorical(logits = 0.01 * x) We scale the logits to see if model is dealing with extreme values.

kuangzy2011 commented 1 year ago

Update model and train again.

def create_model_normal(num_points):
    inputs = keras.Input(shape=(num_points, 3))

    x = tnet(inputs, 3)
    x = conv_bn(x, 32)
    x = conv_bn(x, 32)
    x = tnet(x, 32)
    x = conv_bn(x, 32)
    x = conv_bn(x, 64)
    x = conv_bn(x, 512)
    x = layers.GlobalMaxPooling1D()(x)
    x = dense_bn(x, 256)
    x = layers.Dropout(0.25)(x)
    x = dense_bn(x, 128)
    x = layers.Dropout(0.25)(x)

    dense = tfpl.DenseReparameterization(
            units = tfpl.OneHotCategorical.params_size(NUM_CLASSES), activation = None,
            kernel_posterior_fn = tfpl.default_mean_field_normal_fn(is_singular=False),
            kernel_prior_fn = tfpl.default_multivariate_normal_fn,

            bias_prior_fn = tfpl.default_multivariate_normal_fn,
            bias_posterior_fn = tfpl.default_mean_field_normal_fn(is_singular=False),

            kernel_divergence_fn = divergence_fn,
            bias_divergence_fn = divergence_fn)(x)
    outputs = tfpl.DistributionLambda(lambda x: tfd.OneHotCategorical(logits = 0.01 * x))(dense)
    model = keras.Model(inputs=inputs, outputs=outputs, name="pointnet")
    return model

train_dataset <BatchDataset shapes: ((None, 231, 3), (None, 10)), types: (tf.float64, tf.float64)>
test_dataset <BatchDataset shapes: ((None, 231, 3), (None, 10)), types: (tf.float64, tf.float64)>
2023-02-17 13:33:57.712078: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)
Epoch 1/30
2023-02-17 13:34:12.895018: I tensorflow/stream_executor/cuda/cuda_dnn.cc:369] Loaded cuDNN version 8005
1310/1310 - 67s - loss: 22.3201 - accuracy: 0.1045 - val_loss: 757354368.0000 - val_accuracy: 0.1000
Epoch 2/30
1310/1310 - 44s - loss: 21.3708 - accuracy: 0.1198 - val_loss: 317258464.0000 - val_accuracy: 0.1000
Epoch 3/30
1310/1310 - 44s - loss: 20.4662 - accuracy: 0.1460 - val_loss: 50089232.0000 - val_accuracy: 0.1190
Epoch 4/30
1310/1310 - 44s - loss: 19.5705 - accuracy: 0.1738 - val_loss: 56752180.0000 - val_accuracy: 0.1167
Epoch 5/30
1310/1310 - 44s - loss: 18.6980 - accuracy: 0.2097 - val_loss: 31248118.0000 - val_accuracy: 0.1048
Epoch 6/30
1310/1310 - 44s - loss: 17.8547 - accuracy: 0.2476 - val_loss: 503188.3125 - val_accuracy: 0.1071
Epoch 7/30
1310/1310 - 44s - loss: 17.0483 - accuracy: 0.2886 - val_loss: 623102.0000 - val_accuracy: 0.1429
Epoch 8/30
1310/1310 - 44s - loss: 16.2785 - accuracy: 0.3381 - val_loss: 2514874.5000 - val_accuracy: 0.0905
Epoch 9/30
1310/1310 - 44s - loss: 15.5474 - accuracy: 0.3927 - val_loss: 303895.8438 - val_accuracy: 0.0976
Epoch 10/30
1310/1310 - 44s - loss: 14.8597 - accuracy: 0.4460 - val_loss: 725294.3125 - val_accuracy: 0.1024
Epoch 11/30
1310/1310 - 43s - loss: 14.2143 - accuracy: 0.5031 - val_loss: 1745893.1250 - val_accuracy: 0.0881
Epoch 12/30
1310/1310 - 45s - loss: 13.6149 - accuracy: 0.5614 - val_loss: 1540100.3750 - val_accuracy: 0.1286
Epoch 13/30
1310/1310 - 43s - loss: 13.0604 - accuracy: 0.6151 - val_loss: 4525690.0000 - val_accuracy: 0.1071
Epoch 14/30
1310/1310 - 44s - loss: 12.5467 - accuracy: 0.6683 - val_loss: 3689776.5000 - val_accuracy: 0.1119
Epoch 15/30
1310/1310 - 44s - loss: 12.0814 - accuracy: 0.7103 - val_loss: 6925624.0000 - val_accuracy: 0.1286
Epoch 16/30
1310/1310 - 44s - loss: 11.6500 - accuracy: 0.7498 - val_loss: 6733758.0000 - val_accuracy: 0.1190
Epoch 17/30
1310/1310 - 44s - loss: 11.2548 - accuracy: 0.7785 - val_loss: 99263632.0000 - val_accuracy: 0.0952
Epoch 18/30
1310/1310 - 44s - loss: 10.8870 - accuracy: 0.8032 - val_loss: 6472049.0000 - val_accuracy: 0.0714
Epoch 19/30
1310/1310 - 44s - loss: 10.5297 - accuracy: 0.8363 - val_loss: 19297026.0000 - val_accuracy: 0.0976
Epoch 20/30
1310/1310 - 43s - loss: 10.1925 - accuracy: 0.8700 - val_loss: 43124540.0000 - val_accuracy: 0.1167
Epoch 21/30
1310/1310 - 44s - loss: 9.8820 - accuracy: 0.8970 - val_loss: 124067400.0000 - val_accuracy: 0.0833
Epoch 22/30
1310/1310 - 44s - loss: 9.5937 - accuracy: 0.9173 - val_loss: 79799040.0000 - val_accuracy: 0.1000
Epoch 23/30
1310/1310 - 44s - loss: 9.3247 - accuracy: 0.9325 - val_loss: 446929312.0000 - val_accuracy: 0.1262
Epoch 24/30
1310/1310 - 44s - loss: 9.0735 - accuracy: 0.9442 - val_loss: 732394368.0000 - val_accuracy: 0.0976
Epoch 25/30
1310/1310 - 43s - loss: 8.8402 - accuracy: 0.9531 - val_loss: 776593152.0000 - val_accuracy: 0.0810
Epoch 26/30
1310/1310 - 44s - loss: 8.6207 - accuracy: 0.9588 - val_loss: 140933664.0000 - val_accuracy: 0.0905
Epoch 27/30
1310/1310 - 44s - loss: 8.4134 - accuracy: 0.9637 - val_loss: 246708416.0000 - val_accuracy: 0.1214
Epoch 28/30
1310/1310 - 44s - loss: 8.2224 - accuracy: 0.9681 - val_loss: 1348214400.0000 - val_accuracy: 0.1048
Epoch 29/30
1310/1310 - 44s - loss: 8.0444 - accuracy: 0.9707 - val_loss: 136303104.0000 - val_accuracy: 0.1143
Epoch 30/30
1310/1310 - 43s - loss: 7.8739 - accuracy: 0.9740 - val_loss: 709458240.0000 - val_accuracy: 0.0976
add Codeadd Markdown
Frightera commented 1 year ago

Hmm this is strange, I checked the original Keras doc which has extreme loss values in validation too, so it will be hard to converge for a Bayesian NN.

Can you try flipout layers? They should provide less variance in the training if I am not mistaken. They have the same signatures as Reparameterization layers.

Ex: Try changing Convolution1DReparameterization to Convolution1DFlipout and same for Dense layers too. Also it's worth experimenting with RMSprop too.