thunlp / OpenKE

An Open-Source Package for Knowledge Embedding (KE)
3.81k stars 985 forks source link

DistMult and SGD: loss not decreasing #133

Closed AndRossi closed 5 years ago

AndRossi commented 5 years ago

I am using the OpenKE framework to run some experiments in various popular Knowledge Graph Embedding models, and I'm having some trouble with DistMult.

I am trying to train embeddings of dimension 100 running 1000 epochs, with learning rate of 0.001, using SGD, on the FB15K dataset. This configuration is identical to your example with TransE.

The problem is, the loss does not seem to decrease at all: it is stuck around 69.31472, with tiny fluctuations. The testing results after 1000 epochs are, of course, terrible:

no type constraint results:
metric:          MRR         MR          hit@10      hit@3       hit@1 
l(raw):          0.000719    7469.423828     0.000609    0.000237    0.000085 
r(raw):          0.000765    7470.809570     0.000846    0.000254    0.000152 
averaged(raw):       0.000742    7470.116699     0.000728    0.000245    0.000119 

l(filter):       0.000728    7360.574707     0.000626    0.000237    0.000085 
r(filter):       0.000774    7401.953613     0.000863    0.000254    0.000152 
averaged(filter):    0.000751    7381.264160     0.000745    0.000245    0.000119 

type constraint results:
metric:          MRR         MR          hit@10      hit@3       hit@1 
l(raw):          0.054080    559.501587      0.106939    0.051785    0.021042 
r(raw):          0.086261    471.667175      0.157725    0.087979    0.043896 
averaged(raw):       0.070170    515.584351      0.132332    0.069882    0.032469 

l(filter):       0.090418    450.604736      0.159486    0.090518    0.050295 
r(filter):       0.105329    402.801453      0.184151    0.110291    0.058828 
averaged(filter):    0.097873    426.703094      0.171819    0.100405    0.054561 

triple classification accuracy is 0.500076

Here is my current configuration file:

import config
import models
import tensorflow as tf
import numpy as np

con = config.Config()
#Input training files from benchmarks/FB15K/ folder.
con.set_in_path("./benchmarks/FB15K/")

con.set_work_threads(4)
con.set_train_times(1000)
con.set_nbatches(100)
con.set_alpha(0.001)
con.set_margin(1.0)
con.set_bern(0)
con.set_dimension(100)
con.set_ent_neg_rate(1)
con.set_rel_neg_rate(0)
con.set_opt_method("SGD")

#Models will be exported via tf.Saver() automatically.
con.set_export_files("./res/model.vec.tf", 0)
#Model parameters will be exported to json files automatically.
con.set_out_files("./res/embedding.vec.json")
#Initialize experimental settings.
con.init()
#Set the knowledge embedding model
con.set_model(models.DistMult)
#Train the model.
con.run()

And this is the testing main:

os.environ['CUDA_VISIBLE_DEVICES']='7'
# (1) Set import files and OpenKE will automatically load models via tf.Saver().
con = config.Config()
con.set_in_path("./benchmarks/FB15K/")
con.set_test_link_prediction(True)
con.set_test_triple_classification(True)
con.set_log_on(False)

con.set_work_threads(8)
con.set_dimension(100)
con.set_import_files("./res/model.vec.tf")
con.init()
con.set_model(models.DistMult)
con.test()

Any ideas on what could be causing this problem?

AndRossi commented 5 years ago

I just forgot to mention another weird behaviour I noticed.

I tried several configurations of TransE and DistMult. Using Adagrad and Adam instead of SGD, DistMult seems to train fine. Here are the testing results respectively:

DISTMULT WITH ADAGRAD OPTIMIZER

no type constraint results:

metric:          MRR           MR              hit@10        hit@3       hit@1
l(raw):          0.120585      565.615967      0.225542      0.128185      0.066666
r(raw):          0.179348      466.003021      0.306106      0.191837      0.115319
avg(raw):        0.149967      515.809509      0.265824      0.160011      0.090992

