NVIDIA-Merlin / HugeCTR

HugeCTR is a high efficiency GPU framework designed for Click-Through-Rate (CTR) estimating training
Apache License 2.0
937 stars 200 forks source link

[Question] How can I pre-calculate the GPU memory required for embedding cache size? #427

Open tuanavu opened 11 months ago

tuanavu commented 11 months ago

Details

My company currently operates a Recommender model trained with TensorFlow 2 (TF2) and served on CPU pods. We are exploring the potential of HugeCTR due to its promising GPU embedding cache capabilities and are considering switching our model to it. We have successfully retrained our existing TF2 model with the SparseOperationsKit (more info) and created the inference graph with HPS, as demonstrated in these notebooks: sok_to_hps_dlrm_demo.ipynb and demo_for_tf_trained_model.ipynb

Result: We deployed the model and used Triton's perf_analyzer to test its performance with varying batch sizes. The results were as follows:

Testing Environment:

To maximize throughput, we plan to test the model across different instance types with varying GPU memory sizes. However, optimizing different parameters in config and selecting the best instance type for inference requires a clear understanding of how embedding cache size is calculated.

Details about the current model and embedding tables:

Our current model has various dense, sparse and pre-trained sparse features. After exporting the TF+SOK model to HPS, we have total 42 embedding tables, i.e.: sparse_files in hps_config.json. Here’s the stats:

