Oneflow-Inc / models

Models and examples built with OneFlow
Apache License 2.0
94 stars 37 forks source link

DeepFM 40M #366

Open MARD1NO opened 2 years ago

MARD1NO commented 2 years ago

分支主要修改 dropout=0.05, adam eps=1e-7, lr_scheduler进行调整

fp16: 
================ Test Evaluation ================
Rank[0], Epoch 7, Step 75000, AUC 0.800648, LogLoss 0.125599, Eval_time 15.19 s, Metrics_time 4.50 s, Eval_samples 89192448, GPU_Memory 15908 MiB, Host_Memory 15763 MiB, 2022-07-26 11:50:46

fp32: 
================ Test Evaluation ================
Rank[0], Epoch 7, Step 75000, AUC 0.802384, LogLoss 0.125783, Eval_time 16.35 s, Metrics_time 4.75 s, Eval_samples 89192448, GPU_Memory 15944 MiB, Host_Memory 15736 MiB, 2022-07-26 12:56:33

对应HugeCTR脚本:

import hugectr
from mpi4py import MPI
data_dir = "/RAID0/liujuncheng/criteo1t_parquet_40M_long"
# data_dir = "/RAID0/xiexuan/criteo1t_parquet_C39_int32"

solver = hugectr.CreateSolver(batchsize_eval = 55296,# real value
                              batchsize = 55296, # 55296 or 69120
                              lr = 0.0025, # 对齐
                              warmup_steps = 3000, 
                              decay_start = 10000, 
                              decay_steps = 60000, 
                              decay_power = 2.0,
                              end_lr = 1e-8,
                              enable_tf32_compute = True,
                              #use_mixed_precision = True,
                              #scaler = 1024,
                              vvgpu = [[0,1,2,3]], # 8 gpus
                              repeat_dataset = True,
                              use_algorithm_search=False,
                              i64_input_key = False) # in32, False
                            #   i64_input_key = True) # in32, False

# slot_size = [39873037, 38856, 17238, 7420, 20263, 3, 7102, 1540, 63, 38457571, 2928921, 400819, 10, 2208, 11910, 152, 4, 976, 14, 39976965, 25417819, 39639213, 583186, 12928, 108, 36, 62774, 8001, 2901, 74279, 7513, 3369, 1392, 21627, 7919, 21, 276, 1231237, 9643]
# reader = hugectr.DataReaderParams(data_reader_type = hugectr.DataReaderType_t.Parquet,
#                                   source = [f"{data_dir}/train/_file_list.txt"],
#                                   eval_source = f"{data_dir}/test/_file_list.txt",
#                                   slot_size_array = slot_size, 
#                                   check_type = hugectr.Check_t.Non)

slot_size = [62774, 8001, 2901, 74279, 7513, 3369, 1392, 21627, 7919, 21, 276, 1231236, 9643, 39873199, 38853, 17240, 7421, 20263, 3, 7103, 1540, 63, 38457188, 2929249, 400771, 10, 2209, 11910, 152, 4, 976, 14, 39976779, 25414584, 39639858, 583095, 12929, 108, 36]
reader = hugectr.DataReaderParams(data_reader_type = hugectr.DataReaderType_t.Raw,
                                  source = ["/RAID0/xiexuan/criteo1t_hugectr_raw_C39_int32/train.bin"],
                                  eval_source = "/RAID0/xiexuan/criteo1t_hugectr_raw_C39_int32/test.bin",
                                  check_type = hugectr.Check_t.Non,
                                  num_samples = 4195197692,
                                  eval_num_samples = 89137319,
                                  slot_size_array = slot_size)

optimizer = hugectr.CreateOptimizer(optimizer_type = hugectr.Optimizer_t.Adam,
                                    update_type = hugectr.Update_t.Local, #有可能会影响性能
                                    beta1 = 0.9,
                                    beta2 = 0.999,
                                    epsilon = 1e-7)

dropout_rate = 0.05

model = hugectr.Model(solver, reader, optimizer)
model.add(hugectr.Input(label_dim = 1, label_name = "labels",
                        dense_dim = 0, 
                        dense_name = "dense",
                        data_reader_sparse_param_array = 
                        [hugectr.DataReaderSparseParam("data1", 2, False, 39)])) # 2 False 的含义

model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.LocalizedSlotSparseEmbeddingHash, #有三种可以选 
                           workspace_size_per_gpu_in_mb = 15000,#bigger enough
                           embedding_vec_size = 17,
                           combiner = "sum",
                           sparse_embedding_name = "sparse_embedding1",
                           bottom_name = "data1",
                           slot_size_array = slot_size,  # real value
                           optimizer = optimizer))