l(filter):       0.207128      446.112518      0.317415      0.224052      0.148178
r(filter):       0.266619      390.748138      0.389616      0.288754      0.201774
avg(filter):     0.236874      418.430328      0.353515      0.256403      0.174976

type constraint results:

metric:          MRR           MR              hit@10        hit@3       hit@1
l(raw):          0.133175      363.183777      0.248921      0.141643      0.073098
r(raw):          0.191665      281.449280      0.328909      0.205380      0.121599
avg(raw):        0.162420      322.316528      0.288915      0.173512      0.097349

l(filter):       0.245597      243.721909      0.371417      0.264648      0.178785
r(filter):       0.287590      206.212265      0.424218      0.311540      0.215893
avg(filter):     0.266594      224.967087      0.397818      0.288094      0.197339

triple classification accuracy is 0.703910
DISTMULT WITH ADAM OPTIMIZER

no type constraint results:
metric:           MRR           MR             hit@10        hit@3         hit@1
l(raw):           0.199270      337.434601     0.434037      0.220142      0.096105
r(raw):           0.212262      284.205658     0.455469      0.233533      0.104214
avg(raw):         0.205766      310.820129     0.444753      0.226837      0.100159

l(filter):        0.460150      144.686172     0.716748      0.554062      0.318481
r(filter):        0.466817      164.259415     0.725449      0.557922      0.325727
avg(filter):      0.463484      154.472794     0.721098      0.555992      0.322104

type constraint results:
metric:           MRR           MR             hit@10        hit@3       hit@1
l(raw):           0.242102      231.191345     0.491205      0.272418      0.130335
r(raw):           0.291126      144.388855     0.559496      0.324745      0.169897
avg(raw):         0.266614      187.790100     0.525351      0.298581      0.150116

l(filter):        0.564306      38.424725      0.809568      0.662322      0.425776
r(filter):        0.594326      24.442619      0.844915      0.694266      0.453827
avg(filter):      0.579316      31.433672      0.827242      0.678294      0.439801

triple classification accuracy is 0.897284

As you can see there is a huge difference in performances between the two. I think this is quite weird: I expected the Adam version to be better, but not twice as good as the Adagrad one.

Furthermore, both versions of DistMult perform worse than TransE with plain old SGD (that, in turn, performs insanely good). This is really strange: DistMult is often acclaimed as overall better performing than TransE.

TRANSE WITH SGD OPTIMIZER

no type constraint results:

metric:               MRR          MR             hit@10       hit@3        hit@1
l(raw):               0.219245     269.490387     0.461394     0.252323     0.107413
r(raw):               0.274543     169.209732     0.543228     0.317076     0.148465
averaged(raw):        0.246894     219.350067     0.502311     0.284700     0.127939

l(filter):           0.469272      84.929123      0.727006     0.554993     0.328858
r(filter):           0.520368      54.371773      0.782364     0.613499     0.374617
averaged(filter):    0.494820      69.650452      0.754685     0.584246     0.351738

type constraint results:

metric:               MRR           MR             hit@10       hit@3        hit@1
l(raw):               0.236391      221.997665     0.476511     0.268795     0.125899
r(raw):               0.294595      136.958542     0.558954     0.334022     0.171878
averaged(raw):        0.265493      179.478104     0.517733     0.301408     0.148889

l(filter):            0.515220      37.436390      0.760018     0.596401     0.381896
r(filter):            0.558091      22.120584      0.807824     0.645799     0.419986
averaged(filter):     0.536655      29.778488      0.783921     0.621100     0.400941

triple classification accuracy is 0.871478
Chrixtar commented 5 years ago

I am having the same issue too but also with Adagrad not only for SGD. @AndRossi, could you tell me, how you have fixed this?

AndRossi commented 5 years ago

Hi @Chrixtar. Unfortunately I haven't fixed it :( After days of tests, I resigned to just using Adam for DistMult and SGD for TransE. That's weird, though.