yell / boltzmann-machines

Boltzmann Machines in TensorFlow with examples
MIT License
849 stars 135 forks source link

Custom metrics #11

Open SrMouraSilva opened 5 years ago

SrMouraSilva commented 5 years ago

Continue discussion https://github.com/monsta-hd/boltzmann-machines/issues/7#issuecomment-469763301

The current lib has a lot of metrics, but in the research is expected try other metrics.

Example:

def my_custom_evaluate_function(metric_name, model, minibatch):
    """Mean of activated units after reconstruction"""
    h_means = model._means_h_given_v(minibatch)
    h0 = self._sample_h_given_v(h_means)
    v_means = model._means_v_given_h(gh)
    v1 = self._sample_v_given_h(v_means)

    with tf.name_scope(metric_name):
        tf.summary.scalar(metric_name, tf.mean(v1, axis=1))  # Maybe axis=1

rbm = BernoulliRBM(n_visible=784, n_hidden=args.n_hidden,
    metrics_config=dict(
        # The default metrics
        msre=True,
        pll=True,
        feg=True,
        train_metrics_every_iter=1000,
        val_metrics_every_epoch=2,
        feg_every_epoch=4,
        n_batches_for_feg=50,
        # New metrics
        my_custom_evaluate=my_custom_evaluate_function
    ),
    verbose=True,
)