Jyonn / ONCE

(WSDM 2024) Official implementation of the paper "ONCE: Boosting Content-based Recommendation with Both Open- and Closed-source Large Language Models"
https://arxiv.org/abs/2305.06566
68 stars 5 forks source link

Train GENRE with frozen content-encoder #6

Open edervishaj opened 3 months ago

edervishaj commented 3 months ago

Hi,

I am trying to run experiments with only GENRE framework with frozen PLM content-encoder. Could you direct me to any of the existing configurations or give some suggestions on how to set the parameters of the configurations?

Thank you!

Jyonn commented 3 months ago

Thanks for your attention to our work!

To run the GENRE framework with a frozen PLM, we can prepare the following configurations:

GENRE generated data

store the following configuration to config/data/mind-cs.yaml

name: MIND-${version}-Content-Summarizer
base_dir: data/MIND-${version}
item:
  filter_cache: true
  depot: ${data.base_dir}/news
  order:
    - summarizer-bert  # summarized title with GPT
    - cat-bert
  append:
    - nid
  lm_col: summarizer-bert
user:
  filter_cache: true
  depots:
    train:
      path: ${data.base_dir}/train
    dev:
      path: ${data.base_dir}/dev
    test:
      path: ${data.base_dir}/test
  filters:
    history:
      - x
  union:
    - ${data.base_dir}/user
  candidate_col: nid
  clicks_col: history
  label_col: click
  neg_col: neg
  group_col: imp
  user_col: uid
  index_col: index

Model Configuration

Please refer to config/model/llm/bert-naml.yaml for model configuration.

Embed Configuration

Please refer to config/embed/bert-token.yaml for embedding configuration. Please prepare the bert word embeddings in numpy format (30522x768dim) yourself and store it in data/bert-token.12L.npy.

Experiment Configuration

Please refer to config/exp/tt-llm.yaml for experiment configuraiton. Please replace freeze_emb: false to freeze_emb: true.

Additional Parameters

--lora 0: frozen PLM, no need LoRA --item_lr 0: set learning rate of PLM to 0

Running

Based on the above configurations, you can run your experiments by:

python worker.py --data config/data/mind-cs.yaml --version small --model config/model/llm/bert-naml.yaml --embed config/embed/bert-token.yaml --exp config/exp/tt-llm.yaml --lora 0 --item_lr 0 --lr 0.001 --batch_size 32 --embed_hidden_size 768 --page_size 64
edervishaj commented 3 months ago

Thank you for your quick reply and for providing the configurations!

Does this setup run also with user-profiler (beside content-summarizer)? Is there some way of controlling those?

Jyonn commented 3 months ago

The setups for user profiler and personalized content generator is a bit complex, as we only release the GPT generated raw data but did not release the public processed data.

User Profiler

In data configuration, you need to add the following two attributes:

  user_col: uid
  index_col: index
  plugin: ${data.base_dir}/user-plugin
  plugin_cols:
    - topic

The related code is model/common/user_plugin.py.

Basically, you need to build a data folder user_plugin based on UniTok tool. Here is one script that may help you to convert raw data to that format:

import json

import pandas as pd
from UniTok import UniDep, UniTok, Column
from UniTok.tok import SeqTok

topics_path = 'data/mind/topics_v3.json'
region_path = 'data/mind/regions_v3.json'
user_path = 'data/mind/user'

user_depot = UniDep(user_path)
users = user_depot.sample_size

topics_dict = json.load(open(topics_path, 'r'))
regions_dict = json.load(open(region_path, 'r'))

topics = [[] for _ in range(users)]
region = [[] for _ in range(users)]

for uid in topics_dict:
    topics[int(uid)] = topics_dict[uid]

for uid in regions_dict:
    region[int(uid)] = regions_dict[uid]

df = pd.DataFrame(data=dict(uid=list(range(users)), topic=topics, region=region))

ut = UniTok().add_index_col('uid').add_col(Column(
    name='topic',
    tok=SeqTok(name='topic'),
    max_length=10,
)).add_col(Column(
    name='region',
    tok=SeqTok(name='region'),
    max_length=3,
)).read(df).tokenize().store('data/mind/user-plugin')

Personalized Content Generator

It is the most complex one ... Due to the item vocabulary increases, all the data (news, user, train, valid, and test) should be re-built.

First, you should assign new news id for all the new generated news, and form a new item vocabulary.

Second, you should modify the meta data (item vocab size) and tok.nid.dat (item vocabulary) of all the data folder.

Third, you should tokenize new data and append these data into news/data.npy.

Fourth, you should append news id into user history in user/data.npy.

Note: Please backup all the data before doing these operations, so that previous configurations can be reproduced.

Fifth, rename the entire MIND-small folder to MIND-small-CG.

Sixth, modify the running script from --version small to --version small-CG.

edervishaj commented 2 months ago

Hi, thank you for the detailed reply. Sorry for my very long reply!

Unfortunately, I was not able to reproduce the results of GENRE for a combination of CS and UP (without CG) with frozen BERT content encoder. I am using the tokenized data that you have shared.

