Bihaqo / t3f

Tensor Train decomposition on TensorFlow
https://t3f.readthedocs.io/en/latest/index.html
MIT License
218 stars 56 forks source link

Support eager mode with Rieamannian autodiff #193

Closed Bihaqo closed 4 years ago

Bihaqo commented 4 years ago

195

faysou commented 4 years ago

I have a strange error in deltas_to_tangent_space when using t3f.gradients and doing my tensor completion with the rank increased incrementally, at some point for a rank of [1 4 5 4 1] the cores that should be returned don't have a correct structure so the code stops. I get this for the dimensions of the cores.

[TensorShape([Dimension(1), Dimension(4), Dimension(8)]), TensorShape([Dimension(8), Dimension(4), Dimension(10)]), TensorShape([Dimension(10), Dimension(4), Dimension(9)]), TensorShape([Dimension(8), Dimension(4), Dimension(1)])]

I didn't have this error when I was using the other way to compute the gradient by projection.

I'll try to reproduce the bug in a small example.

faysou commented 4 years ago

Here is a self contained example that fails. I hope this will help you.

import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
tf.enable_eager_execution()
import t3f

cores = [np.array([[[ 5.6620991e+01, -7.8090101e-01, -4.4375276e-03,  5.0257673e-03],
         [ 6.1290928e+01, -4.8074251e-01, -5.3017284e-04, -6.8296650e-03],
         [ 7.1359612e+01,  2.5267056e-01,  1.5182286e-02,  1.4828241e-03],
         [ 7.6741325e+01,  7.2516447e-01, -1.0420068e-02,  3.6771200e-04]]],
       dtype="float32"),
 np.array([[[-4.22641635e-01,  4.77946457e-03,  4.26539964e-06,
          -1.49097632e-05,  3.75683976e-06],
         [-4.57700491e-01,  2.88415444e-03, -1.25210572e-04,
           1.07652795e-05,  1.41972930e-06],
         [-5.32576919e-01, -1.61362393e-03, -3.11460935e-05,
          -1.02584318e-05, -4.71308658e-06],
         [-5.72881937e-01, -4.32803482e-03,  1.26194151e-04,
           1.23360305e-05,  5.04817308e-07]],

        [[-3.78389478e-01, -3.52809697e-01, -3.90482717e-03,
          -4.33888612e-03,  1.34184942e-04],
         [-2.43150383e-01, -3.88812065e-01, -1.38323745e-02,
          -4.97070723e-04, -1.49993954e-04],
         [ 1.34765476e-01, -4.31203961e-01, -2.91475654e-03,
          -2.53561581e-03,  2.78464344e-04],
         [ 3.47801298e-01, -4.43502545e-01, -5.71238156e-03,
           4.18569148e-03,  5.30998979e-04]],

        [[ 6.00875951e-02,  1.07708059e-01, -1.79333284e-01,
           2.60745268e-02, -1.06033236e-02],
         [-6.23276293e-01,  3.24752659e-01, -3.60669166e-01,
          -6.24450780e-02, -1.56603419e-04],
         [ 1.58923477e-01,  1.76855788e-01, -3.14616978e-01,
           1.09622560e-01, -1.24879237e-02],
         [ 3.07694256e-01,  4.79354113e-02,  1.92505583e-01,
          -1.46695822e-01,  2.59475149e-02]],

        [[ 2.22208619e-01, -4.54327077e-01, -1.71980724e-01,
          -1.58617407e-01, -7.33328378e-03],
         [-6.74762726e-02,  1.27745539e-01, -3.32307726e-01,
           1.31681010e-01, -5.37965726e-03],
         [-1.51836932e-01, -4.11450416e-01,  3.07844400e-01,
          -5.63377962e-02, -8.76249447e-02],
         [ 2.55400240e-02,  4.83355403e-01, -6.41870350e-02,
           3.21240574e-02, -5.78595772e-02]]], dtype="float32"),
 np.array([[[-4.22633171e-01,  3.34328739e-03,  5.11623621e-05,
           2.85330589e-06, -1.31908640e-09],
         [-4.57639784e-01,  2.04730919e-03, -2.44270268e-05,
           2.24198538e-06,  4.04761824e-09],
         [-5.32588124e-01, -1.02958886e-03, -5.59781654e-07,
          -8.39729000e-06,  2.23843566e-09],
         [-5.72949052e-01, -3.14126676e-03, -1.66910104e-05,
           3.98870952e-06,  3.29622973e-09]],

        [[-4.69937772e-01, -3.12006742e-01, -6.42035203e-03,
          -9.31801682e-04,  1.28598403e-08],
         [-2.85116911e-01, -3.33732754e-01, -1.14002172e-03,
           9.51640599e-04, -1.50699844e-08],
         [ 1.60229728e-01, -3.55661392e-01,  9.87656694e-03,
          -8.00549111e-04,  6.76460754e-09],
         [ 4.25231278e-01, -3.94975513e-01,  1.42964988e-03,
           9.01863677e-04, -3.82008958e-09]],

        [[ 6.22857690e-01,  2.40874007e-01,  6.73491210e-02,
           3.36995721e-02, -2.24258945e-09],
         [-6.07307971e-01, -1.08960025e-01, -3.84900812e-03,
           2.13157758e-02, -3.02249736e-09],
         [ 4.64363992e-02, -4.80658673e-02, -6.62848726e-02,
          -3.80313350e-03, -3.75147113e-09],
         [-1.44542158e-02, -3.57691258e-01, -1.73237696e-01,
           1.24059096e-02, -1.22619621e-08]],

        [[ 1.09044366e-01,  1.20226040e-01, -3.24483782e-01,
          -2.27006480e-01,  1.83339199e-09],
         [ 2.79833108e-01,  2.88795441e-01, -2.59605736e-01,
           3.50165814e-02,  1.67094303e-08],
         [-2.81293273e-01, -5.32351732e-01,  2.83058286e-01,
          -8.82405639e-02, -5.61330205e-10],
         [-3.79741676e-02, -3.32546711e-01, -9.93924364e-02,
           1.19003721e-01, -1.38504277e-11]],

        [[ 7.02710301e-02, -1.77677590e-02,  3.94130349e-01,
          -1.69657379e-01, -6.87299018e-09],
         [-5.51697984e-03,  1.81056783e-02,  5.88977784e-02,
           1.21434540e-01, -1.17474541e-09],
         [-5.76172292e-01,  7.29691088e-02,  6.92215189e-03,
          -3.96031290e-01,  4.66166439e-09],
         [ 4.87267405e-01,  1.38517559e-01, -1.36327237e-01,
          -1.50535524e-01, -3.70242437e-09]]], dtype="float32"),
 np.array([[[ 4.2267996e-01],
         [ 4.5752168e-01],
         [ 5.3272551e-01],
         [ 5.7290399e-01]],

        [[ 6.5505135e-01],
         [ 4.0487936e-01],
         [-2.2581206e-01],
         [-5.9664857e-01]],

        [[-6.1988062e-01],
         [ 6.8176746e-01],
         [ 2.3720609e-01],
         [-3.0769211e-01]],

        [[ 8.9428566e-02],
         [-4.0241730e-01],
         [ 7.8034973e-01],
         [-4.7023094e-01]],

        [[-4.5380406e-09],
         [ 4.0326453e-09],
         [-9.1800478e-09],
         [ 2.5249662e-09]]], dtype="float32")]

