Closed tonyhqanguyen closed 5 years ago
I just ran this and it was fine, make sure the tf-nightly you are using is fine
import tensorflow as tf
from model import transformer
class Inputs:
def __init__(self, units=512, d_model=256, num_heads=8, num_layers=2,
dropout=0.1, vocab_size=100):
self.num_units = units
self.d_model = d_model
self.num_heads = num_heads
self.num_layers = num_layers
self.dropout = dropout
self.vocab_size = vocab_size
self.activation = "relu"
if __name__ == '__main__':
model = transformer(Inputs())
tf.keras.utils.plot_model(
model, to_file='transformer.png', show_shapes=True)
Thanks! Turns out I accidentally uninstalled tf-nightly without knowing it.
I was just trying to re-write this in PyTorch (I'm more familiar with this). Do you know how I could get some insight of the training loop? For example:
Does each iteration take in a time-step of inputs (i.e. at iteration i
of a training batch, does it take in all the i-th tokens of each sequence, or does the training loop process each sequence separately. Essentially, what would be the shape of the input to each iteration and what would each shape dimension represent?
Would the output shape be (batch size, d_model)? If so, what would this mean?
I was looking at this tutorial here: https://pytorch.org/tutorials/beginner/chatbot_tutorial.html, and the input of each iteration for each training batch is the entire time step, so first iteration gets all the first tokens of all the sequences, and so on. And the output is of shape (batch size, vocab_size) where for each line i
of the matrix, is the array of probabilities of the output being each word in the vocabulary. I just wasn't sure how to interpret the inputs and outputs here.
The drawback of the other model is that it doesn't implement the attention mechanism from the Attention is All You Need paper. The attention mechanism it uses in the decoder seems different, and it doesn't use it in the encoder at all. I was hoping to get some insight from you if you don't mind.
Hi, sorry to bother you again. I was trying out and playing around with your code, and when I try to initialize an instance of a transformer, I get the error as titled. Below is how I called your transformer and the error.
Error:
Do you have any idea what can be causing this? Thanks.