nlpodyssey / cybertron

Cybertron: the home planet of the Transformers in Go
BSD 2-Clause "Simplified" License
291 stars 26 forks source link

cybertron's moka-ai/m3e-base model is inconsistent with Python #26

Open akaler727 opened 1 year ago

akaler727 commented 1 year ago

The results of inference with cybertron's moka-ai/m3e-base model are not consistent with Python, Please tell me, am I missing something and I changed the type of model_max_length from int to float32 because int will report an error and I wonder if this is the problem golang code:

import (
    "context"
    "fmt"

    "github.com/rs/zerolog"
    "github.com/rs/zerolog/log"

    . "github.com/nlpodyssey/cybertron/examples"
    "github.com/nlpodyssey/cybertron/pkg/models/bert"
    "github.com/nlpodyssey/cybertron/pkg/tasks"
    "github.com/nlpodyssey/cybertron/pkg/tasks/textencoding"
)

func main() {
    zerolog.SetGlobalLevel(zerolog.DebugLevel)
    LoadDotenv()

    modelsDir := HasEnvVar("CYBERTRON_MODELS_DIR")
    modelName := HasEnvVar("CYBERTRON_MODEL") // sentence-transformers_all-MiniLM-L6-v2

    m, err := tasks.Load[textencoding.Interface](&tasks.Config{ModelsDir: modelsDir, ModelName: modelName})
    if err != nil {
        log.Fatal().Err(err).Send()
    }
    defer tasks.Finalize(m)
    fn := func(text string, model int) error {
        result, err := m.Encode(context.Background(), text, model) //int(bert.MeanPooling)
        if err != nil {
            return err
        }
        fmt.Printf("%#v - %d\n\n", result.Vector.Data(), result.Vector.Size())
        return nil
    }

    fn("hello", int(bert.MeanPooling))
}

golang results:

