- 将 ERNIE-3.5-8K 模型替换为 ERNIE-Speed-128K 模型 - 修改 getGPTVersion 方法以直接返回当前模型名称 - 更新 QQBot 版本号至 1.7.10
229 lines
9.8 KiB
Java
229 lines
9.8 KiB
Java
package com.yutou.qqbot.gpt;
|
||
|
||
import com.baidubce.qianfan.Qianfan;
|
||
import com.baidubce.qianfan.model.chat.ChatResponse;
|
||
import com.baidubce.qianfan.model.image.Image2TextResponse;
|
||
import com.baidubce.qianfan.model.image.Text2ImageResponse;
|
||
import com.yutou.qqbot.data.baidu.Message;
|
||
import com.yutou.qqbot.utlis.ConfigTools;
|
||
import com.yutou.qqbot.utlis.Log;
|
||
import com.yutou.qqbot.utlis.StringUtils;
|
||
import lombok.val;
|
||
|
||
import java.io.ByteArrayInputStream;
|
||
import java.io.File;
|
||
import java.nio.file.Files;
|
||
import java.nio.file.StandardCopyOption;
|
||
import java.util.ArrayList;
|
||
import java.util.Base64;
|
||
import java.util.Collections;
|
||
import java.util.List;
|
||
import java.util.concurrent.ConcurrentHashMap;
|
||
import java.util.concurrent.atomic.AtomicBoolean;
|
||
import java.util.concurrent.atomic.AtomicInteger;
|
||
|
||
public class BaiduGPTManager extends AbsGPTManager {
|
||
|
||
private static final String AppID = ConfigTools.load(ConfigTools.CONFIG, ConfigTools.BAIDU_GPT_APPID, String.class);
|
||
private static final String ApiKey = ConfigTools.load(ConfigTools.CONFIG, ConfigTools.BAIDU_GPT_API_KEY, String.class);
|
||
//ConfigTools.load操作可以确保获取到相关参数,所以无需关心
|
||
private static final String AccessKey = ConfigTools.load(ConfigTools.CONFIG, ConfigTools.BAIDU_GPT_ACCESS_KEY, String.class);
|
||
private static final String SecretKey = ConfigTools.load(ConfigTools.CONFIG, ConfigTools.BAIDU_GPT_SECRET_KEY, String.class);
|
||
private final ConcurrentHashMap<String, List<Message>> msgMap;
|
||
private final static String modelFor40 = "ERNIE-4.0-8K";
|
||
private final static String modelFor35 = "ERNIE-Speed-128K";
|
||
private String model = modelFor35;
|
||
// 新增锁映射表
|
||
private final ConcurrentHashMap<String, AtomicBoolean> userLocks = new ConcurrentHashMap<>();
|
||
private final Qianfan qianfan;
|
||
|
||
private BaiduGPTManager() {
|
||
msgMap = new ConcurrentHashMap<>();
|
||
qianfan = new Qianfan(AccessKey, SecretKey);
|
||
String savedVersion = ConfigTools.load(ConfigTools.CONFIG, ConfigTools.BAIDU_GPT_VERSION, String.class);
|
||
if (StringUtils.isEmpty(savedVersion) || (!"3.5".equals(savedVersion) && !"4.0".equals(savedVersion))) {
|
||
savedVersion = "3.5";
|
||
ConfigTools.save(ConfigTools.CONFIG, ConfigTools.BAIDU_GPT_VERSION, savedVersion);
|
||
}
|
||
model = "3.5".equals(savedVersion) ? modelFor35 : modelFor40;
|
||
}
|
||
|
||
private static volatile BaiduGPTManager manager;
|
||
|
||
public static BaiduGPTManager getManager() {
|
||
if (manager == null) {
|
||
synchronized (BaiduGPTManager.class) {
|
||
if (manager == null) {
|
||
manager = new BaiduGPTManager();
|
||
}
|
||
}
|
||
}
|
||
return manager;
|
||
}
|
||
|
||
@Override
|
||
public int setMaxMessageCount(int count) {
|
||
MAX_MESSAGE.set(count);
|
||
return count;
|
||
}
|
||
|
||
public synchronized void setModelFor40() {
|
||
model = modelFor40;
|
||
ConfigTools.save(ConfigTools.CONFIG, ConfigTools.BAIDU_GPT_VERSION, "4.0");
|
||
}
|
||
|
||
public synchronized void setModelFor35() {
|
||
model = modelFor35;
|
||
ConfigTools.save(ConfigTools.CONFIG, ConfigTools.BAIDU_GPT_VERSION, "3.5");
|
||
}
|
||
|
||
/**
|
||
* 这里确实是需要清空所有数据
|
||
*/
|
||
@Override
|
||
public synchronized void clear() { // 添加同步
|
||
msgMap.clear();
|
||
for (AtomicBoolean value : userLocks.values()) {
|
||
value.set(false);
|
||
}
|
||
userLocks.forEachValue(1, atomicBoolean -> atomicBoolean.set(false));
|
||
userLocks.clear();
|
||
}
|
||
|
||
|
||
// 这个是官方的示例代码,表示连续对话
|
||
private static void exampleChat() {
|
||
Qianfan qianfan = new Qianfan();
|
||
ChatResponse response = qianfan.chatCompletion()
|
||
// 设置需要使用的模型,与endpoint同时只能设置一种
|
||
.model("ERNIE-Bot")
|
||
// 通过传入历史对话记录来实现多轮对话
|
||
.addMessage("user", "你好!你叫什么名字?")
|
||
.addMessage("assistant", "你好!我是文心一言,英文名是ERNIE Bot。")
|
||
// 传入本轮对话的用户输入
|
||
.addMessage("user", "刚刚我的问题是什么?")
|
||
.execute();
|
||
System.out.println("输出内容:" + response.getResult());
|
||
}
|
||
|
||
@Override
|
||
public Message sendMessage(String user, String message) {
|
||
// 获取或创建用户锁
|
||
AtomicBoolean lock = userLocks.computeIfAbsent(user, k -> new AtomicBoolean(false));
|
||
// 尝试加锁(如果已被锁定则立即返回提示)
|
||
if (!lock.compareAndSet(false, true)) {
|
||
return Message.create("您有请求正在处理中,请稍后再试", true);
|
||
}
|
||
try {
|
||
List<Message> list = msgMap.computeIfAbsent(user, k -> Collections.synchronizedList(new ArrayList<>()));
|
||
// 限制历史消息的最大数量
|
||
synchronized (list) {
|
||
if (list.size() >= MAX_MESSAGE.get()) {
|
||
int removeCount = list.size() - MAX_MESSAGE.get() + 1; // 腾出空间给新消息
|
||
list.subList(0, removeCount).clear();
|
||
}
|
||
list.add(Message.create(message));
|
||
}
|
||
val builder = qianfan.chatCompletion()
|
||
.model(model);
|
||
for (Message msg : list) {
|
||
builder.addMessage(msg.getRole(), msg.getContent());
|
||
}
|
||
ChatResponse chatResponse = builder.execute();
|
||
Message response = Message.create(chatResponse.getResult(), true);
|
||
synchronized (list) {
|
||
list.add(response);
|
||
if (list.size() > MAX_MESSAGE.get()) {
|
||
int overflow = list.size() - MAX_MESSAGE.get();
|
||
list.subList(0, overflow).clear();
|
||
}
|
||
}
|
||
// msgMap.put(user, list);
|
||
return response;
|
||
} catch (Exception e) {
|
||
Log.e(e, message);
|
||
return Message.create("请求失败,请重试", true);
|
||
} finally {
|
||
lock.set(false);
|
||
userLocks.remove(user, lock);
|
||
}
|
||
|
||
}
|
||
|
||
/**
|
||
* 将文本转换为图像文件
|
||
* 该方法使用预训练的AI模型将给定的文本转换为图像,并将其保存为文件
|
||
*
|
||
* @param user 用户标识符,用于为生成的图像文件命名
|
||
* @param text 要转换为图像的文本
|
||
* @return 返回生成的图像文件对象,如果转换过程中发生错误,则返回null
|
||
*/
|
||
@Override
|
||
public File textToImage(String user, String text) {
|
||
// 使用QianFan的text2Image方法将文本转换为图像数据
|
||
Text2ImageResponse response = qianfan.text2Image()
|
||
.prompt(text)
|
||
.execute();
|
||
// 获取转换后的图像数据,以Base64编码的图像字符串形式
|
||
val b64Image = response.getData().get(0).getB64Image();
|
||
// 将Base64编码的图像数据转换为图像文件
|
||
// 创建一个临时目录下的图像文件,文件名包含用户标识符和当前时间戳,以确保唯一性
|
||
val imageFile = new File("tmp" + File.separator + user + "_" + System.currentTimeMillis() + ".png");
|
||
try (val inputStream = new ByteArrayInputStream(Base64.getDecoder().decode(b64Image))) {
|
||
// 将解码后的图像数据复制到图像文件中,替换现有文件
|
||
Files.copy(inputStream, imageFile.toPath(), StandardCopyOption.REPLACE_EXISTING);
|
||
return imageFile;
|
||
} catch (Exception e) {
|
||
// 如果在图像文件生成过程中发生错误,记录错误信息
|
||
Log.e(e);
|
||
}
|
||
// 如果发生错误,返回null
|
||
return null;
|
||
}
|
||
|
||
/**
|
||
* 将图片转换为文本描述
|
||
*
|
||
* @param user 使用该功能的用户标识
|
||
* @param file 要转换的图片文件
|
||
* @return 转换后的文本描述,如果转换失败则返回null
|
||
*/
|
||
@Override
|
||
public String imageToText(String user, File file) {
|
||
// 将file文件转换成base64的代码
|
||
try {
|
||
// 读取文件内容并转换为Base64编码
|
||
val base64 = Base64.getEncoder().encodeToString(Files.readAllBytes(file.toPath()));
|
||
|
||
// 调用图像转文本的API
|
||
Image2TextResponse response = qianfan.image2Text()
|
||
.image(base64)
|
||
.prompt("请描述这张图片中的主要内容和细节,以及它们之间的关系\n")
|
||
.execute();
|
||
String translationPrompt = "将以下英文内容严格翻译为简体中文,不要解释、不要添加额外内容,保留专业术语和名称(如Star Wars保持英文):\n" + response.getResult();
|
||
// 获取API返回的结果
|
||
return sendMessage("bot",translationPrompt).getContent();
|
||
} catch (Exception e) {
|
||
// 异常处理:记录错误日志
|
||
Log.e(e);
|
||
}
|
||
// 如果发生异常,返回null
|
||
return null;
|
||
}
|
||
|
||
@Override
|
||
public String getGPTVersion() {
|
||
return model;
|
||
}
|
||
|
||
public static void main(String[] args) throws Exception {
|
||
// BaiduGPTManager.getManager().textToImage("user","画一个猫娘,用二次元动画画风,她是粉色头发,坐在地上");
|
||
// BaiduGPTManager.getManager().imageToText("user",new File("test.png"));
|
||
// Message message = BaiduGPTManager.getManager().sendMessage("user", "现在假设小猪等于1,小猴等于2");
|
||
// System.out.println(message.getContent());
|
||
// message = BaiduGPTManager.getManager().sendMessage("user", "那么小猪加上小猴等于多少?");
|
||
// System.out.println(message.getContent());
|
||
System.out.println(BaiduGPTManager.getManager().sendMessage("user", "分析这个网页链接的页面内容,而非链接本身:https://www.bilibili.com/video/BV1TTf5YrESz/").getContent());
|
||
}
|
||
}
|