Closed Dylan-JD closed 9 months ago
from mindnlp.models import MSBertModel
import mindspore
from mindnlp.transformers import BertTokenizer
from mindspore import Tensor
import json
import numpy as np
from gaussdb_vector.gauss_db_vector import GaussDBVector
from gaussdb_vector.gauss_table import FieldSchema, IndexSchema, TableSchema
from gaussdb_vector.gauss_types import DataType, MetricType, IndexType, IndexParamType, IndexParams
import sys
sys.path.append('/home/workspace/dongjian/gauss_test/code/gauss_python_program/gaussdb_vector')
# tokenizer = BertTokenizer.from_pretrained('/home/ma-user/data/bert_base_chinese')
tokenizer = BertTokenizer.from_pretrained('/home/workspace/dongjian/gauss_test/code/m3e_model')
mindspore.set_context(mode=mindspore.GRAPH_MODE)
# model = MSBertModel.from_pretrained('/home/ma-user/data/bert_base_chinese')
model = MSBertModel.from_pretrained('/home/workspace/dongjian/gauss_test/code/m3e_model')
model = model.to_float(mindspore.float16)
def predict(text, label=None):
# print(tokenizer.encode(text))
# if len(tokenizer.encode(text).idx) >= 512:
if len(tokenizer.encode(text)) >= 512:
return None
# text_tokenized = Tensor([tokenizer.encode(text).ids])
print(tokenizer.encode(text))
text_tokenized = Tensor(np.array(tokenizer.encode(text)))
print(type(text_tokenized))
logits = model(text_tokenized)
print(logits[0].shape)
embedding = np.mean(logits[0][0].asnumpy(), axis=0)
embedding = embedding.tolist()
return embedding
print(predict('你好'))
If this is your first time, please read our contributor guidelines: https://github.com/mindspore-lab/mindcv/blob/main/CONTRIBUTING.md
Describe the bug/ 问题描述 (Mandatory / 必填) 执行推断时报错: TypeError: For primitive[Dense], the input type must be same. name:[w]:Ref[Tensor[Float32]]. name:[x]:Tensor[Float16].
Ascend
/GPU
/CPU
) / 硬件环境: Ascend 910B3PyNative
/Graph
): Graph To Reproduce / 重现步骤 (Mandatory / 必填) Steps to reproduce the behavior: 使用bert类型的model,先用tokenizer进行encode,随后使用Tensor转为Tensor,最后使用model(Tensor)获得logisticsExpected behavior / 预期结果 (Mandatory / 必填) 输出一个embedding序列 Screenshots/ 日志 / 截图 (Mandatory / 必填)
Additional context / 备注 (Optional / 选填) Add any other context about the problem here.