====================================================HPS Create====================================================
[HCTR][19:30:44.749][INFO][RK0][main]: Creating HashMap CPU database backend...
[HCTR][19:30:44.749][DEBUG][RK0][main]: Created blank database backend in local memory!
[HCTR][19:30:44.749][INFO][RK0][main]: Volatile DB: initial cache rate = 1
[HCTR][19:30:44.749][INFO][RK0][main]: Volatile DB: cache missed embeddings = 0
[HCTR][19:30:44.749][DEBUG][RK0][main]: Created raw model loader in local memory!
[HCTR][19:30:44.765][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_dense_emb_monolith; cached 16343 / 16343 embeddings in volatile database (HashMapBackend); load: 16343 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.775][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sparse_emb_monolith; cached 12239 / 12239 embeddings in volatile database (HashMapBackend); load: 12239 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.789][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_101; cached 66672 / 66672 embeddings in volatile database (HashMapBackend); load: 66672 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.800][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_102; cached 61395 / 61395 embeddings in volatile database (HashMapBackend); load: 61395 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.812][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_11; cached 73572 / 73572 embeddings in volatile database (HashMapBackend); load: 73572 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.823][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_14; cached 66534 / 66534 embeddings in volatile database (HashMapBackend); load: 66534 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.835][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_15; cached 64777 / 64777 embeddings in volatile database (HashMapBackend); load: 64777 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.848][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_18; cached 59276 / 59276 embeddings in volatile database (HashMapBackend); load: 59276 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.859][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_3; cached 14670 / 14670 embeddings in volatile database (HashMapBackend); load: 14670 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.871][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_4; cached 19489 / 19489 embeddings in volatile database (HashMapBackend); load: 19489 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.881][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_5; cached 20859 / 20859 embeddings in volatile database (HashMapBackend); load: 20859 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.893][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_56; cached 52218 / 52218 embeddings in volatile database (HashMapBackend); load: 52218 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.904][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_6; cached 21863 / 21863 embeddings in volatile database (HashMapBackend); load: 21863 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.915][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_60; cached 11075 / 11075 embeddings in volatile database (HashMapBackend); load: 11075 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.926][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_67; cached 28075 / 28075 embeddings in volatile database (HashMapBackend); load: 28075 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.936][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_68; cached 26174 / 26174 embeddings in volatile database (HashMapBackend); load: 26174 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.947][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_69; cached 13388 / 13388 embeddings in volatile database (HashMapBackend); load: 13388 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.957][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_70; cached 13759 / 13759 embeddings in volatile database (HashMapBackend); load: 13759 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.968][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_71; cached 4157 / 4157 embeddings in volatile database (HashMapBackend); load: 4157 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.978][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_72; cached 4622 / 4622 embeddings in volatile database (HashMapBackend); load: 4622 / 18446744073709551615 (0.00%).
[HCTR][19:30:44.992][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_73; cached 71902 / 71902 embeddings in volatile database (HashMapBackend); load: 71902 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.005][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_74; cached 81772 / 81772 embeddings in volatile database (HashMapBackend); load: 81772 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.018][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_75; cached 79451 / 79451 embeddings in volatile database (HashMapBackend); load: 79451 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.031][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_76; cached 66525 / 66525 embeddings in volatile database (HashMapBackend); load: 66525 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.044][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_77; cached 70591 / 70591 embeddings in volatile database (HashMapBackend); load: 70591 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.055][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_78; cached 29983 / 29983 embeddings in volatile database (HashMapBackend); load: 29983 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.068][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_79; cached 68439 / 68439 embeddings in volatile database (HashMapBackend); load: 68439 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.079][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_80; cached 35543 / 35543 embeddings in volatile database (HashMapBackend); load: 35543 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.089][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_89; cached 2170 / 2170 embeddings in volatile database (HashMapBackend); load: 2170 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.099][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_90; cached 2100 / 2100 embeddings in volatile database (HashMapBackend); load: 2100 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.109][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_91; cached 1762 / 1762 embeddings in volatile database (HashMapBackend); load: 1762 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.119][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_92; cached 1269 / 1269 embeddings in volatile database (HashMapBackend); load: 1269 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.128][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_93; cached 154 / 154 embeddings in volatile database (HashMapBackend); load: 154 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.138][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_94; cached 233 / 233 embeddings in volatile database (HashMapBackend); load: 233 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.148][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_95; cached 2771 / 2771 embeddings in volatile database (HashMapBackend); load: 2771 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.158][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_96; cached 1695 / 1695 embeddings in volatile database (HashMapBackend); load: 1695 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.168][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_97; cached 1931 / 1931 embeddings in volatile database (HashMapBackend); load: 1931 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.181][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_98; cached 72896 / 72896 embeddings in volatile database (HashMapBackend); load: 72896 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.193][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_sf_99; cached 66650 / 66650 embeddings in volatile database (HashMapBackend); load: 66650 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.231][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_TopicVAEDigestAnswer180DayWDR-sf_2; cached 524289 / 524289 embeddings in volatile database (HashMapBackend); load: 524289 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.253][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_TopicVAETopics-sf_1; cached 262145 / 262145 embeddings in volatile database (HashMapBackend); load: 262145 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.274][INFO][RK0][main]: Table: hps_et.full_model_max_mono.hps_TopicVAETopics-sf_7; cached 262145 / 262145 embeddings in volatile database (HashMapBackend); load: 262145 / 18446744073709551615 (0.00%).
[HCTR][19:30:45.274][DEBUG][RK0][main]: Real-time subscribers created!
[HCTR][19:30:45.274][INFO][RK0][main]: Creating embedding cache in device 0.
[HCTR][19:30:45.275][INFO][RK0][main]: Model name: full_model_max_mono
[HCTR][19:30:45.275][INFO][RK0][main]: Max batch size: 1024
[HCTR][19:30:45.275][INFO][RK0][main]: Number of embedding tables: 42
[HCTR][19:30:45.275][INFO][RK0][main]: Use GPU embedding cache: True, cache size percentage: 1.000000
[HCTR][19:30:45.275][INFO][RK0][main]: Use static table: False
[HCTR][19:30:45.275][INFO][RK0][main]: Use I64 input key: True
[HCTR][19:30:45.275][INFO][RK0][main]: Configured cache hit rate threshold: 1.100000
[HCTR][19:30:45.275][INFO][RK0][main]: The size of thread pool: 16
[HCTR][19:30:45.275][INFO][RK0][main]: The size of worker memory pool: 1
[HCTR][19:30:45.275][INFO][RK0][main]: The size of refresh memory pool: 1
[HCTR][19:30:45.275][INFO][RK0][main]: The refresh percentage : 0.200000
[HCTR][19:30:45.378][DEBUG][RK0][main]: Created raw model loader in local memory!
{
    "supportlonglong": true,
    "volatile_db": {
        "type": "parallel_hash_map",
        "allocation_rate": 100000.0,
        "initial_cache_rate": 1.0
    },
    "persistent_db": {
        "type": "disabled"
    },
    "models": [
        {
            "model": "full_model_max_mono",
            "sparse_files": [
                list_of_sparse_files
            ],
            "num_of_worker_buffer_in_pool": 2,
            "instance_group": 4,
            "embedding_table_names": [
                "hps_dense_emb_monolith",
                "hps_sparse_emb_monolith",
                "hps_sf_101",
                "hps_sf_102",
                "hps_sf_11",
                "hps_sf_14",
                "hps_sf_15",
                "hps_sf_18",
                "hps_sf_3",
                "hps_sf_4",
                "hps_sf_5",
                "hps_sf_56",
                "hps_sf_6",
                "hps_sf_60",
                "hps_sf_67",
                "hps_sf_68",
                "hps_sf_69",
                "hps_sf_70",
                "hps_sf_71",
                "hps_sf_72",
                "hps_sf_73",
                "hps_sf_74",
                "hps_sf_75",
                "hps_sf_76",
                "hps_sf_77",
                "hps_sf_78",
                "hps_sf_79",
                "hps_sf_80",
                "hps_sf_89",
                "hps_sf_90",
                "hps_sf_91",
                "hps_sf_92",
                "hps_sf_93",
                "hps_sf_94",
                "hps_sf_95",
                "hps_sf_96",
                "hps_sf_97",
                "hps_sf_98",
                "hps_sf_99",
                "hps_TopicVAEDigestAnswer180DayWDR-sf_2",
                "hps_TopicVAETopics-sf_1",
                "hps_TopicVAETopics-sf_7"
            ],
            "embedding_vecsize_per_table": [
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                2,
                8,
                8,
                8
            ],
            "maxnum_catfeature_query_per_table_per_sample": [
                221,
                7,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                100,
                1,
                100,
                100
            ],
            "default_value_for_each_table": [
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0
            ],
            "deployed_device_list": [
                0
            ],
            "max_batch_size": 24000,
            "cache_refresh_percentage_per_iteration": 0.2,
            "hit_rate_threshold": 1.1,
            "gpucacheper": 0.8,
            "gpucache": true
        }
    ]
}

Questions

  1. Given the specific details of our HPS model and the provided context, can you guide us on how to estimate the GPU memory needed to store the embedding cache based on the different batch sizes with HugeCTR backend for inference scenarios? This information will assist us in determining the optimal configuration and instance type to maximize our model's throughput during inference.
  2. Assuming that the GPU memory is insufficient to store all embeddings, what would be the best configuration? I understand that I might reduce the GPU cache ratio and cache the entire the embedding table in CPU Memory Database (volatile_db). Could you confirm if this is the correct approach?
  3. I also have a question regarding the allocation_rate configuration in the above volatile_db. I observed that I must reduce allocation_rate = 1e6, or else the default allocation (256 MiB) leads to out-of-memory issue during hps.init. Could you explain why this happens and provide some insights into this matter?
bashimao commented 11 months ago

Regarding 2: Using the parallel_hash_map as your volatile_db is the suggested approach, if you cannot put the entire embedding table directly into the GPU.

Regarding 3: For performance reasons (avoid frequent small allocations) and long term memory fragmentation the hash_map backends allocate memory in chunks. The size of these chunks is 256 MiB. Since you have 42 tables, that means at least 42 x 256 MiB = 10752 MiB will be allocated. Given that your EC2 instance only has 16 GiB memory, you seeing that OOM (Out-Of-Memory) error is not too surprising. However, I noticed your tables are rather small. I think, without loss of performance, it should be fine to decrease the allocation rate to 128 MiB, 100 MiB or even lower like 64 MiB.

yingcanw commented 11 months ago

@tuanavu Regarding the 2rd question, I have some comments here. We already support quantization for fp8 in the static embedding cache from v23.08. HPS will perform fp8 quantization on the embedding vector when reading the embedding table by enable "fp8_quant": true and embedding_cache_type":"static" item in HPS json configuration file, and perform fp32 dequantization on the embedding vector corresponding to the queried embedding key in the static embedding cache, so as to ensure the accuracy of dense part prediction.

Since the embedding is stored with fp8 type and the GPU memory size will be greatly reduced. However, due to different business use cases, the precision loss caused by quantization/dequantization still needs to be evaluated in the real production. So currently we only have experimental support for static embedding caching for POC verification. If quantization can bring greater benefits to your case, we will add quantization features to dynamics and upcoming lock-free optimized gpu cache.