j-towns / craystack

Compression tools for machine learning researchers
Other
82 stars 8 forks source link

Craystack (hierarchical) logistic distributions for discrete sample. #7

Closed vatj closed 4 years ago

vatj commented 4 years ago

Hello,

I would like to use craystack to perform lossless compression on audio samples. I have a tensorflow model which takes in discrete values np.int16 mapped to [-1,1] for stability purpose. The model outputs a sample to encode as well as the corresponding parameters of a (mixture of) logistic distribution. I am not interested in the bit-back method but only in using vanilla RANS encoding/decoding.

I am struggling a bit to understand what I have to modify in the examples you provided to adapt it to my case. Since the sample has the same discrete nature as the input, is it easier to use the categorical distribution or the already implemented LogisticMixture? Using it to generate the probabilities using the locs, scales provided by my code. These parameters are set to provide the probabilities for symbol between [-1, 1]. Will it handle properly to have float32 value as symbol or should I map them to integer first? Maybe it has something to do with coding_prec argument in the LogisticMixture_UnifBins codec?

Second issue for me is the decoding. When encoding I use fixed loc, scale for the last part of the sample. When decoding, I need to first decode the last part using the fixed loc, scale parameters. Generate the next loc, scale based on the decoded input. I am guessing this inverse-cascade needs to be specified in the codec using the AutoRegressive defined in codecs.py, but I am a bit unsure as to how to proceed.

Any chance you would have time to help? Please feel free to ask any questions if you need more informations.

Best regards, Victor

vatj commented 4 years ago

My solution so far :

def Waveglow_codec(model, hparams):

  coding_prec = hparams['coding_prec']
  bin_precision = hparams['bin_precision']
  bin_lowerbound = hparams['bin_lower_bound']
  bin_upperbound = hparams['bin_upper_bound']

  def WaveglowLogisticMixture(all_params, block):
    return LogisticMixture_UnifBins(logit_probs=all_params[0], means=all_params[1], log_scales=all_params[2], 
                                    coding_prec=coding_prec, bin_prec=bin_precision, 
                                    bin_lb=bin_lowerbound, bin_ub=bin_upperbound)

  def AutoRegressiveIDF(model, elem_codec):
    """
    Codec for data from distributions which are calculated autoregressively.
    That is, the data can be partitioned into n elements such that the
    distribution/codec for an element is only known when all previous
    elements are known. This does not affect the push step, but does
    affect the pop step, which must be done in sequence (so is slower).
    elem_param_fn maps data to the params for the respective codecs.
    elem_idxs defines the ordering over elements within data.
    We assume that the indices within elem_idxs can also be used to index
    the params from elem_param_fn. These indexed params are then used in
    the elem_codec to actually code each element.
    """
    def push(message, data):
      encodable, all_params = model.infer_craystack(data)
#       tf.print(encodable[0].shape)
      for block in range(model.n_blocks + 1):
        tf.print(f'block {block}')
        elem_params = all_params[block]
        tf.print(elem_params[0].shape)
        elem_push, _ = elem_codec(elem_params, block) # block potentially useless here but good to have the option
        message = elem_push(message, encodable[block].astype('uint16'))
        tf.print(encodable[block])
      return message

    def pop(message):
      elem = None
      for block in reversed(range(model.n_blocks + 1)):
        tf.print(f'block {block}')
        data, all_params = model.generate_craystack(x=None if block + 1 > model.n_blocks else data, 
                                                    z=elem, block=block+1)
        _, elem_pop = elem_codec(all_params=all_params, block=block)
        message, elem = elem_pop(message)
        tf.print(elem)

      data, all_params = model.generate_craystack(x=data, z=elem, block=0)
      return message, data

    return Codec(push, pop)

  return AutoRegressiveIDF(model=model, elem_codec=WaveglowLogisticMixture)

I have tried to encode a random discrete signal which is not distributed according to the logistic mixture and I get 26bits average per dim. On the other hand, I tried encoding a signal which has been sampled from the logistic mixture and I also get something like 26bits-ish per dim. For values that are supposed to be 16bit integer that is unfortunate. I am guessing the 26 comes from the coding_prec which is set at 27. Anything below that ask for a higher precision to rebalance the buckets. I think the issue might have to do with the logistic parameters being provided to predict values in the [-1, 1] range. Any chance you have time to look into that?

