tensorflow / recommenders-addons

Additional utils and helpers to extend TensorFlow when build recommendation systems, contributed and maintained by SIG Recommenders.
Apache License 2.0
593 stars 135 forks source link

dynamic_embedding table 插入数据报错 #463

Open lwmonster opened 1 month ago

lwmonster commented 1 month ago

System information

Describe the bug

在使用 Dynamic Embedding 的时候,创建了 Variable, 这个 Variable 里会有一些 hashtable 用于存储稀疏特征的 embedding。 现在我想用外部数据更新 hashtable 中的 embedding 向量,但是 执行 table.insert 的时候会报错:

ValueError: Operation name: "hashTableGroup_0_mht_1of1_lookup_table_insert/TFRA>CuckooHashTableInsert"
op: "TFRA>CuckooHashTableInsert"
input: "group_0/hashTableGroup_0/hashTableGroup_0_mht_1of1"
input: "hashTableGroup_0_mht_1of1_lookup_table_insert/keys"
input: "hashTableGroup_0_mht_1of1_lookup_table_insert/values"
device: "/job:ps/replica:0/task:0"
attr {
  key: "Tin"
  value {
    type: DT_INT64
  }
}
attr {
  key: "Tout"
  value {
    type: DT_FLOAT
  }
}
attr {
  key: "_class"
  value {
    list {
      s: "loc:@group_0/hashTableGroup_0/hashTableGroup_0_mht_1of1"
    }
  }
}
 is not an element of this graph.

Code to reproduce the issue

我的代码如下:


def _save_merged_ckpt(self,
                          cluster,
                          server,
                          model_config,
                          valid_data_dirs,
                          cluster_spec,
                          task_type,
                          task_index,
                          merged_variables,
                          merged_hashtable,
                          target_ckpt_dir):

        if task_index != 0:
            print('Worker {} is not chief, skip save merged ckpt'.format(task_index))
            return

        if tf.io.gfile.exists(target_ckpt_dir):
            tf.io.gfile.rmtree(target_ckpt_dir)
        tf.io.gfile.makedirs(target_ckpt_dir)

        #首先重置模型图
        tf.reset_default_graph()
        with tf.Graph().as_default():
            eval_ds = DatasetManager.create_dataset(model_config.dataset_type,
                                                    model_config,
                                                    valid_data_dirs,
                                                    training=False,
                                                    is_eval=True)
            inputs = eval_ds.build(batch_size=1,
                                num_epochs=1,
                                num_worker=model_config.num_shard_of_export,
                                worker_idx=task_index)

            with tf.device(
                tf.train.replica_device_setter(worker_device="/job:worker/task:%d" % task_index,
                                            cluster=cluster)):
                # dataset
                inference_iterator = inputs.make_one_shot_iterator()
                sample = inference_iterator.get_next()
                self.model_fn(sample, eval_ds, training=True)

                global_step_assign = tf.assign_add(self._global_step, 1)
                # 获取模型所有参数的值
                all_vars = tf.trainable_variables()
                # 更新dense参数
                assign_ops = []
                for var in all_vars:
                    print("{}'s shape is {}".format(var.name, var.shape))
                    if var.name not in merged_variables:
                        print('var {} not in merged_variables, SKIP it'.format(var.name))
                        continue
                    assign_ops.append(tf.assign(var, merged_variables[var.name]))

                # 更新hashtable参数
                for table in self._hash_tables:
                    if table.name not in merged_hashtable:
                        print('table {} not in merged_hashtable, SKIP it'.format(table.name))
                        continue
                    keys, values = list(merged_hashtable[table.name].keys()), list(merged_hashtable[table.name].values())
                    values = np.array(values)
                    #values_tensor = tf.convert_to_tensor(values, dtype=tf.float32)
                    assign_ops.append(table.insert(keys, values))

            parallelism_conf = self.get_parallelism_conf(task_type, task_index)
            scaffold = self.get_scaffold()
            start = time.time()

            saver = tf.train.Saver()
            with tf.train.MonitoredTrainingSession(master=server.target,
                                                is_chief=(task_index == 0),
                                                scaffold=scaffold,
                                                checkpoint_dir=target_ckpt_dir,
                                                hooks=[],
                                                save_checkpoint_secs=None,
                                                save_checkpoint_steps=None,
                                                save_summaries_steps=None,
                                                save_summaries_secs=None,
                                                config=parallelism_conf) as sess:
                print('Worker {} restore successful....'.format(task_index))

                _, step = sess.run([assign_ops, global_step_assign])
                print('Assign merged_variables to checkpoint done')

我的assign_ops 是在图里的呀,为什么会报CuckooHashTableInsert 这个 OP 不在图里呢? 哪位大佬帮忙指点一下呀 PS: 这里的 self._hash_tables 是保存的所有 dynamic_embedding_variable 变量的 内部的 tables

Provide a reproducible test case that is the bare minimum necessary to generate the problem.

Other info / logs

Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

lwmonster commented 1 month ago

@MoFHeka @rhdong @jq HELP Please ~

lwmonster commented 1 month ago

问题解决了,但还是不知道为啥...