training_idx = np.array([[2, 2, 2, 2],
       [0, 0, 0, 0],
       [3, 3, 0, 0],
       [0, 3, 0, 3],
       [0, 3, 3, 0],
       [3, 0, 0, 2],
       [0, 0, 3, 2],
       [3, 0, 2, 0],
       [0, 0, 0, 3],
       [3, 0, 3, 3],
       [3, 3, 3, 0],
       [3, 3, 0, 3],
       [0, 3, 1, 0],
       [1, 3, 3, 3],
       [3, 0, 0, 0],
       [0, 0, 3, 1],
       [3, 3, 3, 3],
       [1, 0, 0, 3],
       [2, 3, 0, 0],
       [3, 0, 3, 3],
       [0, 1, 0, 0],
       [0, 3, 3, 2],
       [3, 2, 3, 0],
       [3, 1, 0, 3],
       [0, 2, 0, 3],
       [0, 1, 3, 3],
       [1, 0, 2, 0],
       [3, 3, 0, 2],
       [0, 3, 2, 0],
       [3, 0, 1, 0],
       [3, 0, 3, 3],
       [3, 3, 3, 2],
       [0, 1, 0, 0],
       [0, 0, 1, 3],
       [3, 3, 1, 0],
       [2, 3, 0, 3],
       [1, 3, 3, 0],
       [1, 0, 3, 3],
       [0, 3, 1, 2],
       [3, 1, 2, 0],
       [0, 0, 1, 0],
       [3, 2, 0, 0],
       [2, 0, 1, 3],
       [2, 3, 3, 0],
       [1, 0, 3, 1],
       [3, 3, 1, 3],
       [0, 2, 0, 1],
       [0, 1, 0, 3],
       [1, 2, 3, 3],
       [1, 3, 3, 0],
       [2, 0, 3, 3],
       [3, 1, 0, 1],
       [0, 0, 0, 2],
       [3, 3, 0, 1],
       [0, 1, 2, 0],
       [1, 3, 2, 3],
       [0, 3, 1, 1],
       [3, 1, 0, 2],
       [3, 2, 3, 1],
       [0, 1, 0, 0],
       [1, 0, 3, 1],
       [3, 2, 2, 3],
       [3, 0, 2, 2],
       [0, 3, 2, 2],
       [1, 2, 0, 0],
       [2, 1, 0, 3],
       [0, 0, 2, 2],
       [2, 2, 3, 0],
       [3, 3, 2, 2],
       [1, 0, 3, 1],
       [1, 3, 1, 0],
       [3, 0, 2, 2],
       [0, 1, 1, 3],
       [1, 0, 1, 0],
       [0, 2, 3, 1],
       [3, 2, 0, 1],
       [2, 2, 3, 3],
       [1, 1, 0, 3],
       [3, 3, 1, 2],
       [1, 1, 3, 3],
       [2, 2, 3, 0],
       [1, 0, 1, 0],
       [3, 2, 1, 0],
       [3, 1, 3, 2],
       [0, 2, 2, 0],
       [1, 1, 0, 3],
       [3, 3, 1, 2],
       [1, 1, 3, 3],
       [2, 0, 1, 1],
       [0, 2, 2, 0],
       [2, 2, 3, 3],
       [0, 0, 2, 2],
       [2, 3, 3, 1],
       [0, 0, 1, 1],
       [3, 3, 1, 1],
       [3, 1, 2, 3],
       [0, 1, 1, 3],
       [1, 2, 3, 0],
       [1, 2, 1, 3],
       [3, 1, 2, 1]], dtype="int64")