Got a connection, launched process /private/var/folders/my/mmkg8gln76j857660h9q908h0000gn/T/___1go_build_github_com_nlpodyssey_cybertron (pid = 92969).
{"level":"debug","file":"moka-ai/moka-ai/m3e-base/config.json","time":"2023-07-27T14:22:35+08:00","message":"model file already exists, skipping download"}
{"level":"debug","file":"moka-ai/moka-ai/m3e-base/pytorch_model.bin","time":"2023-07-27T14:22:35+08:00","message":"model file already exists, skipping download"}
{"level":"debug","file":"moka-ai/moka-ai/m3e-base/vocab.txt","time":"2023-07-27T14:22:35+08:00","message":"model file already exists, skipping download"}
{"level":"debug","file":"moka-ai/moka-ai/m3e-base/tokenizer_config.json","time":"2023-07-27T14:22:35+08:00","message":"model file already exists, skipping download"}
{"level":"info","model":"moka-ai/moka-ai/m3e-base/spago_model.bin","time":"2023-07-27T14:22:35+08:00","message":"model file already exists, skipping conversion"}
float.floatSlice[float32]{0.5586561, 0.3481341, 1.1915668, -0.66031563, -0.14736839, -1.5355552, 0.025558796, 0.101684175, -1.2309126, 0.20468175, 0.9455595, -0.04586415, 0.6733322, -0.4213071, -0.9842657, -0.14411977, 0.2860089, 0.8369727, -0.74724257, -0.8295647, -0.24167839, 0.14310831, -0.95775265, 0.654835, 0.0402015, 0.030705217, -0.2343363, 0.352139, 0.50944626, 0.37679872, 0.9200878, -0.10877298, 0.11220171, -0.07710601, -0.11649196, -0.22657523, 0.7939314, 0.07352792, 0.25212747, 0.50869745, 0.78861177, 0.78612345, -0.87185585, -0.42625034, -0.0033961039, 0.2769445, 0.41774124, 0.15084733, -1.3549421, -0.059750002, 0.5812156, 5.364834, -0.4415004, 0.54449576, -0.060659435, 0.5178813, -0.5485616, -0.4674707, 0.40060604, 0.19186114, -0.20614424, -1.1355268, 0.34049234, -0.33354956, -0.23492238, -0.08065402, -0.047902226, -0.37340134, 1.2953057, -0.9417114, 0.4743478, -1.4645922, -0.91245496, 0.3786436, 0.90022177, -0.4570604, 0.72824585, 0.123488575, 0.26364398, -0.54934716, 0.20124838, 0.39526105, -0.013703133, 0.1886307, -1.7255692, 0.17734289, -0.91223955, -0.23831634, 0.30638784, 0.39905104, -2.0812528, -0.23258786, -0.2173264, -0.3542292, 0.34415314, -0.58268064, 0.014427334, -0.8137113, -0.39925456, 0.8942062, 0.31036633, -0.21444407, -0.5592039, -0.6024176, 0.49992222, -0.20572749, -0.65747404, 0.7906306, 0.56768584, -0.100532874, -0.022146303, 0.1409393, -0.01063375, 0.044230442, -0.62187946, -0.6604363, -0.67741287, -0.12345415, 1.3856716, -0.51854134, 0.017210325, 0.19882512, -0.4741126, -1.406298, -0.010283401, -0.12648019, 0.17839025, -0.14563477, -0.466341, 0.6114912, 0.0837489, -0.37595928, -1.5684165, 0.25083137, -0.23843801, 0.13842767, -0.29676613, -0.44700682, -0.9568887, -1.5731511, 1.689507, 0.5037775, -0.4791989, -0.5949365, -0.61601853, 0.75016373, -0.007852355, 0.18144856, 0.46363875, -0.0430381, 0.3369313, -0.14247802, 0.22483832, 0.5628735, 0.36320806, 0.5729853, -0.8830575, 0.3194078, -1.3680843, 0.49093843, -0.20522587, -0.34615856, 0.11486925, -0.53484887, 0.7429905, 0.77017415, 0.22035861, -1.2011147, 0.27883685, 0.92228675, -0.6460571, -0.7828549, -0.541044, -0.39700398, -0.5601654, -0.018595917, -0.60338205, 0.7471642, -1.0193474, -0.21543849, -1.0544162, -0.361531, 0.5082523, 0.4297683, -0.037086025, 0.8365776, -0.6101785, 0.37019834, 0.4232986, -0.22428536, -0.8988843, -0.24274461, -0.72213906, -0.3117569, 1.2218258, 0.3386291, -0.8692045, -1.3972731, -1.2372184, 0.46838042, 0.42251644, -0.28002506, -0.32562447, 0.8917746, -0.09792821, 0.2055389, -0.6134357, 0.6462733, 0.524881, 0.31465966, -0.64137524, -0.8678176, -0.100821756, 0.3128143, -0.2529764, 0.22984982, -1.794968, 0.47062716, -0.30664465, -0.041966647, 0.792232, -0.67575026, 0.0015863578, -0.57833123, -1.4809616, -0.30945492, -1.013249, -0.5161098, 0.577516, -0.37767792, 0.32753712, -0.37001777, -0.2641474, -0.8292316, 0.30337378, 0.68331385, -0.40814948, -0.034036644, -0.26826414, -0.12866855, -0.049134076, 0.9968583, 0.4438274, -0.83019894, -0.38212675, 0.7135316, -0.28314376, -0.8654863, -0.4090188, 0.06296025, 0.8331766, 0.06844811, 0.492837, -1.1496916, -0.3197983, -0.4718892, -0.5521597, -0.90919656, 0.120810226, -0.61302817, -0.14986253, -0.20587806, 0.96503794, 0.2814034, 0.033698563, -0.6710733, 0.12410319, -1.5172327, -0.720875, 0.36440298, -0.39186117, 1.2283376, 0.6683216, -0.325405, 0.65716743, -0.2667477, 0.8952669, -0.0074681044, -1.0335073, -0.8435897, -0.06165985, 0.8835367, 0.54728925, -0.20958409, -0.61436784, 0.23365095, -0.05989808, 0.21799684, -0.33517414, -0.8736791, -0.7647462, 0.22665468, 0.45739207, 0.29327714, 4.955877, -1.0347219, -0.42653248, 0.8362527, 0.7743188, 0.06949124, -0.19598548, -1.0427347, -0.19868009, 1.1681774, -0.034899607, 0.6261763, 0.4761127, 0.61922497, -0.12878291, -0.4452449, -0.26318374, -0.29606512, 0.41165692, 0.21679279, -1.1060889, 0.65256184, -0.67687625, -0.90906614, 0.08326803, -0.29034087, 0.4753198, 0.3085594, 0.24984238, 0.84592247, 0.26823223, 0.28253132, 0.67301196, 0.85992277, -0.41336548, 1.054586, 1.258374, -1.1692991, 0.66828126, 1.269233, 0.29028067, 1.297507, 0.11060475, -0.93796605, -0.30756664, 0.37655783, -0.37691545, 0.17787269, 0.23547147, -1.4353197, -0.26889348, -0.4716735, -0.31950897, -0.6160807, -0.1677597, -0.45514965, -1.1609468, 0.06970423, -0.5989529, -0.59813994, -0.6641214, 0.5345185, 1.0468268, 0.39875543, -0.02086733, 0.34882402, -0.17870352, -0.19745669, -1.3137001, 0.974972, 0.45530897, -0.6094151, 0.5464872, 0.47533727, 0.36187488, -0.4182399, 0.83761406, -0.27439693, 1.0363626, -0.64448214, 1.4618754, 0.9909825, -1.5559127, 0.35349584, 0.521738, -0.21153657, 0.15345253, -0.8915105, -0.3603847, -0.16213252, -0.2258909, -0.28113198, 0.15495992, -1.4883742, -0.27888572, -1.0159744, -0.76716894, 0.436343, 0.68445456, -0.42764467, -0.5621178, -0.6797613, -0.010237455, 0.43840218, 0.7121403, 0.7438488, 0.9068295, -0.5387485, -0.38599, 0.48751765, -0.24927513, 0.031315055, -0.42303988, -0.5412056, 0.9670886, -0.1814614, -0.61184525, 0.090777576, -0.40066075, -1.7312568, 0.22919196, 0.051517244, -0.71642894, -0.5013463, 1.3154533, -1.2769346, -0.08370313, -0.123125, 1.0336876, -0.5957928, 0.23143372, 0.21134731, 0.47409233, 0.084024325, 0.7122943, 0.0050302846, 0.9451868, 0.49944115, 0.9674804, -0.42512417, 0.96013325, 0.49117666, -1.0473185, 0.23189923, -0.14844075, -1.0158184, -0.564623, -0.051982604, 0.085207924, -0.0627882, -0.5974218, -0.11681984, 0.28066874, -0.9524521, 0.4369557, -0.49451694, 0.6792588, 0.008014753, 0.1434726, -0.9833708, 0.7839698, -0.15273994, 0.5508911, 0.4452134, 0.089669466, 0.71982974, -0.6315946, -0.27465117, -0.09060815, 0.93898857, 0.14761311, -0.11981739, 0.6109779, -0.41099262, 0.95354813, 0.59368265, 0.0070482045, 0.98043317, -0.12696317, -0.106457226, -0.9131175, -0.26857468, 0.5230087, 0.5626597, -0.99615306, -1.2188501, 0.5033919, -0.27135533, 0.019292366, -1.0229635, -0.81190383, 0.5747945, 0.47489303, -0.25216419, 0.061917096, -0.20595881, -1.202519, 0.3710681, 1.5564482, -0.95155925, -0.5505824, -0.44009846, 0.13443932, 0.13497871, -0.66424406, 1.2706015, 1.1227267, -0.30020985, -0.037559893, 0.7721657, 0.33907586, -0.08006205, -0.8495388, 0.10940178, -0.59910554, 0.8866389, 0.48454726, -0.17196727, 0.3195081, 1.4303505, -0.09417012, -0.7957394, -0.45309362, 0.69021237, 0.21278183, 0.21021526, 1.7833284, -0.008728772, -0.07566177, 1.298654, 0.108984545, 0.3140452, -0.10412737, -0.029870307, 0.64668167, 2.178915, -0.5301237, 0.16194561, -0.07004644, -0.22161862, 0.47245592, 0.26612693, -0.40323877, -1.0126446, -0.17367098, -1.2763882, 0.78977966, 0.3725124, 0.89175224, -0.2812779, 1.1134355, -0.71118057, 0.8265137, -0.4853515, 0.7031471, 0.084302574, -0.47851586, 0.5777946, -0.19740188, 0.92293465, 0.52769816, -0.34881315, -2.408277, -0.9068, 0.6417775, -0.43638715, -0.25479567, 0.06316389, -1.6539805, -0.41115344, -0.2593305, 0.27070582, -0.16882461, 0.018293774, -0.20345013, -0.5166233, 0.69859886, 0.7007345, 0.5515851, 0.68067616, -0.015562555, -0.11025004, -0.18250214, -0.6656741, -0.2616398, 0.8520726, -0.22963691, 1.2325003, 0.46701306, -1.0965149, -0.73943067, -0.22813259, 0.31647295, -0.7997521, 0.06252813, 0.47044277, -0.156255, 0.56394464, 0.020972114, 1.1977443, -0.048263144, 0.70653, 0.4685727, -0.053483915, -0.6645753, 0.9429509, 0.6438494, 0.5978333, 0.28899002, 0.6807536, 0.6045586, -1.0483279, 1.8320946, -0.08943628, -0.023991987, -0.6563122, -1.1724004, 0.8837205, 0.810367, -0.23900923, 0.02903029, 0.65648925, -0.89517593, -0.15514776, 0.30705136, 0.26269007, 0.13684818, -0.54990864, -0.59346026, 0.3405019, -0.5146506, -0.42077392, 2.3867738, 0.18798193, 0.31524557, -0.61146, -0.014019772, -0.9659583, 1.0523849, -0.18784826, -0.15733469, 1.4298962, -0.37591442, 0.015373915, -0.83349156, 1.6636237, 1.0818708, -0.49549565, -1.0145818, -0.54088354, 0.88212216, 0.23834348, 1.5952225, 0.3768664, 0.36205453, -0.10613571, 0.25902507, -0.057457786, 0.6561526, -0.47616667, 0.55362934, 0.62362164, 0.9500669, 0.1519331, -0.6059568, 0.020420095, 0.38898036, 0.014923382, 1.8089727, -1.700867, 0.075224034, -0.92983294, 0.23721413, 0.070676416, 0.85289526, 0.81218004, 0.32324332, 0.5002023, 0.2116686, -0.73058957, -0.22014074, -0.76425236, 0.79867566, 0.9991926, -0.10438898, -0.53062564, -0.69680965, 0.56213224, 0.25225624, -1.2770052, -0.23407626, -1.0992198, -0.6047719, 0.19346945, -0.3439815, -0.32097661, -0.306589, 1.1400976, -0.16471976, -0.15639892, -0.69532585, 2.7921805, -0.59517515, 0.5892587, 0.47785053, -0.18247125, 1.1331238, 0.87331665, -0.1656356, -0.7234974, -0.16103837, -0.24972527, -0.97285163, -0.22834465, -0.06689389, -0.3923758, -0.0989133, -1.1746864, 1.0151973, -0.3554772, 0.016932443, -0.34616363, 0.5761374, 1.0964034, 0.963796, 0.3298906, 1.3143567, -0.14396495, -0.075510725, -0.031782664, 0.15960157, 1.2680503, 0.18764167, 0.86476505, 0.46823892, 0.2136173, -0.2094295, -0.16769594, -0.8687974, -1.305096, -0.05746324, 0.37226224, -0.29579797, -0.25378853, 0.48217577, 0.25654075, -0.40801167, 0.04489517, 1.046751, -0.726053, -1.4279008, -1.3432583, 0.06791468, -0.051261686, 0.3293378, 0.8105246, -0.3564893, -0.2491971, 0.68325007, -0.14297698, 0.26161194, -0.4075473, 0.39791656, -1.1087991, -0.31646204, -0.13931182, -0.039432134, 0.0050914837, 0.9031293, -1.4465828, -0.9414681, -1.2794018, -1.1000504, -0.7025533} - 768

python code:

def test():

    tokenizer = AutoTokenizer.from_pretrained("moka-ai/m3e-base")
    model = AutoModel.from_pretrained("moka-ai/m3e-base")

    text = "hello"

    inputs = tokenizer(text, return_tensors="pt")
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]

    with torch.no_grad():
    outputs = model(input_ids, attention_mask=attention_mask)

    embeddings = outputs.last_hidden_state[0]
    res = embeddings.numpy()
    print(res)
    print('end')

python results:

[[ 0.55908877  0.37521887  1.0591375  ... -1.2979769  -1.0342392
  -0.7252413 ]
 [ 0.5580415   0.29322627  1.4551864  ... -1.2450337  -1.2312784
  -0.6592587 ]
 [ 0.55908877  0.37521887  1.0591375  ... -1.297977   -1.0342392
  -0.7252413 ]]
matteo-grella commented 1 year ago

Thanks, I’m checking. Back to you asap.