vatj commented 4 years ago

After some debugging I have written a minimal example of the issue. It seems to me like the logistic, logistic mixture and gaussian with uniform bins are not functioning properly in the craystack core implementation.

import craystack as cs
import numpy as np
import time

hparams = dict()
# rANS precision
hparams['coding_prec'] = 14
# Bin precision for LogisticMixture_UnifBins codec
hparams['bin_precision'] = 8
# Lower bound for LogisticMixture_UnifBins codec
hparams['bin_lower_bound'] = -1.
# Upper bound for LogisticMixture_UnifBins codec
hparams['bin_upper_bound'] = 1.
# dims
hparams['dims'] = pow(2, 4)
# Encoding shape
hparams['target_shape'] = (1, hparams['dims'])
# Parameter shape
hparams['parameter_shape'] = (1, 10, hparams['dims']) # middle axis is n_logisitic_in_mixture

# Instantiate parameters for logistic distribution. A strong peak around the mean should be enough to encode the audio with very few bits
peak = -3
logits = np.ones(hparams['parameter_shape']) 
means = np.zeros(hparams['parameter_shape'])
log_scales = np.zeros(hparams['parameter_shape']) + peak

# Instantiate data whose value is mean of the logistic distribution
audio = np.zeros(hparams['target_shape']) + pow(2, hparams['bin_precision'] - 1)

print(audio) 

## Encode
encode_t0 = time.time()
init_message = cs.empty_message(shape=hparams['target_shape'])

# Codec
elem_push, elem_pop = cs.LogisticMixture_UnifBins(
  logit_probs=logits, means=means, log_scales=log_scales,
  bin_lb=hparams['bin_lower_bound'], bin_ub=hparams['bin_upper_bound'],
  bin_prec=hparams['bin_precision'], coding_prec=hparams['coding_prec'])
# elem_push, elem_pop = cs.Logistic_UnifBins(
#   means=means, log_scales=log_scales,
#   bin_lb=hparams['bin_lower_bound'], bin_ub=hparams['bin_upper_bound'],
#   bin_prec=hparams['bin_precision'], coding_prec=hparams['coding_prec'])
# elem_push, elem_pop = cs.DiagGaussian_UnifBins(
#   mean=means, stdd=np.exp(log_scales),
#   bin_min=hparams['bin_lower_bound'], bin_max=hparams['bin_upper_bound'],
#   n_bins=pow(2,hparams['bin_precision']), coding_prec=hparams['coding_prec'])

# Encode the mnist images
message = elem_push(init_message, audio.astype('uint64'))

flat_message = cs.flatten(message)
print(flat_message.dtype)
encode_t = time.time() - encode_t0

print("All encoded in {:.2f}s.".format(encode_t))

message_len = 32 * len(flat_message)
print("Used {} bits.".format(message_len))
print("This is {:.2f} bits per dim.".format(message_len / hparams['dims']))

## Decode
decode_t0 = time.time()
message = cs.unflatten(flat_message, shape=hparams['target_shape'])

message, audio_ = elem_pop(message)
decode_t = time.time() - decode_t0

print(f'decoded audio shape : {audio_.shape}')

print('All decoded in {:.2f}s.'.format(decode_t))

np.testing.assert_equal(audio, audio_)
np.testing.assert_equal(message, init_message)

Running this script output a whopping 40 bits per dimension which is clearly wrong. The encoding of an array with value all equal to the means of the distribution should be very efficient. Could you please investigate this issue?

j-towns commented 4 years ago

Hi victor. Sorry for being unresponsive. I’m very busy at the moment writing up my PhD thesis and may not get round to looking at this properly in the next couple of months. It’s possible @tom-bird or @JuliusKunze may be able to take a look.

vatj commented 4 years ago

I solved the issue by looking at the rans test function. Indeed, for one data point it performs very poorly but for additional points performance increases rapidly. There seems to be a fairly constant overhead that becomes rapidly negligible as new datapoints are added. Best of luck for your writing up.