aksnzhy / xlearn

High performance, easy-to-use, and scalable machine learning (ML) package, including linear model (LR), factorization machines (FM), and field-aware factorization machines (FFM) for Python and CLI interface.
https://xlearn-doc.readthedocs.io/en/latest/index.html
Apache License 2.0
3.09k stars 519 forks source link

FM 分类训练--no-norm 预测正常,去掉--no-norm 预测得到接近1的值 #297

Open BHMliang opened 5 years ago

BHMliang commented 5 years ago

通过命令行训练模型,加上--no-norm 得到模型文件。用自定义代码进行预测,与使用命令行加上--no-norm预测的结果一致; 通过命令行训练模型,不加--no-norm 得到模型文件。用自定义代码(归一化,norm = 1.0 / feaIds.size(); sqrtNorm = Math.sqrt(norm))进行预测,与使用命令行不加--no-norm预测的结果不一致;自定义代码预测的值比较大; 自定义预测代码:

        double sumWeight = 0.0;
        // bias
        sumWeight += bias;
        // norm
        double norm = 1.0 / feaIds.size();

        //norm = 1.0;

        // sqrt norm
        double sqrtNorm = Math.sqrt(norm);

        // linear term
        sumWeight += feaIds.stream().mapToDouble(feaId -> NumberUtils.doubleValue(feature2Weight.get(feaId)) * sqrtNorm).sum();
       //latent factor
       for (int i = 0; i < latentSize; i++) {
            //sum1
            double sum1 = 0.0;
            //sum2
            double sum2 = 0.0;
            for (long feaId : feaIds) {
                List<Double> embedding = getEmbedding(feaId);
                if(CollectionUtils.isEmpty(embedding)){
                    continue;
                }
                double d = embedding.get(i) * sqrtNorm;
                sum1 += d;
                sum2 += d * d;
            }
            sumWeight += (0.5 * (sum1 * sum1 - sum2));
        }
        for (int i = 0; i < latentSize; i++) {
            //sum1
            double sum1 = 0.0;
            //sum2
            double sum2 = 0.0;
            for (long feaId : feaIds) {
                List<Double> embedding = getEmbedding(feaId);
                if(CollectionUtils.isEmpty(embedding)){
                    continue;
                }
                double d = embedding.get(i) * sqrtNorm;
                sum1 += d;
                sum2 += d * d;
            }
         sumWeight += (0.5 * (sum1 * sum1 - sum2));

@aksnzhy 请问我的代码有什么问题吗

BHMliang commented 5 years ago

double d = embedding.get(i) sqrtNorm; 改为double d = embedding.get(i) norm;之后结果一致,但不明白为什么要*norm。 根据原公式: image

vx 翻译过来应该是embedding.get(i) sqrtNorm 1。 不明白这里为什么却是embedding.get(i) norm * 1

dongjiewhu commented 5 years ago

@aksnzhy 和上面一样的疑问,看了一遍xlearn的代码,在predict和CaclGrad的时候,在embedding这里都是乘的norm,而不是sqrtNorm

BUCTdarkness commented 5 years ago

@aksnzhy 我这边也是同样的问题,为什么在embedding这里都是乘的norm,而不是sqrtNorm呢