I am using the following configuration files:

With the above configuration files, I am getting the following results:

[19:00:24] |Worker| Recall@5: 0.3445
[19:00:24] |Worker| Recall@10: 0.5259
[19:00:24] |Worker| Recall@20: 0.7116
[19:00:24] |Worker| NDCG@5: 0.2265
[19:00:24] |Worker| NDCG@10: 0.2894
[19:00:24] |Worker| NDCG@20: 0.3415
[19:00:24] |Worker| MRR: 0.1680

In exp.yaml I am tuning for NDCG@10 instead of AUC. However, I was not expecting the results to deviate this much from the ones reported in the paper (especially MRR). Do you think I am missing something in the configuration files?

Moreover, I tried to reproduce the NAML results reported in the paper. I changed the configurations from above as follows:

These configurations gave the following results (which are better than the ones reported in the paper):

[00:53:47] |Worker| Recall@5: 0.4927
[00:53:47] |Worker| Recall@10: 0.6694
[00:53:47] |Worker| Recall@20: 0.8212
[00:53:47] |Worker| NDCG@5: 0.3485
[00:53:47] |Worker| NDCG@10: 0.4107
[00:53:47] |Worker| NDCG@20: 0.4544
[00:53:47] |Worker| MRR: 0.3351

Thank you and sorry again for the long reply.

Jyonn commented 2 months ago

Hi, can you provide the entire configuration which is printed in the front of the running output? It is better for figuring out. One thing we noticed before is that, we fixed all the hidden size to 64 in the experiments reported in the paper if the model is trained without LLM. When increasing hidden size to 256, the performance grows.

edervishaj commented 2 months ago

Sure! Disregard any changes in directories, I have moved depots between runs.

BERT-NAML log:

[00:00:00] |Worker| {
    "data": {
        "name": "MIND-small",
        "base_dir": "ONCE/MIND-small/data",
        "item": {
            "filter_cache": true,
            "depot": "ONCE/MIND-small/data/news",
            "order": [
                "summarizer-bert",
                "cat-bert"
            ],
            "append": [
                "nid"
            ],
            "lm_col": "summarizer-bert"
        },
        "user": {
            "filter_cache": true,
            "depots": {
                "train": {
                    "path": "ONCE/MIND-small/data/train"
                },
                "dev": {
                    "path": "ONCE/MIND-small/data/dev"
                },
                "test": {
                    "path": "ONCE/MIND-small/data/test"
                }
            },
            "filters": {
                "history": [
                    "x"
                ]
            },
            "union": [
                "ONCE/MIND-small/data/user"
            ],
            "candidate_col": "nid",
            "clicks_col": "history",
            "label_col": "click",
            "neg_col": "neg",
            "group_col": "imp",
            "user_col": "uid",
            "index_col": "index",
            "plugin": "ONCE/MIND-small/data/user-plugin",
            "plugin_cols": [
                "topic",
                "region"
            ]
        }
    },
    "embed": {
        "name": "bert-token",
        "embeddings": [
            {
                "vocab_name": "bert",
                "vocab_type": "numpy",
                "path": "ONCE/MIND-small/data/bert.npy",
                "frozen": true
            }
        ]
    },
    "model": {
        "name": "BERT-NAML",
        "meta": {
            "item": "Bert",
            "user": "Ada",
            "predictor": "Dot"
        },
        "config": {
            "use_neg_sampling": true,
            "use_item_content": true,
            "use_fast_eval": 1,
            "max_item_content_batch_size": 0,
            "same_dim_transform": false,
            "embed_hidden_size": 768,
            "hidden_size": 64,
            "page_size": 512,
            "neg_count": 4,
            "item_config": {
                "llm_dir": "bert-base-uncased",
                "lora": 0,
                "lora_lr": 0
            },
            "user_config": {
                "inputer_config": {
                    "use_cls_token": false,
                    "use_sep_token": false
                }
            }
        }
    },
    "exp": {
        "name": "train_test",
        "dir": "ONCE/saving/MIND-small/BERT-NAML/bert-token-train_test",
        "log": "ONCE/saving/MIND-small/BERT-NAML/bert-token-train_test/exp.log",
        "mode": "train_test",
        "load": {
            "save_dir": null,
            "epochs": null,
            "model_only": true,
            "strict": true,
            "wait": false
        },
        "store": {
            "metric": "NDCG@10",
            "maximize": true,
            "top": 1,
            "early_stop": 2
        },
        "policy": {
            "epoch_start": 0,
            "epoch": 50,
            "lr": 0.001,
            "item_lr": 0,
            "freeze_emb": true,
            "pin_memory": false,
            "epoch_batch": 0,
            "batch_size": 16,
            "accumulate_batch": 1,
            "device": "gpu",
            "n_warmup": 0,
            "check_interval": -2,
            "simple_dev": false
        },
        "metrics": [
            "Recall@5",
            "Recall@10",
            "Recall@20",
            "NDCG@5",
            "NDCG@10",
            "NDCG@20",
            "MRR"
        ]
    },
    "lr": 0.001,
    "batch_size": 16,
    "embed_hidden_size": 768,
    "page_size": 64,
    "lora": 0,
    "item_lr": 0,
    "warmup": 0,
    "fast_eval": true,
    "simple_dev": false,
    "acc_batch": 1,
    "lora_r": 32,
    "mind_large_submission": false,
    "hidden_size": 64,
    "epoch_batch": 0,
    "max_item_batch_size": 0,
    "patience": 2,
    "epoch_start": 0,
    "frozen": true,
    "load_path": null,
    "rand": {},
    "time": {},
    "seed": 2023
}

