lmnt-com / haste

Haste: a fast, simple, and open RNN library
Apache License 2.0
325 stars 27 forks source link

HASTE produces wrong gradients on K80 device #17

Closed amurashov closed 4 years ago

amurashov commented 4 years ago

I have tested HASTE on two different instance types on AWS (for reproducibility):

p2.xlarge (K80 instance) p3.2xlarge (V100 instance)

Both instances were using stock Deep Learning AMI (Amazon Linux 2) Version 29.0 - ami-0b0b075706e19de29

Following sequence of commands was used to install the HASTE:

(0) Change symlink of /usr/local/cuda to point from /usr/local/cuda-10.0 to /usr/local/cuda-10.1 (see another issue that without this HASTE does not install properly). (1) source activate tensorflow2_p36 (2) git clone https://github.com/lmnt-com/haste (3) cd haste (4) make haste_tf (5) pip install haste_tf-*.whl

then from jupyter notebook the following:

%env CUDA_VISIBLE_DEVICES=0
import numpy as np
import pickle
import tensorflow as tf
#gpus = tf.config.experimental.list_physical_devices('GPU')
#tf.config.experimental.set_memory_growth(gpus[0], True)
import haste_tf as haste
from tensorflow.python.keras import layers as L
from tensorflow.python.keras import backend as K

embedding_size = 100 #n_channels
lstm_nunits = 200
ntimestamps = 300
batch_size = 16

class HasteLSTM(tf.keras.layers.Layer):
    def __init__(self, num_units, dropout, zoneout, shape):
      super(HasteLSTM, self).__init__()
      self.haste_lstm = haste.LSTM(num_units = num_units, dropout = dropout, zoneout = zoneout, direction='unidirectional')
      self.haste_lstm.build(shape)

    def call(self, inputs, training):
       return self.haste_lstm(inputs, training = training)

haste_lstm = HasteLSTM(lstm_nunits, 0.00, 0.00, [batch_size, ntimestamps, embedding_size])

#not really a CuDNN but a normal LSTM, so number of parameters matches
cudnn_lstm = L.LSTM(lstm_nunits, return_sequences = True, unit_forget_bias = False)

dummy_input  = tf.random.normal([batch_size, ntimestamps, embedding_size])
dummy_target = np.zeros(shape=(batch_size, ntimestamps, lstm_nunits))

for i in range(dummy_target.shape[0]):
    for j in range(dummy_target.shape[1]):
        dummy_target[i,j,np.random.randint(0, lstm_nunits)] = 1 #one in random position for each timestamp

input_ = L.Input(shape = [ntimestamps, embedding_size])
model_ = haste_lstm(input_, training = True)
if isinstance(model_, tuple): model_ = model_[0] #take only output, no states
model_ = K.softmax(model_) #simple classificiton task

model_haste = tf.keras.Model(inputs=input_, outputs=model_, name='haste_model')

input_ = L.Input(shape = [ntimestamps, embedding_size])
model_ = cudnn_lstm(input_, training = True)
if isinstance(model_, tuple): model_ = model_[0] #take only output, no states
model_ = K.softmax(model_) #simple classification task

model_cudnn = tf.keras.Model(inputs=input_, outputs=model_, name='cudnn_model')

total_trainable = 0
haste_trainable = []
for w in haste_lstm.haste_lstm.trainable_variables:
    K.set_value(w, np.zeros_like(w.numpy()))
    haste_trainable.append(w)
    total_trainable += w.numpy().flatten().shape[0]
print("HASTE has total %d trainable variables!" % total_trainable)

total_trainable = 0
cudnn_trainable = []
for w in cudnn_lstm.trainable_weights:
    K.set_value(w, np.zeros_like(w.numpy()))
    cudnn_trainable.append(w)
    total_trainable += w.numpy().flatten().shape[0]
print("CuDNN has total %d trainable variables!" % total_trainable)

#check HASTE gradients on the dummy example
with tf.GradientTape() as tape:
    prediction = model_haste(dummy_input, training=True)
    loss = tf.keras.losses.categorical_crossentropy(dummy_target, prediction)

gradients = tape.gradient(loss, haste_trainable)

print("HASTE maxabs of each grad:")
for grad in gradients:
    print (np.max(np.abs(grad)))

print("Non-HASTE maxabs of each grad:")
#check CuDNN (actually - plain LSTM) gradients on the dummy example
with tf.GradientTape() as tape:
    prediction = model_cudnn(dummy_input, training=True)
    loss = tf.keras.losses.categorical_crossentropy(dummy_target, prediction)

gradients = tape.gradient(loss, cudnn_trainable)
for grad in gradients:
    print (np.max(np.abs(grad)))

On p2.xlarge (K80) the following is the output:

env: CUDA_VISIBLE_DEVICES=0 HASTE has total 240800 trainable variables! CuDNN has total 240800 trainable variables! HASTE maxabs of each grad: 0.0 0.0 Non-HASTE maxabs of each grad: 6.3259706 0.0 7.397908

On p3.2xlarge (V100) the following is the output:

env: CUDA_VISIBLE_DEVICES=0 HASTE has total 240800 trainable variables! CuDNN has total 240800 trainable variables! HASTE maxabs of each grad: 7.004616 6.2311497 Non-HASTE maxabs of each grad: 6.231148 0.0 7.0048447

Gradients appear to be broken on K80 device.

sharvil commented 4 years ago

I double-checked the compute capability of the K80 and it's actually 3.7, not >= 6.0. Since only atomicAdd<double> needed the higher compute capability, I've implemented it in terms of compare-and-swap and lowered the requirements so you can use it on the K80. Could you please try it out and let me know if it works for you?

amurashov commented 4 years ago

Confirmed working! //may be a a good idea to update the docs?

PS Install still required some hacking with Makefile paths to compilers, cuda-libraries, etc. I am fine with this, but some potential users might get turned off by this. // may be a good idea to make a note in the docs that Makefile might require fine-tuning to get up and running? On my systems (including the standard AWS with out-of-the-box TF envs) some voodoo magic was needed for this to compile.