# model.add(hugectr.SparseEmbedding(embedding_type = hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash, #有三种可以选 
#                            workspace_size_per_gpu_in_mb = 15000,#bigger enough
#                            embedding_vec_size = 17,
#                            combiner = "sum",
#                            sparse_embedding_name = "sparse_embedding1",
#                            bottom_name = "data1",
#                            slot_size_array = slot_size,  # real value
#                            optimizer = optimizer))

model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
                            bottom_names = ["sparse_embedding1"],
                            top_names = ["reshape_sparse_embedding1"],
                            leading_dim=17))  

model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Slice,
                            bottom_names = ["reshape_sparse_embedding1"],
                            top_names = ["embedded_x", "lr_embedded_x"],
                            ranges=[(0,16),(16,17)]))

model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
                            bottom_names = ["embedded_x"],
                            top_names = ["reshaped_embedded_x"],
                            leading_dim=16 * 39))

model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,
                            bottom_names = ["lr_embedded_x"],
                            top_names = ["reshaped_lr_embedded_x"],
                            leading_dim=39))

# lr_out
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReduceSum,
                            bottom_names = ["reshaped_lr_embedded_x"],
                            top_names = ["lr_out"],
                            axis=1))

# dot_sum
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.FmOrder2,
                            bottom_names = ["reshaped_embedded_x"],
                            top_names = ["dot_sum"],
                            out_dim=1))

# layer1
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["reshaped_embedded_x"],
                            top_names = ["fc1"],
                            num_output=1000))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
                            bottom_names = ["fc1"],
                            top_names = ["relu1"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout,
                            bottom_names = ["relu1"],
                            top_names = ["dropout1"],
                            dropout_rate=dropout_rate))

# layer2
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["dropout1"],
                            top_names = ["fc2"],
                            num_output=1000))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
                            bottom_names = ["fc2"],
                            top_names = ["relu2"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout,
                            bottom_names = ["relu2"],
                            top_names = ["dropout2"],
                            dropout_rate=dropout_rate))

# layer3
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["dropout2"],
                            top_names = ["fc3"],
                            num_output=1000))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
                            bottom_names = ["fc3"],
                            top_names = ["relu3"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout,
                            bottom_names = ["relu3"],
                            top_names = ["dropout3"],
                            dropout_rate=dropout_rate))

# layer4
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["dropout3"],
                            top_names = ["fc4"],
                            num_output=1000))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
                            bottom_names = ["fc4"],
                            top_names = ["relu4"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout,
                            bottom_names = ["relu4"],
                            top_names = ["dropout4"],
                            dropout_rate=dropout_rate))

# layer5
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["dropout4"],
                            top_names = ["fc5"],
                            num_output=1000))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.ReLU,
                            bottom_names = ["fc5"],
                            top_names = ["relu5"]))
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Dropout,
                            bottom_names = ["relu5"],
                            top_names = ["dropout5"],
                            dropout_rate=dropout_rate))

# layer6
model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.InnerProduct,
                            bottom_names = ["dropout5"],
                            top_names = ["fc6"],
                            num_output=1))
# skip final activation

model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Add,
                            bottom_names = ["fc6", "lr_out", "dot_sum"],
                            top_names = ["add"])) 

model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.BinaryCrossEntropyLoss,
                            bottom_names = ["add", "labels"],
                            top_names = ["loss"]))
model.compile()
model.summary()
model.fit(max_iter = 75000, display = 1000, eval_interval = 4999, snapshot = 1000000, snapshot_prefix = "deepfm")
# model.fit(max_iter = 6000, display = 1000, eval_interval = 4999, snapshot = 1000000, snapshot_prefix = "deepfm")

如果要用narrow:

        multi_embedded_x = self.embedding_layer(inputs)
        # print("multi_embedded_x is: ", multi_embedded_x.shape)
        embedded_x = flow.narrow(multi_embedded_x, 2, 0, self.embedding_vec_size)
        lr_embedded_x = flow.narrow(multi_embedded_x, 2, self.embedding_vec_size, 1) # oneflow.Size([55296, 39, 1])
        # FM
        lr_out = flow.sum(lr_embedded_x, dim=1, keepdim=False)
        dot_sum = interaction(embedded_x, self.use_fuse_interaction)
        fm_pred = lr_out + dot_sum

        # DNN
        dnn_pred = self.dnn_layer(embedded_x.flatten(start_dim=1))