training_vals = np.array([10.726894 ,  3.9559312,  7.9788456,  7.9788456,  7.9788456,
        7.3694243,  7.3694243,  7.3694243,  5.726894 , 10.726894 ,
       10.726894 , 10.726894 ,  6.243689 , 11.490298 ,  5.726894 ,
        6.243689 , 13.955931 ,  6.243689 ,  7.3694243, 10.726894 ,
        4.3554955,  9.99369  ,  9.99369  ,  8.619424 ,  7.3694243,
        8.619424 ,  5.726894 ,  9.99369  ,  7.3694243,  6.243689 ,
       10.726894 , 13.105495 ,  4.3554955,  6.243689 ,  8.619424 ,
        9.99369  ,  8.619424 ,  8.619424 ,  7.9788456,  7.9788456,
        4.3554955,  7.3694243,  7.9788456,  9.99369  ,  6.791099 ,
       11.490298 ,  5.726894 ,  6.243689 , 10.726894 ,  8.619424 ,
        9.99369  ,  6.791099 ,  5.240298 ,  8.619424 ,  5.726894 ,
       10.726894 ,  6.791099 ,  7.9788456, 10.726894 ,  4.3554955,
        6.791099 , 12.283375 ,  9.2911   ,  9.2911   ,  5.726894 ,
        7.9788456,  6.791099 ,  9.2911   , 12.283375 ,  6.791099 ,
        6.791099 ,  9.2911   ,  6.791099 ,  4.7833753,  7.9788456,
        7.9788456, 12.283375 ,  6.791099 , 10.726894 ,  9.2911   ,
        9.2911   ,  4.7833753,  7.9788456, 10.726894 ,  6.791099 ,
        6.791099 , 10.726894 ,  9.2911   ,  6.243689 ,  6.791099 ,
       12.283375 ,  6.791099 , 10.726894 ,  4.7833753,  9.2911   ,
       10.726894 ,  6.791099 ,  7.9788456,  8.619424 ,  8.619424 ],
      dtype="float32")

X = t3f.TensorTrain(cores)

def loss(x):
    estimated_vals = t3f.gather_nd(x, training_idx)
    return 0.5 * tf.reduce_sum((training_vals - estimated_vals) ** 2)

gradF = t3f.gradients(loss, X, runtime_check=False)
Bihaqo commented 4 years ago

Hi, thanks for a self-contained example! It will take some time to fix it, but for now I would recommend the following workaround.

Your TensorTrain X has shape [4, 4, 4, 4] and TT-ranks [1, 4, 5, 5, 1]. The maximal sensible TT-rank for a tensor of this shape is [1, 4, 16, 4, 1]. In your example that last 5 is higher than the corresponding 4. If one of the TT-ranks is higher than that (in elementwise way) than you a) trigger the bug and b) use extra memory and computation for nothing, since growing the rank above this maximum doesn't give you any extra capacity.

So my workaround would be to avoid using excessive TT-ranks for now. One (not very fast) way of trimming the ranks is to orthogonalize the tensor like this:

X = t3f.orthogonalize_tt_cores(X)
X = t3f.orthogonalize_tt_cores(X, left_to_right=False)

If you do that before computing the t3f.gradients, your example works.

faysou commented 4 years ago

Great, thank you Alexander. I really admire tensor trains and your work on it. I've seen that you are converting your library for tensorflow 2, it's great, this will allow your library to keep up to date and be useful for non specialists on tensor trains, and I don't think tensorflow will do breaking changes as big as now in the future.

Thanks to your workaround my code with t3f.gradients now works. I don't think orthogonalizing cores takes too much time anyway, cores are often small.

faysou commented 4 years ago

Other question, is there any way to know "the maximal sensible TT-rank for a tensor" of a given shape ?

Bihaqo commented 4 years ago

For a tensor of size n_1, ..., n_d with TT-ranks r_0, r_1, ..., r_d (where r_0 = r_d = 1) r_k should be less than or equal to min(n_1 ... nk, n{k+1} ... n_d).

I have a really old pull request which adds a function which computes that into the library and I was too lazy to properly finish. See here: https://github.com/Bihaqo/t3f/blob/prune_ranks/t3f/utils.py#L59

faysou commented 4 years ago

Ok thank you

Bihaqo commented 4 years ago

Done in #201