omoindrot / tensorflow-triplet-loss

Implementation of triplet loss in TensorFlow
https://omoindrot.github.io/triplet-loss
MIT License
1.12k stars 284 forks source link

Embeddings Collapse very fast #50

Open AtinAngrish opened 4 years ago

AtinAngrish commented 4 years ago

Hi Omoindrot.

Thanks for the great explanation of triplet loss on your blog. It is one of the few resources I found online which talked about the implementation. I have been trying to simulate a dataset for embedding tests. I am attaching the full code for your reference. Maybe you could have a look if you have some time and let me know what I have been doing wrong:

I am sampling data points from 2 different distributions and wanting to see if I can generate lower dimensional embeddings for both.

`# -*- coding: utf-8 -*-
"""
Created on Tue Sep 24 18:44:04 2019

@author: aangris
"""

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn import datasets
from sklearn.decomposition import PCA
from omnidrot_loss_utils import batch_all_triplet_loss

# import some data to play with
iris = datasets.load_iris()

def load_data():
    X = iris.data #[:, :2]  # we only take the first two features.
    y = iris.target.reshape(-1,1)
    XY_labeled = np.hstack((X,y))
    np.random.shuffle(XY_labeled)
    return XY_labeled[:,:-1] , XY_labeled[:,-1]

def make_shared_branch(input_tensor):
    layer_1_units = 5
    layer_2_units = 3
    with tf.variable_scope('layer1_scope',reuse = tf.AUTO_REUSE):
        weights = tf.get_variable('weights1',
                                  shape = [input_tensor.shape[1].value,layer_1_units],
                                  initializer = tf.truncated_normal_initializer(1/np.sqrt(layer_1_units)))
        biases = tf.get_variable('b1',
                                 initializer=tf.ones_initializer(),
                                 shape = [layer_1_units])
        layer1 = tf.nn.relu(tf.matmul(input_tensor,weights)+biases)

        drop_out1 = tf.nn.dropout(layer1,rate = 0.5)

    with tf.variable_scope("layer2_scope",reuse = tf.AUTO_REUSE):
        weights = tf.get_variable(
            'weights2', 
            shape=[layer_1_units, layer_2_units],
            initializer=tf.truncated_normal_initializer(
                stddev=1.0/np.sqrt(float(layer_2_units))))
        biases = tf.get_variable(
            'biases2', 
            initializer=tf.ones_initializer(),
            shape=[layer_2_units])
        layer2 = tf.nn.relu(tf.add(tf.matmul(drop_out1, weights),biases))
    #this is the only thing that has been changed.
    #sigmoid gives some problems with embeddings sticking to each other
    #this (with relu) gives pretty good results
#    with tf.variable_scope("lambda_layer"):
#        lambda_layer  = tf.nn.l2_normalize(layer2,axis=1)
#        
        return layer2

num_imgs = 200
dims = 10  

anchor = tf.placeholder(tf.float32,(None,dims),"anchor")
labels = tf.placeholder(tf.int32,shape = (None),name = "labels")

learning_rate = tf.placeholder(tf.float32,shape=[])

anchor_branch = make_shared_branch(anchor)

#X,y = load_data()
np.random.seed(1)
X1 = np.random.multivariate_normal(mean = np.array([40]*dims),cov = np.eye(10),size = num_imgs)
X2 = np.random.multivariate_normal(mean = np.array([50]*dims),cov = np.eye(10),size = num_imgs)
X = np.vstack((X1,X2))
y = np.repeat([1,2],200).reshape(-1,1)
XY = np.hstack((X,y))
np.random.shuffle(XY)
X = XY[:,:-1]
y = XY[:,-1]
#test
#trips = generate_triplets(X,y)
loss_mean_collector = []
val_loss_mean_collector = []
train_cutoff = len(X)*0.8
test_cutoff = len(X)*0.2
EPOCHS = 500
lr = 1e-3
#A very high learning rate collapses the embeddings. Try smaller learning rates if you want to prevent collapse
#Seems like starting from 0.1 for triplet loss (NON BPR) is a bad strategy.
#Margin fixed at 0.5 

loss = batch_all_triplet_loss(labels,anchor_branch,margin=0.05,squared=False)
train_step = tf.train.AdamOptimizer(learning_rate = learning_rate).minimize(loss[0])
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
batch_size = 30
from tqdm import tqdm
loss_collector = []
frac_pos =[]
loss_mean_collector = []
with sess.as_default():
    for i in range(EPOCHS):
        for b in tqdm(range(int(train_cutoff//batch_size))):
            X_train= X[b:batch_size+b]
            y_train = y[b:batch_size+b]
            '''   
            shape1, text1, im_labels, shape_text_margin = args[0],args[1],args[2],args[3]
            im_im_labels = args[4]
            shape_margin = args[5]
            '''
            custom_dict = {anchor : X_train,
                           labels : y_train,
                           learning_rate : lr}
            _ = sess.run(train_step,feed_dict = custom_dict)
            loss_value,fraction_pos = sess.run(loss,feed_dict = custom_dict)            
            frac_pos.append(fraction_pos)
            loss_collector.append(loss_value)
        if i%500==0:
            lr = lr*0.5
            print("Epoch #",i)
        print(np.mean(loss_collector))
        print(np.mean(frac_pos))
        loss_mean_collector.append(np.mean(loss_collector))
    frac_pos = []

plt.plot(loss_mean_collector)
def rand_jitter(arr):
    stdev = .01*(max(arr)-min(arr))
    return arr + np.random.randn(len(arr)) * stdev
with sess.as_default():
    Xt,yt = X,y
    newX = sess.run(anchor_branch,feed_dict={anchor : Xt})
    newX2 = newX - np.mean(newX,axis=0)
    fig = plt.figure()
    ax = plt.axes(projection="3d")
#    ax.scatter3D(newX[:,0][y==0],
#                 newX[:,1][y==0],
#                 newX[:,2][y==0],
#                 c = "r",marker = "o")
    #with mean shift
    ax.scatter3D(rand_jitter(newX[:,0][np.where(yt==1)[0]] - np.mean(newX[:,0][np.where(yt==1)[0]])),
                 rand_jitter(newX[:,1][np.where(yt==1)[0]] - np.mean(newX[:,0][np.where(yt==1)[0]])),
                 rand_jitter(newX[:,2][np.where(yt==1)[0]] - np.mean(newX[:,0][np.where(yt==1)[0]])),
                 c = "g",marker = "o")
    ax.scatter3D(rand_jitter(newX[:,0][np.where(yt==2)[0]] - np.mean(newX[:,0][np.where(yt==2)[0]])),
                 rand_jitter(newX[:,1][np.where(yt==2)[0]] - np.mean(newX[:,0][np.where(yt==2)[0]])),
                 rand_jitter(newX[:,2][np.where(yt==2)[0]] - np.mean(newX[:,0][np.where(yt==2)[0]])),
                 c = "b",marker = "^")
    #without mean shift
    fig = plt.figure()
    ax = plt.axes(projection="3d")
    ax.scatter3D(rand_jitter(newX[:,0][np.where(yt==1)[0]]),
                 rand_jitter(newX[:,1][np.where(yt==1)[0]]),
                 rand_jitter(newX[:,2][np.where(yt==1)[0]]),
                 c = "g",marker = "o")
    ax.scatter3D(rand_jitter(newX[:,0][np.where(yt==2)[0]]),# - np.mean(newX[:,0][np.where(yt==2)[0]])),
                 rand_jitter(newX[:,1][np.where(yt==2)[0]]),# - np.mean(newX[:,0][np.where(yt==2)[0]])),
                 rand_jitter(newX[:,2][np.where(yt==2)[0]]),# - np.mean(newX[:,0][np.where(yt==2)[0]])),
                 c = "b",marker = "^")

The hope is that newX parameter is able to allow me to visualize the newly generated embeddings and i can see the differences visually. Of course I could be doing something wrong. If you could have a look at my code and let me know where my mistake is, that would be awesome. Thanks for your time.

Please let me know if I need to share more details.

AA