Open KeremTurgutlu opened 1 year ago
I used the flaxformer library to create the flax module but now getting another error.
SCALE = 1.0
HEAD_DIM = 256
NUM_ENCODER_LAYERS = 32
NUM_DECODER_LAYERS = 32
EMBED_DIM = 4096
MLP_DIM = 16384
NUM_HEADS = 16
VOCAB_SIZE = 32128
DROPOUT_RATE = 0.0
ACTIVATIONS = ('swish', 'linear')
def encoder_decoder(embedding_dim,
mlp_dim,
num_heads,
num_encoder_layers,
num_decoder_layers,
head_dim=HEAD_DIM,
vocabulary_size=VOCAB_SIZE,
dropout_rate=DROPOUT_RATE,
activations=ACTIVATIONS,
dtype=jnp.bfloat16):
"""Create a T5-1.1 style encoder-decoder stack.
Args:
embedding_dim: The size of the embedding for this stack.
mlp_dim: The dimension of the multilayer perceptron.
num_heads: The number of attention heads.
num_encoder_layers: The number of encoder layers to create.
num_decoder_layers: The number of decoder layers to create.
head_dim: The dimension of the attention head.
vocabulary_size: The size of the embedding vocabulary.
dropout_rate: The dropout rate. Set to 0.0 to turn off dropout.
activations: The activations to use for the MLP.
dtype: The dtype for all layers in this encoder-decoder.
Returns:
A T5-style encoder-decoder.
"""
# T5 1.1 has decoupled embeddings, so we create a separate output logits
# factory.
output_logits_factory = functools.partial(
dense.DenseGeneral,
use_bias=False,
features=vocabulary_size,
dtype='float32',
kernel_init=t5_common_layers.MLP_KERNEL_INIT,
bias_init=t5_common_layers.BIAS_INIT,
)
decoder_factory = functools.partial(
t5_common_layers.decoder,
num_heads=num_heads,
head_dim=head_dim,
mlp_dim=mlp_dim,
num_layers=num_decoder_layers,
dropout_rate=dropout_rate,
activations=activations,
output_logits_factory=output_logits_factory,
dtype=dtype)
encoder_factory = functools.partial(
t5_common_layers.encoder,
num_heads=num_heads,
head_dim=head_dim,
mlp_dim=mlp_dim,
num_layers=num_encoder_layers,
dropout_rate=dropout_rate,
activations=activations,
dtype=dtype)
embedding_factory = functools.partial(
t5_common_layers.embedding,
vocabulary_size=vocabulary_size,
embedding_dim=embedding_dim,
dtype=dtype)
return t5_architecture.EncoderDecoder(
encoder_factory=encoder_factory,
decoder_factory=decoder_factory,
shared_token_embedder_factory=embedding_factory,
dtype=dtype) # pytype: disable=wrong-keyword-args
module = encoder_decoder(EMBED_DIM, MLP_DIM, NUM_HEADS, NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS)
input_vocabulary=t5.data.get_default_vocabulary()
output_vocabulary=t5.data.get_default_vocabulary()
optimizer=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0,
logical_factor_rules=t5x.adafactor.standard_logical_factor_rules())
decode_fn=functools.partial(t5x.decoding.temperature_sample, temperature=1.0, topk=40)
scale_rules = [('relpos_bias', False), ('.*', True)]
optimizer = t5x.adafactor.Adafactor(step_offset=0,
multiply_by_parameter_scale=t5x.adafactor.HParamMap(scale_rules))
model = t5x.models.EncoderDecoderModel(
module=module,
input_vocabulary=input_vocabulary,
output_vocabulary=output_vocabulary,
optimizer_def=optimizer,
decode_fn=decode_fn)
checkpoint_path='gs://scenic-bucket/ul2/ul220b/checkpoint_2650000'
dtype='bfloat16'
restore_mode='specific'
partitioner=t5x.partitioning.PjitPartitioner(
num_partitions=8,
model_parallel_submesh=None)
restore_ckpt_config = t5x.utils.RestoreCheckpointConfig(
dtype=dtype, mode=restore_mode, path=checkpoint_path, use_gda=False)
# Randomly initialize model
batch_size=8
task_feature_lengths = {'inputs': 512, 'targets': 512}
output_dir='/tmp/output_dir'
input_shapes = {
'encoder_input_tokens': np.array([8, 512]),
'decoder_target_tokens': np.array([8, 512]),
'decoder_input_tokens': np.array([8, 512]),
'decoder_loss_weights': np.array([8, 512])
}
interactive_model = InteractiveModel(
batch_size=batch_size,
task_feature_lengths=task_feature_lengths,
output_dir=output_dir,
partitioner=partitioner,
model=model,
dtype=dtype,
restore_mode=restore_mode,
checkpoint_path=checkpoint_path,
input_shapes=input_shapes
)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
File ~/t5x/t5x/partitioning.py:903, in PjitPartitioner.get_mesh_axes.<locals>._logical_to_mesh_axes(param_name, logical_axes)
902 try:
--> 903 return flax_partitioning.logical_to_mesh_axes(logical_axes,
904 self._logical_axis_rules)
905 except ValueError as e:
File ~/t5_venv/lib/python3.9/site-packages/flax/linen/spmd.py:172, in logical_to_mesh_axes(array_dim_names, rules)
140 """Compute layout for an array.
141
142 The rules are in order of precedence, and consist of pairs:
(...)
170 PartitionSpec for the parameter.
171 """
--> 172 result = _logical_to_mesh_axes(array_dim_names, rules)
173 if result is None:
File ~/t5_venv/lib/python3.9/site-packages/flax/linen/spmd.py:120, in _logical_to_mesh_axes(array_dim_names, rules)
119 if dups:
--> 120 raise ValueError(
121 f'Unsupported: Dimensions {dups} occur more than once in array names.')
122 if not isinstance(rules, (tuple, list)):
ValueError: Unsupported: Dimensions ('unmodeled',) occur more than once in array names.
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last)
Cell In[80], line 23
15 output_dir='/tmp/output_dir'
16 input_shapes = {
17 'encoder_input_tokens': np.array([8, 512]),
18 'decoder_target_tokens': np.array([8, 512]),
19 'decoder_input_tokens': np.array([8, 512]),
20 'decoder_loss_weights': np.array([8, 512])
21 }
---> 23 interactive_model = InteractiveModel(
24 batch_size=batch_size,
25 task_feature_lengths=task_feature_lengths,
26 output_dir=output_dir,
27 partitioner=partitioner,
28 model=model,
29 dtype=dtype,
30 restore_mode=restore_mode,
31 checkpoint_path=checkpoint_path,
32 input_shapes=input_shapes
33 )
File ~/t5x/t5x/interactive_model.py:191, in InteractiveModel.__init__(self, batch_size, task_feature_lengths, output_dir, partitioner, model, dtype, restore_mode, checkpoint_path, input_shapes, input_types, init_random_seed, add_eos, eval_names)
188 self._restore_checkpoint_cfg = None
189 self._save_checkpoint_cfg = utils.SaveCheckpointConfig(
190 dtype=dtype, keep=5, save_dataset=False, use_gda=False, period=1000)
--> 191 self._train_state_initializer = utils.TrainStateInitializer(
192 optimizer_def=self._model.optimizer_def,
193 init_fn=self._model.get_initial_variables,
194 input_shapes=self._input_shapes,
195 input_types=self._input_types,
196 partitioner=self._partitioner)
198 # Initialize checkpoint manager.
199 self._checkpoint_manager = utils.LegacyCheckpointManager(
200 save_cfg=self._save_checkpoint_cfg,
201 restore_cfg=self._restore_checkpoint_cfg,
(...)
206 model_dir=self._output_dir,
207 use_gda=False)
File ~/t5x/t5x/utils.py:1152, in TrainStateInitializer.__init__(self, optimizer_def, init_fn, input_shapes, partitioner, input_types)
1148 self._partitioner = partitioner
1149 self.global_train_state_shape = jax.eval_shape(
1150 initialize_train_state, rng=jax.random.PRNGKey(0)
1151 )
-> 1152 self.train_state_axes = partitioner.get_mesh_axes(
1153 self.global_train_state_shape
1154 )
1155 self._initialize_train_state = initialize_train_state
1157 # Currently scanned layers require passing annotations through to the
1158 # point of the scan transformation to resolve an XLA SPMD issue.
1159
1160 # init_fn is always(?) equal to model.get_initial_variables, fetch the model
1161 # instance from the bound method.
File ~/t5x/t5x/partitioning.py:910, in PjitPartitioner.get_mesh_axes(self, train_state)
906 raise ValueError(f'Failed to map logical axes for {param_name}') from e
908 flat_logical_axes = traverse_util.flatten_dict(
909 logical_axes.state_dict(), keep_empty_nodes=True, sep='/')
--> 910 flat_mesh_axes = {
911 k: _logical_to_mesh_axes(k, v) for k, v in flat_logical_axes.items()
912 }
914 return logical_axes.restore_state(
915 traverse_util.unflatten_dict(flat_mesh_axes, sep='/'))
File ~/t5x/t5x/partitioning.py:911, in <dictcomp>(.0)
906 raise ValueError(f'Failed to map logical axes for {param_name}') from e
908 flat_logical_axes = traverse_util.flatten_dict(
909 logical_axes.state_dict(), keep_empty_nodes=True, sep='/')
910 flat_mesh_axes = {
--> 911 k: _logical_to_mesh_axes(k, v) for k, v in flat_logical_axes.items()
912 }
914 return logical_axes.restore_state(
915 traverse_util.unflatten_dict(flat_mesh_axes, sep='/'))
File ~/t5x/t5x/partitioning.py:906, in PjitPartitioner.get_mesh_axes.<locals>._logical_to_mesh_axes(param_name, logical_axes)
903 return flax_partitioning.logical_to_mesh_axes(logical_axes,
904 self._logical_axis_rules)
905 except ValueError as e:
--> 906 raise ValueError(f'Failed to map logical axes for {param_name}') from e
ValueError: Failed to map logical axes for target/decoder/logits_dense/kernel
Okay, after some trial and error I think I am able to load the UL2 checkpoints as an t5x InteractiveModel
to do inference with it. I used the utilities from flaxformer package and just modified the Decoder to not use a scaled layernorm which is the decoder_norm
seen in first error message.
def decoder(num_heads,
head_dim,
mlp_dim,
num_layers,
shared_token_embedder,
dropout_rate,
activations,
output_logits_factory=None,
dtype=jnp.bfloat16):
"""Create a standard decoder for T5-style architectures."""
decoder_layer_factory = functools.partial(
decoder_layer,
num_heads=num_heads,
head_dim=head_dim,
mlp_dim=mlp_dim,
activations=activations,
dropout_rate=dropout_rate,
dtype=dtype)
relative_position_bias_factory = functools.partial(
relative_position_bias, num_heads=num_heads, dtype=dtype)
dropout_factory = functools.partial(dropout, rate=dropout_rate)
return t5_architecture.Decoder(
layer_factory=decoder_layer_factory,
dropout_factory=dropout_factory,
layer_norm_factory=layer_norm_factory,
num_layers=num_layers,
shared_token_embedder=shared_token_embedder,
shared_relative_position_bias_factory=relative_position_bias_factory,
output_logits_factory=output_logits_factory,
dtype=dtype) # pytype: disable=wrong-keyword-args
def encoder_decoder(embedding_dim,
mlp_dim,
num_heads,
num_encoder_layers,
num_decoder_layers,
head_dim,
vocabulary_size=VOCAB_SIZE,
dropout_rate=DROPOUT_RATE,
activations=ACTIVATIONS,
dtype=jnp.bfloat16):
"""Create a T5-1.0 style encoder-decoder stack.
Args:
embedding_dim: The size of the embedding for this stack.
mlp_dim: The dimension of the multilayer perceptron.
num_heads: The number of attention heads.
num_encoder_layers: The number of encoder layers to create.
num_decoder_layers: The number of decoder layers to create.
head_dim: The dimension of the attention head.
vocabulary_size: The size of the embedding vocabulary.
dropout_rate: The dropout rate. Set to 0.0 to turn off dropout.
activations: The activations to use for the MLP.
dtype: The dtype for all layers in this encoder-decoder.
Returns:
A T5-style encoder-decoder.
"""
decoder_factory = functools.partial(
t5_common_layers.decoder,
num_heads=num_heads,
head_dim=head_dim,
mlp_dim=mlp_dim,
num_layers=num_decoder_layers,
dropout_rate=dropout_rate,
activations=activations,
dtype=dtype)
encoder_factory = functools.partial(
t5_common_layers.encoder,
num_heads=num_heads,
head_dim=head_dim,
mlp_dim=mlp_dim,
num_layers=num_encoder_layers,
dropout_rate=dropout_rate,
activations=activations,
dtype=dtype)
embedding_factory = functools.partial(
t5_common_layers.embedding,
vocabulary_size=vocabulary_size,
embedding_dim=embedding_dim,
dtype=dtype)
return t5_architecture.EncoderDecoder(
encoder_factory=encoder_factory,
decoder_factory=decoder_factory,
shared_token_embedder_factory=embedding_factory,
dtype=dtype) # pytype: disable=wrong-keyword-args
def layer_norm_factory(*args, **kwargs):
return layer_norm.T5LayerNorm(*args, **kwargs, use_scale=False)
SCALE = 1.0
HEAD_DIM = 256
NUM_ENCODER_LAYERS = 32
NUM_DECODER_LAYERS = 32
EMBED_DIM = 4096
MLP_DIM = 16384
NUM_HEADS = 16
VOCAB_SIZE = 32128
DROPOUT_RATE = 0.0
ACTIVATIONS = ('swish', 'linear')
checkpoint_path=<UL2 checkpoint path>
batch_size=8
task_feature_lengths = {'inputs': 512, 'targets': 512}
output_dir='/tmp/output_dir'
input_shapes = {
'encoder_input_tokens': np.array([8, 512]),
'decoder_target_tokens': np.array([8, 512]),
'decoder_input_tokens': np.array([8, 512]),
'decoder_loss_weights': np.array([8, 512])
}
interactive_model = InteractiveModel(
batch_size=batch_size,
task_feature_lengths=task_feature_lengths,
output_dir=output_dir,
partitioner=partitioner,
model=model,
dtype=dtype,
restore_mode=restore_mode,
checkpoint_path=checkpoint_path,
input_shapes=input_shapes
Some examples (temp=0.7, topk=40):
----
is used to separate the input and model generated outputs.
1)[NLU]I am having a great day because I am finally able to load this amazing large language model ---- <extra_id_0> into a neural network. I have been working on this for days, and it really is such a great feeling when it works. The model is made up of a million tokens and is about 6 megabytes. This is the model that was made for Google’s Project Adam. <extra_id_10> this model into a neural network, I had to do a little bit of tweaking
2)[NLG]I am having a great day because I am finally able to load this amazing large language model ---- <extra_id_0> and my ML is able to recognize and recognize a lot of words now. I can’t wait to share this great news with you guys. So here are the instructions to install language models on TensorFlow. I know it is a bit difficult to follow. I tried to explain it as simple as possible. It took me a long time to get this one. So I hope it helps. TLDR:
3)[S2S]I am having a great day because I am finally able to load this amazing large language model ---- which has been kindly provided by
4)[NLU]Why did chicken cross the road? ---- <extra_id_0> Chicken got hit by a car! -.- Why did the chicken cross the road? Because he was a stupid chicken! -.- Why did the chicken cross the road? Because he was a stupid chicken! -.- <extra_id_10> the chicken cross the road? Because he was
5)[NLG]Why did chicken cross the road? ---- <extra_id_0>. Why did the chicken cross the road?. . . . Why did the chicken cross the road?. . . . Why did the chicken cross the road? . . . . Why did the chicken cross the road? . . . Why did the chicken cross the road?. . . . Why did the chicken cross the road?. . . . Why did the chicken cross the road?. . . . Why did the chicken cross the road?. . . . Why did the chicken cross the road?. . . . Why did the chicken cross the road?. . . . Why did the chicken cross the road?. . . . Why did the chicken cross the road?. . . . Why did the chicken cross the road?. . . . Why did the chicken cross the road?. . . . Why did the chicken cross the road?. . . . Why did the chicken cross the road?. . . . Why did the chicken cross the road?. . . . Why did the chicken cross the road?. . . . Why did the chicken cross the road?. . . . Why did the chicken cross the road?. . . . Why did the chicken cross the road?. . . . Why did the chicken cross the road?. . . . Why did the chicken cross the road?. . . . Why did the chicken cross the road?. . . . Why did the chicken cross the road?. . . . Why did the chicken cross the <extra_id_7> . . . Why did the chicken cross the road?. . . . Why did the chicken cross the road?. . . . Why did the chicken cross the road?. . . . Why did the chicken cross the road?. . . . Why did the chicken cross the road?. . . . Why did the chicken cross the road?. . .
6)[S2S]Why did chicken cross the road? ---- .. Why did chicken cross the road?.. Why did chicken cross the road?.. ..because he was going to the other side of the road to get a good burger. Why did chicken cross the road?.. ..because he was going to the other side of the road to get a good burger. Why did chicken cross the road?.. Is this the same chicken that is in the bathroom stall at the hospital? Why did chicken cross the road?.. Why did chicken cross the road?.. It was going to the other side of the road to get a good burger. Why did chicken cross the road?.. ..because he was going to the other side of the road to get a good burger. Why did chicken cross the road?.. Why did chicken cross the road?.. Why did chicken cross the road?.. Why did chicken cross the road?.. Why did chicken cross the road?.. Why did chicken cross the road?.. Why did chicken cross the road?.. ..because he was going to the other side of the road to get a good burger. Why did chicken cross the road?.. ..because he was going to the other side of the road to get a <extra_id_7>. Why did chicken cross the road?.. Why did chicken cross the road?.. Why did chicken cross the road?.. ..because he was going to the other side of the road to get a good burger. Why did chicken cross the road?.. ..because he was <extra_id_56> to
7)[NLU]What is 2 x 10 ? Think step by step. ---- <extra_id_0> What is 2 x 10 ? Think step by step. What is 2 x 10 ? What is 2 x 10 ? What is 2 x 10 ? Think step by step. What is 2 x 10 ? Think step by step. What <extra_id_8> 6 <extra_id_9> What is 2 x 10 ? <extra_id_10> 2 x 10 ? Think step by step. What is 2 x 10 ? Think step by step. What is
8)[NLG]What is 2 x 10 ? Think step by step. ---- <extra_id_0>. Add. What is 2 x 10 ? Think step by step. Add. Multiply. Find the difference. What is 2 x 10 ? Think step by step. Add. Multiply. What is 2 x 10 ? Think step by step. Add. Divide. Find the difference. What is 2 x 10 ? Think step by step. Add. Multiply. Add the numbers together. What is 2 x 10 ? Think step by step. Add. Multiply. How many are there altogether? What is 2 x 10 ? Think step by step. Add. Multiply. What is 2 x 10 ? Think step by step. Add. Multiply. Add the numbers together. What is 2 x 10 ? Think step by step. Add. Divide. What is 2 x 10 ? Think step by step. Add. Multiply. Find the difference. What is 2 x 10 ? Think step by step. Add. Multiply. What is 2 x 10 ? Think step by step. Add. Multiply. Count the numbers in the pair and write the sum . What is 2 x 10 ? Think step by step. Add. Multiply. Find the difference. What is 2 x 10 ? Think step by step. Add. Multiply. Find the difference. How many are there altogether? What
9)[S2S]What is 2 x 10 ? Think step by step. ---- First, think 10 times 2 = 20. Then, think 2 times 10 is <extra_id_2> Then, 2 times
I am not able to figure out how to set the interactive model config to load
checkpoint_path='gs://scenic-bucket/ul2/ul220b/checkpoint_2650000'
from here.I followed the gin config file as best I can:
But I am getting an error about
decoder_norm
probably because of the last 2 lines of the following config (UL2 config):