Open vivpra89 opened 1 year ago
I am doing the same. Any leads will be appreciated
@vivpra89 is there a specific reason you don't want to use a single transformer architecture with single sequence? is the reason accuracy? by using a single sequence the position embeddings would help the network to learn about what the most recent interactions are, or you might use time elapsed
as side information. You can calculate the time elapsed between each interaction or since last user interaction.
If you really want to have to sequences, yes you can have two transformer block and concat them without applying masking, which means the target (the last item in the user interaction sequence) should be in a different column and not in the input-sequence.
@rnyak In two sequences approach, we create two separate sequences: one for long-term behavior and another for short-term behavior. The long-term sequence captures the historical preferences and patterns of the user, while the short-term sequence represents recent interactions or activities. Each sequence is trained individually using a transformer architecture, which allows the model to capture complex dependencies and patterns within the data. Once both models are trained, we can combine them using a fusion method.This approach can be effective as it considers both long-term and short-term preferences.
What is your opinion on this architecture?
Hi @NamartaVij and @vivpra89 . Transformers4Rec provides some masking options for training sequential models: Causal Language Modeling and Masked Language Modeling, that you set in TabularSequenceFeatures
like here. Those masking approaches extract the labels from the sequence.
For that reason, in Transformers4Rec you can only have one TabularSequenceFeatures
block with masking set. For example, you could have two TabularSequenceFeatures
(one for each sequence), and set masking
only for the short sequence TabularSequenceFeatures
(from which the targets would be extracted), and then concatenate it with the long sequence TabularSequenceFeatures
, from which the model would have access to all long-sequence positions.
@NamartaVij You proposed training two models separately and that definitely is possible. Maybe you train one of those models first and share the embedding weights with the second model, using the TensorInitializer
. If you use weight tying (NextItemPredictionTask(weight_tying=True)
), then the output of the model will be in the same vector space as the item embeddings. So you could maybe use that property to combine the session embedding output from both short and long term models to look for similar items. Does that make sense?
What do you mean by "using a fusion method"?
@gabrielspmoreira @NamartaVij
What I have in mind is to train both the transformers together like a two tower model. I am referencing this example to concat before the prediction task. Please take a look at below code and share your thoughts.
long_seq_inputs = mm.InputBlockV2(
schema_model.select_by_tag(Tags.SEQUENCE),
categorical=mm.Embeddings(
schema_model.select_by_tag(Tags.SEQUENCE),
sequence_combiner=None,
dim=manual_dims
)
)
short_term_inputs = mm.InputBlockV2( schema_model.select_by_tag(Tags.SEQUENCE), categorical=mm.Embeddings( schema_model.select_by_tag(Tags.SEQUENCE), sequence_combiner=None, dim=manual_dims) )
mlp_block1 = mm.MLPBlock( [128,dmodel], activation='relu', no_activation_last_layer=True, dropout=DROPOUT, )
mlp_block2 = mm.MLPBlock( [128,dmodel], activation='relu', no_activation_last_layer=True, dropout=DROPOUT, )
lt_dense_block = mm.SequentialBlock( long_seq_inputs, mlp_block1, mm.XLNetBlock( d_model=dmodel, n_head=4, n_layer=2, post='sequence_mean', ) )
st_dense_block = mm.SequentialBlock( short_term_inputs, mlp_block2, mm.XLNetBlock( d_model=dmodel, n_head=4, n_layer=2, post='sequence_mean', ) )
concats = mm.ParallelBlock( {'dense_block': lt_dense_block, 'cat_inputs': st_dense_block}, aggregation='concat' )
mlp_block2 = mm.MLPBlock( [128,dmodel], activation='relu', no_activation_last_layer=True, dropout=DROPOUT, )
prediction_task= mm.CategoricalOutput( to_call=seq_inputs["categorical"][item_id_name], logits_temperature=TEMPERATURE_SCALING, target='purchase_id_first', )
optimizer = tf.keras.optimizers.Adam( learning_rate=LEARNING_RATE )
model_transformer = mm.Model(concats, mlp_block2, prediction_task)
model_transformer.compile( optimizer=optimizer, run_eagerly=False, loss=tf.keras.losses.CategoricalCrossentropy( from_logits=True, label_smoothing=LABEL_SMOOTHING ), metrics=mm.TopKMetricsAggregator.default_metrics(top_ks=[100]) )
@gabrielspmoreira If we assume Long Term sequences are purchases and short term are ATCs, do you think this works with causal masking for both transformers?
also, I have couple of questions : 1. whats the difference in using merlin-models vs transformers4rec 2. how do we modify the head part (number of layers and nodes in mlp etc)
schema: tr.Schema = tr.data.tabular_sequence_testing_data.schema
max_sequence_length, d_model = 20, 64
input_module_lt = tr.TabularSequenceFeatures.from_schema( schema, max_sequence_length=max_sequence_length, continuous_projection=d_model, aggregation="concat", masking="causal", d_output=200, )
input_module_st = tr.TabularSequenceFeatures.from_schema( schema, max_sequence_length=max_sequence_length, continuous_projection=d_model, aggregation="concat", masking="causal", d_output=200, )
transformer_config_lt = tr.XLNetConfig.build( d_model=d_model, n_head=4, n_layer=2, total_seq_length=max_sequence_length )
transformer_config_st= tr.XLNetConfig.build( d_model=d_model, n_head=4, n_layer=2, total_seq_length=max_sequence_length )
lt_body = tr.SequentialBlock( input_module_lt, tr.MLPBlock([d_model]), tr.TransformerBlock(transformer_config, masking=input_module.masking) )
st_body = tr.SequentialBlock( input_module_st, tr.MLPBlock([d_model]), tr.TransformerBlock(transformer_config, masking=input_module.masking) )
metrics = [NDCGAt(top_ks=[20, 40], labels_onehot=True), RecallAt(top_ks=[20, 40], labels_onehot=True)]
body_concats = mm.ParallelBlock( {'lt_body': lt_body, 'st_body': st_body}, aggregation='concat' )
head = tr.Head( body_concats, tr.NextItemPredictionTask(weight_tying=True, metrics=metrics), inputs=input_module, )
model = tr.Model(head)
Hi @NamartaVij and @vivpra89 . Transformers4Rec provides some masking options for training sequential models: Causal Language Modeling and Masked Language Modeling, that you set in
TabularSequenceFeatures
like here. Those masking approaches extract the labels from the sequence. For that reason, in Transformers4Rec you can only have oneTabularSequenceFeatures
block with masking set. For example, you could have twoTabularSequenceFeatures
(one for each sequence), and setmasking
only for the short sequenceTabularSequenceFeatures
(from which the targets would be extracted), and then concatenate it with the long sequenceTabularSequenceFeatures
, from which the model would have access to all long-sequence positions.@NamartaVij You proposed training two models separately and that definitely is possible. Maybe you train one of those models first and share the embedding weights with the second model, using the
TensorInitializer
. If you use weight tying (NextItemPredictionTask(weight_tying=True)
), then the output of the model will be in the same vector space as the item embeddings. So you could maybe use that property to combine the session embedding output from both short and long term models to look for similar items. Does that make sense? What do you mean by "using a fusion method"? ) @gabrielspmoreira , Thankyou for your response. I got it about TabularSequenceFeature for one sequence. However, Could you please explain the second approach which you suggested regarding sharing the embedding weights with the second model.
Hi @NamartaVij and @vivpra89 . Transformers4Rec provides some masking options for training sequential models: Causal Language Modeling and Masked Language Modeling, that you set in
TabularSequenceFeatures
like here. Those masking approaches extract the labels from the sequence. For that reason, in Transformers4Rec you can only have oneTabularSequenceFeatures
block with masking set. For example, you could have twoTabularSequenceFeatures
(one for each sequence), and setmasking
only for the short sequenceTabularSequenceFeatures
(from which the targets would be extracted), and then concatenate it with the long sequenceTabularSequenceFeatures
, from which the model would have access to all long-sequence positions.@NamartaVij You proposed training two models separately and that definitely is possible. Maybe you train one of those models first and share the embedding weights with the second model, using the
TensorInitializer
. If you use weight tying (NextItemPredictionTask(weight_tying=True)
), then the output of the model will be in the same vector space as the item embeddings. So you could maybe use that property to combine the session embedding output from both short and long term models to look for similar items. Does that make sense? What do you mean by "using a fusion method"?
@gabrielspmoreira Can you please help me with a snipped on how to concat TabularSequenceFeatures, or use TensorInitializer to share embedding weights, couldnt find them in examples.
@gabrielspmoreira
`# Define input block sequence_length, d_model = 20, 64
pretrained_dim = 256 np_emb_item_id = np.random.rand(27929, pretrained_dim)
embeddings_op = EmbeddingOperator( np_emb_item_id, lookup_key="item-list", embedding_name="pretrained_item_embeddings" )
data_loader = MerlinDataLoader.from_schema( schema, train, batch_size=256, max_sequence_length=sequence_length, transforms=[embeddings_op], shuffle=False, )
model_schema = data_loader.output_schema
inputs= tr.TabularSequenceFeatures.from_schema( model_schema, max_sequence_length=sequence_length, continuous_projection=64, aggregation="concat", d_output=d_model, masking="mlm", )
transformer_config = tr.XLNetConfig.build( d_model=d_model, n_head=8, n_layer=2, total_seq_length=sequence_length )
body = tr.SequentialBlock( inputs, tr.MLPBlock([256]), tr.TransformerBlock(transformer_config, masking=inputs.masking) )
prediction_task = tr.NextItemPredictionTask(weight_tying=True,
metrics=[NDCGAt(top_ks=[2, 5, 10], labels_onehot=True),
RecallAt(top_ks=[2, 5, 10], labels_onehot=True)])
head = tr.Head( body,prediction_task )
model_si = tr.Model(head)
training_args = T4RecTrainingArguments( output_dir="./tmp", max_sequence_length=10, data_loader_engine='merlin', num_train_epochs=10, dataloader_drop_last=False, per_device_train_batch_size = 64, per_device_eval_batch_size = 32, gradient_accumulation_steps = 1, learning_rate=0.000666, report_to = [], logging_steps=200, )
trainer = Trainer( model=model_si, args=training_args, schema=schema, compute_metrics=True, ) OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "./data/sessions_by_day_si/") from transformers4rec.torch.utils.examples_utils import fit_and_evaluate OT_results = fit_and_evaluate(trainer, start_time_index=1, end_time_index=2, input_dir=OUTPUT_DIR) ` above code fails with a RuntimeError: mat1 and mat2 shapes cannot be multiplied (640x704 and 960x64) am I missing something here
@gabrielspmoreira @rnyak how to concatenate in the end both outputs to get the final prediction?
❓ Questions & Help
Details
Im working on a project that requires me to produce a stable long-term user and item representation and also use short-term user behavior for next action prediction. Is it possible to create a custom architecture to train two transformer towers together with different inputs and then have concat at later point. What is the recommended architecture for problems like this.