NAML log:

[00:00:00] |Worker| {
    "data": {
        "name": "MIND-small",
        "base_dir": "ONCE/MIND-small/data",
        "item": {
            "filter_cache": true,
            "depot": "ONCE/MIND-small/data/news",
            "order": [
                "title",
                "cat"
            ],
            "append": [
                "nid"
            ]
        },
        "user": {
            "filter_cache": true,
            "depots": {
                "train": {
                    "path": "ONCE/MIND-small/data/train"
                },
                "dev": {
                    "path": "ONCE/MIND-small/data/dev"
                },
                "test": {
                    "path": "ONCE/MIND-small/data/test"
                }
            },
            "filters": {
                "history": [
                    "x"
                ]
            },
            "union": [
                "ONCE/MIND-small/data/user"
            ],
            "candidate_col": "nid",
            "clicks_col": "history",
            "label_col": "click",
            "neg_col": "neg",
            "group_col": "imp",
            "user_col": "uid",
            "index_col": "index"
        }
    },
    "embed": {
        "name": "bert-token",
        "embeddings": [
            {
                "vocab_name": "bert",
                "vocab_type": "numpy",
                "path": "ONCE/MIND-small/data/bert.npy",
                "frozen": true
            }
        ]
    },
    "model": {
        "name": "NAML",
        "meta": {
            "item": "CNN",
            "user": "Ada",
            "predictor": "Dot"
        },
        "config": {
            "use_neg_sampling": true,
            "use_item_content": true,
            "use_fast_eval": 1,
            "max_item_content_batch_size": 0,
            "same_dim_transform": false,
            "embed_hidden_size": 768,
            "hidden_size": 64,
            "page_size": 512,
            "neg_count": 4,
            "item_config": {
                "dropout": 0.1,
                "kernel_size": 3
            },
            "user_config": {
                "inputer_config": {
                    "use_cls_token": false,
                    "use_sep_token": false
                }
            }
        }
    },
    "exp": {
        "name": "train_test",
        "dir": "ONCE/saving/MIND-small/NAML/bert-token-train_test",
        "log": "ONCE/saving/MIND-small/NAML/bert-token-train_test/exp.log",
        "mode": "train_test",
        "load": {
            "save_dir": null,
            "epochs": null,
            "model_only": true,
            "strict": true,
            "wait": false
        },
        "store": {
            "metric": "NDCG@10",
            "maximize": true,
            "top": 1,
            "early_stop": 2
        },
        "policy": {
            "epoch_start": 0,
            "epoch": 50,
            "lr": 0.001,
            "item_lr": 0,
            "freeze_emb": true,
            "pin_memory": false,
            "epoch_batch": 0,
            "batch_size": 16,
            "accumulate_batch": 1,
            "device": "gpu",
            "n_warmup": 0,
            "check_interval": -2,
            "simple_dev": false
        },
        "metrics": [
            "Recall@5",
            "Recall@10",
            "Recall@20",
            "NDCG@5",
            "NDCG@10",
            "NDCG@20",
            "MRR"
        ]
    },
    "lr": 0.001,
    "batch_size": 16,
    "embed_hidden_size": 768,
    "page_size": 64,
    "lora": 0,
    "item_lr": 0,
    "warmup": 0,
    "fast_eval": true,
    "simple_dev": false,
    "acc_batch": 1,
    "lora_r": 32,
    "mind_large_submission": false,
    "hidden_size": 64,
    "epoch_batch": 0,
    "max_item_batch_size": 0,
    "patience": 2,
    "epoch_start": 0,
    "frozen": true,
    "load_path": null,
    "rand": {},
    "time": {},
    "seed": 2023
}
Jyonn commented 3 weeks ago

Hi, very sorry for the late reply.

Unfortunately, I was not able to reproduce the results of GENRE for a combination of CS and UP (without CG) with frozen BERT content encoder. I am using the tokenized data that you have shared.

I am a little confused about this setting. It seems we do not have such experiments in our paper. All GENRE-type experiments use the original news encoder (for example, CNN for NAML and Attention for NRMS) in Table 2. Moreover, all experiments in Table 3 use original datasets, rather than GENRE-generated ones.

These configurations gave the following results (which are better than the ones reported in the paper):

To reproduce NAML model, please use title-token and cat-token, rather than title-bert and cat-bert in data.yaml.

Best,

Qijiong