RUCAIBox / RecBole

A unified, comprehensive and efficient recommendation library
https://recbole.io/
MIT License
3.27k stars 590 forks source link

请问实现自定义evaluation metric和sampler后在哪引用? #1965

Closed xhran2010 closed 5 months ago

xhran2010 commented 5 months ago

https://recbole.io/docs/developer_guide/customize_metrics.html# 步骤实现后,具体是在哪个参数传入自己新定义的metrics对象呢?

xhran2010 commented 5 months ago

以及Customize sampler实现后在哪里调用呢?

Yilu114 commented 5 months ago

在RecBole中,如果您已经按照[官方文档] 的步骤实现了自定义的评估指标,那么您可以在calculate_metric方法中使用这个新的指标。

首先,您需要在recbole.evaluator.metrics文件中创建一个新的类,并在__init__()方法中定义参数。然后,您需要设置指标的属性,包括metric_need(指标需要的输入),metric_type(指标所需的分数是否按用户分组),以及smaller(较小的指标值是否表示更好的性能)。

最后,您需要实现calculate_metric(self, dataobject)方法¹。所有的计算过程都在这个函数中定义。dataobject是一个包含所有结果的打包数据对象,我们可以将其视为一个字典,并通过rec_items = dataobject.get('rec.items')从中获取数据¹。返回的值应该是一个字典,键是指标名称,值是最终结果¹。请注意,指标名称应为小写。

以下是一个示例代码:

from recbole.evaluator.base_metric import AbstractMetric
from recbole.utils import EvaluatorType

class MyMetric(AbstractMetric):
    metric_type = EvaluatorType.RANKING
    metric_need = ['rec.items', 'data.num_items']
    smaller = True

    def __init__(self, config):
        pass

    def calculate_metric(self, dataobject):
        rec_items = dataobject.get('rec.items')
        # 在这里添加您的指标逻辑
        return result_dict

如果您想使用自定义的Sampler,您需要创建一个新的Sampler类,并继承自recbole.sampler.sampler.AbstractSampler。以下是一些关键步骤:

  1. 创建新的Sampler类:首先,您需要在recbole.sampler文件中创建一个新的类,并从AbstractSampler继承¹。例如,如果您想创建一个名为MySampler的新Sampler,您可以这样做:
    
    from recbole.sampler import AbstractSampler

class MySampler(AbstractSampler): pass

2. **实现`__init__()`方法**:然后,您需要实现`__init__()`方法,该方法用于初始化Sampler,包括加载数据集信息、Sampler参数等。

3. **实现`sample_by_key_ids()`方法**:此方法用于根据给定的key_ids进行采样¹。它接收两个参数:`key_ids`和`num`,分别表示输入的key_ids和每个key_id需要采样的value_ids的数量¹。该方法应返回一个torch.tensor,其中包含采样的value_ids¹。

4. **实现其他必要的方法**:根据您的需求,您可能还需要实现其他方法,如`get_used_ids()`、`sampling()`等¹。

以下是一个示例代码:
```python
from recbole.sampler import AbstractSampler
import torch

class MySampler(AbstractSampler):
    def __init__(self, distribution, alpha):
        super().__init__(distribution, alpha)

    def sample_by_key_ids(self, key_ids, num):
        # 在这里添加您的采样逻辑
        return sampled_value_ids  # 返回一个torch.tensor
xhran2010 commented 5 months ago

谢谢~