Open NacolZero opened 1 year ago
@Slf4j
@Service
@Component
public class StoreService {
@Autowired
private StoreDao storeDao;
@Autowired
private StoreServiceImpl storeService;
public Store create(Store store) {
Assert.notNull(store, "store can't be null");
store.setName(StrUtils.emptyToNull(store.getName()));
store.setPrice(store.getPrice() == null ? 0.0 : store.getPrice());
if (store.getPrice() < 100) {
throw new BizException("商品价格必须大于100");
}
if ("001".equals(store.getType())) {
if (store.getPrice() < 200) {
throw new BizException("商品类型为 '001',价格必须大于200");
}
}
if (storeDao.selectByName(store.getName()) != null) {
throw new BizException("商品名称不能重复");
}
storeDao.insert(store);
return storeService.selectById(store.getId());
}
public Store update(Store store) {
Assert.notNull(store, "store can't be null");
store.setName(StrUtils.emptyToNull(store.getName()));
store.setPrice(store.getPrice() == null ? 0.0 : store.getPrice());
if (store.getPrice() < 100) {
throw new BizException("商品价格必须大于100");
}
if ("001".equals(store.getType())) {
if (store.getPrice() < 200) {
throw new BizException("商品类型为 '001',价格必须大于200");
}
}
if (storeDao.selectByName(store.getName()) != null) {
throw new BizException("商品名称不能重复");
}
storeDao.updateById(store);
return storeService.selectById(store.getId());
}
public Store delete(String id) {
Store store = storeDao.selectById(id);
storeDao.deleteById(id);
return store;
}
public Store selectById(String id) {
Store store = storeDao.selectById(id);
return store;
}
public Store selectById(Store store) {
Store store1 = storeDao.selectById(store.getId());
return store1;
}
public List<Store> selectAll() {
return storeDao.selectAll();
}
}
我用的int4量化的版本,用chatglm.cpp 跑的,在macbook pro m2乞丐版上
你要不试试在你的prompt最后再加一个\n ?我发现只有一个\n结尾的时候经常出现重复生成同样代码的问题
Environment
- OS: macos Ventura 13.2.1 - Python: 3.11 - Transformers: 4.30.2 - PyTorch: 2.0.1 - CUDA Support: False
Current Behavior
使用 Mac M2Max 进行推理异常:
- 内存最高吃到94G;
- 要求 Java 语言,推理结果 Python;
- 简单prompt推理时长超长(几十秒到3分钟);
- 复杂prompt经常不会出结果(10分钟);
- 错误的、重复的推理结果。
代码 demo
model_path = "/xxxxxxxxx" model_id = 'ZhipuAI/codegeex2-6b' tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True).half().to("mps") model = model.eval() # remember adding a language tag for better performance prompt = "// language: java\n// 使用Mybatis-plus 的分页查询用户\n" # prompt = "language: Python\n# write a bubble sort function\n" # prompt = "language: Java\n# write a bubble sort function\n" # prompt = "language: Java\n# 使用Mybatis-plus 的分页查询用户\n" #prompt = "# language: Java\n# 使用Mybatis-plus 写一个关于【商城Service】的业务代码,商城的 Service 命名为 StoreService.工具:1. 字符串处理使用hutool的StrUtils; 2. 抛异常使用hutool的Assert; 3. 业务异常使用 BizException; 实体类有字段 :String id;String name;Double price;String type;业务-【新增商品】,业务规则:1. 必填名称;2. 价格必须大于100;3. 如果商品类型为 '001',价格必须大于200;4. 商品名称不能重复。\n" # prompt = "# language:Java\n# 写一个冒泡排序函数" inputs = tokenizer.encode(prompt, return_tensors="pt").to(model.device) outputs = model.generate(inputs, max_length=888) response = tokenizer.decode(outputs[0]) print(response)
具体案例如下:
Case 1:官方 demo 的 prompt = “# language: Python\n# write a bubble sort function\n”
可推理出结果,时间大约10秒。
Case 2:官方 demo 改为 java prompt = “# language: Java\n# write a bubble sort function\n”
写出了Python, 时间大约10秒。
language: Java # write a bubble sort function def bubble_sort(arr): for i in range(len(arr) - 1): for j in range(len(arr) - 1): if arr[j] > arr[j + 1]: arr[j], arr[j + 1] = arr[j + 1], arr[j] return arr print(bubble_sort([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))
Case 3:prompt = “# language:Java\n# 写一个冒泡排序函数”
推理出结果了,但是要求是 Java,写成了 Python,而且中文有大量冗余。
# language:Java # 写一个冒泡排序函数 # 冒泡排序: # 1.比较相邻的元素。如果第一个比第二个大,就交换他们两个。 # 2.对每一对相邻元素作同样的工作,从开始第一对到结尾的最后一对。这步做完后,最后的元素会是最大的数。 # 3.针对所有的元素重复以上的步骤,除了最后一个。 # 4.持续每次对越来越少的元素重复上面的步骤,直到没有任何一对数字需要比较。 # 冒泡排序的原理: # 1.比较相邻的元素。如果第一个比第二个大,就交换他们两个。 # 2.对每一对相邻元素作同样的工作,从开始第一对到结尾的最后一对。这步做完后,最后的元素会是最大的数。 # 3.针对所有的元素重复以上的步骤,除了最后一个。 # 4.持续每次对越来越少的元素重复上面的步骤,直到没有任何一对数字需要比较。 # 冒泡排序的代码实现: def bubble_sort(alist): for i in range(len(alist) - 1, 0, -1): for j in range(i): if alist[j] > alist[j + 1]: alist[j], alist[j + 1] = alist[j + 1], alist[j] return alist alist = [54, 26, 93, 17, 77, 31, 44, 55, 20] print(bubble_sort(alist))
Case4:prompt = "# language:Java\n# 冒泡排序“
内存吃到90G,3分钟不出结果
Case5:prompt = "使用Mybatis-plus 的分页查询用户"
- idea 插件:表现正常
- 本地执行:内存吃到90G, 70%卡死,30%可出结果并且结果正常。
Case6:带有大量上下文,有业务场景的长 prompt
prompt = "# language: Java\n# 使用Mybatis-plus 写一个关于【商城Service】的业务代码,商城的 Service 命名为 StoreService.工具:1. 字符串处理使用hutool的StrUtils; 2. 抛异常使用hutool的Assert; 3. 业务异常使用 BizException; 实体类有字段 :String id;String name;Double price;String type;业务-【新增商品】,业务规则:1. 必填名称;2. 价格必须大于100;3. 如果商品类型为 '001',价格必须大于200;4. 商品名称不能重复。\n"
- idea 的插件:表现良好,
- 网友(cuda 3090):基本秒出结果,执行结果正常。
- 本地执行时:内存吃到94G,最长执行了10分钟不出结果,只出过一次结果,显示如下:
Loading checkpoint shards: 100%|██████████| 7/7 [00:05<00:00, 1.35it/s] The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results. Setting `pad_token_id` to `eos_token_id`:2 for open-end generation. /Users/nacol/Projects/llm/CodeGeeX2/venv/lib/python3.11/site-packages/transformers/generation/utils.py:2419: UserWarning: MPS: no support for int64 min/max ops, casting it to int32 (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/mps/operations/ReduceOps.mm:1271.) if unfinished_sequences.max() == 0: # language: Java # 使用Mybatis-plus 写一个关于【商城Service】的业务代码,商城的 Service 命名为 StoreService.工具:1. 字符串处理使用hutool的StrUtils; 2. 抛异常使用hutool的Assert; 3. 业务异常使用 BizException; 实体类有字段 :String id;String name;Double price;String type;业务-【新增商品】,业务规则:1. 必填名称;2. 价格必须大于100;3. 如果商品类型为 '001',价格必须大于200;4. 商品名称不能重复。 import cn.hutool.core.util.StrUtil; import cn.hutool.json.JSONObject; import cn.hutool.json.JSONUtil; import cn.hutool.log.Log; import cn.hutool.log.LogFactory; import cn.hutool.log.LogFactory; import cn.hutool.log.Log; import cn.hutool.log.LogFactory; import cn.hutool.log.Log; import cn.hutool.log.LogFactory; import cn.hutool.log.Log; import cn.hutool.log.LogFactory; import cn.hutool.log.Log; import cn.hutool.log.LogFactory; import cn.hutool.log.Log; import cn.hutool.log.LogFactory; import cn.hutool.log.Log; import cn.hutool.log.LogFactory; import cn.hutool.log.Log; import cn.hutool.log.LogFactory; import cn.hutool.log.Log; import cn.hutool.log.LogFactory; import cn.hutool.log.Log; import cn.hutool.log.LogFactory; import cn.hutool.log.Log; import cn.hutool.log.LogFactory; import cn.hutool.log.Log; import cn.hutool.log.LogFactory; import cn.hutool.log.Log; import cn.hutool.log.LogFactory; import cn.hutool.log.Log; import cn.hutool.log.LogFactory; import cn.hutool.log.Log; import cn.hutool.log.LogFactory; import cn.hutool.log.Log; import cn.hutool.log.LogFactory; import cn.hutool.log.Log; import cn.hutool.log.LogFactory; import cn.hutool.log.Log; import cn.hutool.log.LogFactory; import cn.hutool.log.Log; import cn.hutool.log.LogFactory; import cn.hutool.log.Log; import cn.hutool.log.LogFactory; import cn.hutool.log.Log; import cn.hutool.log.LogFactory; import cn.hutool.log.Log; import cn.hutool.log.LogFactory; import cn.hutool.log.Log; import cn.hutool.log.LogFactory; import cn.hutool.log.Log; import cn.hutool.log.LogFactory; import cn.hutool.log.Log; import cn.hutool.log.LogFactory; import cn.hutool.
Prompt写得有问题,开源的 CodeGeeX2-6B 是一个基座代码模型,它的使用方式是偏补全的,要按照某种语言一般的编程习惯使用就可以了。比如prompt需要使用相应语言的注释符号,Python用"# [prompt]",Java则应该用“// [prompt]”。也可以加一些关键字来引导模型生成函数或类,比如Java用“// [prompt]\npublic class”
Prompt写得有问题,开源的 CodeGeeX2-6B 是一个基座代码模型,它的使用方式是偏补全的,要按照某种语言一般的编程习惯使用就可以了。比如prompt需要使用相应语言的注释符号,Python用"# [prompt]",Java则应该用“// [prompt]”。也可以加一些关键字来引导模型生成函数或类,比如Java用“// [prompt]\npublic class”
感谢你的指出, 我修正了 prompt 如下: prompt = "// language: Java\n// 使用Mybatis-plus 的分页查询用户\n\npublic class" 依然有以下问题:
- 推理速度极慢(1~3分钟);
- 内存耗用极高 40G;
- 输出结果有时候会异常,将我的prompt输出;
- 输出结果有时候正确。
你要不试试在你的prompt最后再加一个\n ?我发现只有一个\n结尾的时候经常出现重复生成同样代码的问题
加了,但是依然没解决生成速度、吃内存等问题。 应该是其它问题。
Prompt写得有问题,开源的 CodeGeeX2-6B 是一个基座代码模型,它的使用方式是偏补全的,要按照某种语言一般的编程习惯使用就可以了。比如prompt需要使用相应语言的注释符号,Python用"# [prompt]",Java则应该用“// [prompt]”。也可以加一些关键字来引导模型生成函数或类,比如Java用“// [prompt]\npublic class”
感谢你的指出, 我修正了 prompt 如下: prompt = "// language: Java\n// 使用Mybatis-plus 的分页查询用户\n\npublic class" 依然有以下问题:
- 推理速度极慢(1~3分钟);
- 内存耗用极高 40G;
- 输出结果有时候会异常,将我的prompt输出;
- 输出结果有时候正确。
Environment
Current Behavior
使用 Mac M2Max 进行推理异常:
代码 demo
具体案例如下:
Case 1:官方 demo 的 prompt = “# language: Python\n# write a bubble sort function\n”
可推理出结果,时间大约10秒。
Case 2:官方 demo 改为 java prompt = “# language: Java\n# write a bubble sort function\n”
写出了Python, 时间大约10秒。
Case 3:prompt = “# language:Java\n# 写一个冒泡排序函数”
推理出结果了,但是要求是 Java,写成了 Python,而且中文有大量冗余。
Case4:prompt = "# language:Java\n# 冒泡排序“
内存吃到90G,3分钟不出结果
Case5:prompt = "使用Mybatis-plus 的分页查询用户"
Case6:带有大量上下文,有业务场景的长 prompt
prompt = "# language: Java\n# 使用Mybatis-plus 写一个关于【商城Service】的业务代码,商城的 Service 命名为 StoreService.工具:1. 字符串处理使用hutool的StrUtils; 2. 抛异常使用hutool的Assert; 3. 业务异常使用 BizException; 实体类有字段 :String id;String name;Double price;String type;业务-【新增商品】,业务规则:1. 必填名称;2. 价格必须大于100;3. 如果商品类型为 '001',价格必须大于200;4. 商品名称不能重复。\n"
后续
2023.08.03 16:00
按照下面两位朋友的建议
已将prompt修改如下:
解决了:要求 java 输出 python 的问题;
未解决:推理出慢、推理卡死、高内存、错误的推理结果等问题
该prompt跑了5次,每次3分钟以上,内存持续上涨到40G手动停止,未的出推理结果。