refactor(gpt): 重构 GPT 相关代码并优化功能

- 新增 AbsGPTManager 抽象类,定义 GPT 管理器的通用接口
- 重命名 BaiduGPTManager 类,使其位于 com.yutou.qqbot.gpt 包中
- 更新相关引用和依赖
- 优化部分代码结构,提高可维护性
This commit is contained in:
2025-02-04 18:15:15 +08:00
parent 09305ae824
commit e7fae929a1
4 changed files with 25 additions and 10 deletions

View File

@@ -3,10 +3,9 @@ package com.yutou.qqbot.utlis;
import com.yutou.napcat.QQDatabase;
import com.yutou.napcat.model.GroupBean;
import com.yutou.qqbot.Annotations.UseModel;
import com.yutou.qqbot.QQBotManager;
import com.yutou.qqbot.QQNumberManager;
import com.yutou.qqbot.gpt.BaiduGPTManager;
import com.yutou.qqbot.models.Model;
import lombok.val;
import org.springframework.boot.ApplicationArguments;
import org.springframework.boot.ApplicationRunner;
import org.springframework.stereotype.Component;

View File

@@ -1,219 +0,0 @@
package com.yutou.qqbot.utlis;
import com.baidubce.qianfan.Qianfan;
import com.baidubce.qianfan.model.chat.ChatResponse;
import com.baidubce.qianfan.model.image.Image2TextResponse;
import com.baidubce.qianfan.model.image.Text2ImageResponse;
import com.yutou.qqbot.data.baidu.Message;
import lombok.val;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.nio.file.Files;
import java.nio.file.StandardCopyOption;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
public class BaiduGPTManager {
private static final AtomicInteger MAX_MESSAGE = new AtomicInteger(20);
private static final String AppID = ConfigTools.load(ConfigTools.CONFIG, ConfigTools.BAIDU_GPT_APPID, String.class);
private static final String ApiKey = ConfigTools.load(ConfigTools.CONFIG, ConfigTools.BAIDU_GPT_API_KEY, String.class);
//ConfigTools.load操作可以确保获取到相关参数所以无需关心
private static final String AccessKey = ConfigTools.load(ConfigTools.CONFIG, ConfigTools.BAIDU_GPT_ACCESS_KEY, String.class);
private static final String SecretKey = ConfigTools.load(ConfigTools.CONFIG, ConfigTools.BAIDU_GPT_SECRET_KEY, String.class);
private final ConcurrentHashMap<String, List<Message>> msgMap;
private final static String modelFor40 = "ERNIE-4.0-8K";
private final static String modelFor35 = "ERNIE-3.5-8K";
private String model = modelFor35;
// 新增锁映射表
private final ConcurrentHashMap<String, AtomicBoolean> userLocks = new ConcurrentHashMap<>();
private final Qianfan qianfan;
private BaiduGPTManager() {
msgMap = new ConcurrentHashMap<>();
qianfan = new Qianfan(AccessKey, SecretKey);
String savedVersion = ConfigTools.load(ConfigTools.CONFIG, ConfigTools.BAIDU_GPT_VERSION, String.class);
if (StringUtils.isEmpty(savedVersion) || (!"3.5".equals(savedVersion) && !"4.0".equals(savedVersion))) {
savedVersion = "3.5";
ConfigTools.save(ConfigTools.CONFIG, ConfigTools.BAIDU_GPT_VERSION, savedVersion);
}
model = "3.5".equals(savedVersion) ? modelFor35 : modelFor40;
}
private static volatile BaiduGPTManager manager;
public static BaiduGPTManager getManager() {
if (manager == null) {
synchronized (BaiduGPTManager.class) {
if (manager == null) {
manager = new BaiduGPTManager();
}
}
}
return manager;
}
public int setMaxMessageCount(int count) {
MAX_MESSAGE.set(count);
return count;
}
public synchronized void setModelFor40() {
model = modelFor40;
ConfigTools.save(ConfigTools.CONFIG, ConfigTools.BAIDU_GPT_VERSION, "4.0");
}
public synchronized void setModelFor35() {
model = modelFor35;
ConfigTools.save(ConfigTools.CONFIG, ConfigTools.BAIDU_GPT_VERSION, "3.5");
}
/**
* 这里确实是需要清空所有数据
*/
public synchronized void clear() { // 添加同步
msgMap.clear();
for (AtomicBoolean value : userLocks.values()) {
value.set(false);
}
userLocks.forEachValue(1, atomicBoolean -> atomicBoolean.set(false));
userLocks.clear();
}
// 这个是官方的示例代码,表示连续对话
private static void exampleChat() {
Qianfan qianfan = new Qianfan();
ChatResponse response = qianfan.chatCompletion()
// 设置需要使用的模型与endpoint同时只能设置一种
.model("ERNIE-Bot")
// 通过传入历史对话记录来实现多轮对话
.addMessage("user", "你好!你叫什么名字?")
.addMessage("assistant", "你好我是文心一言英文名是ERNIE Bot。")
// 传入本轮对话的用户输入
.addMessage("user", "刚刚我的问题是什么?")
.execute();
System.out.println("输出内容:" + response.getResult());
}
public Message sendMessage(String user, String message) {
// 获取或创建用户锁
AtomicBoolean lock = userLocks.computeIfAbsent(user, k -> new AtomicBoolean(false));
// 尝试加锁(如果已被锁定则立即返回提示)
if (!lock.compareAndSet(false, true)) {
return Message.create("您有请求正在处理中,请稍后再试", true);
}
try {
List<Message> list = msgMap.computeIfAbsent(user, k -> Collections.synchronizedList(new ArrayList<>()));
// 限制历史消息的最大数量
synchronized (list) {
if (list.size() >= MAX_MESSAGE.get()) {
int removeCount = list.size() - MAX_MESSAGE.get() + 1; // 腾出空间给新消息
list.subList(0, removeCount).clear();
}
list.add(Message.create(message));
}
val builder = qianfan.chatCompletion()
.model(model);
for (Message msg : list) {
builder.addMessage(msg.getRole(), msg.getContent());
}
ChatResponse chatResponse = builder.execute();
Message response = Message.create(chatResponse.getResult(), true);
synchronized (list) {
list.add(response);
if (list.size() > MAX_MESSAGE.get()) {
int overflow = list.size() - MAX_MESSAGE.get();
list.subList(0, overflow).clear();
}
}
// msgMap.put(user, list);
return response;
} catch (Exception e) {
Log.e(e, message);
return Message.create("请求失败,请重试", true);
} finally {
lock.set(false);
userLocks.remove(user, lock);
}
}
/**
* 将文本转换为图像文件
* 该方法使用预训练的AI模型将给定的文本转换为图像并将其保存为文件
*
* @param user 用户标识符,用于为生成的图像文件命名
* @param text 要转换为图像的文本
* @return 返回生成的图像文件对象如果转换过程中发生错误则返回null
*/
public File textToImage(String user, String text) {
// 使用QianFan的text2Image方法将文本转换为图像数据
Text2ImageResponse response = qianfan.text2Image()
.prompt(text)
.execute();
// 获取转换后的图像数据以Base64编码的图像字符串形式
val b64Image = response.getData().get(0).getB64Image();
// 将Base64编码的图像数据转换为图像文件
// 创建一个临时目录下的图像文件,文件名包含用户标识符和当前时间戳,以确保唯一性
val imageFile = new File("tmp" + File.separator + user + "_" + System.currentTimeMillis() + ".png");
try (val inputStream = new ByteArrayInputStream(Base64.getDecoder().decode(b64Image))) {
// 将解码后的图像数据复制到图像文件中,替换现有文件
Files.copy(inputStream, imageFile.toPath(), StandardCopyOption.REPLACE_EXISTING);
return imageFile;
} catch (Exception e) {
// 如果在图像文件生成过程中发生错误,记录错误信息
Log.e(e);
}
// 如果发生错误返回null
return null;
}
/**
* 将图片转换为文本描述
*
* @param user 使用该功能的用户标识
* @param file 要转换的图片文件
* @return 转换后的文本描述如果转换失败则返回null
*/
public String imageToText(String user, File file) {
// 将file文件转换成base64的代码
try {
// 读取文件内容并转换为Base64编码
val base64 = Base64.getEncoder().encodeToString(Files.readAllBytes(file.toPath()));
// 调用图像转文本的API
Image2TextResponse response = qianfan.image2Text()
.image(base64)
.prompt("请描述这张图片中的主要内容和细节,以及它们之间的关系\n")
.execute();
String translationPrompt = "将以下英文内容严格翻译为简体中文不要解释、不要添加额外内容保留专业术语和名称如Star Wars保持英文\n" + response.getResult();
// 获取API返回的结果
return sendMessage("bot",translationPrompt).getContent();
} catch (Exception e) {
// 异常处理:记录错误日志
Log.e(e);
}
// 如果发生异常返回null
return null;
}
public String getGPTVersion() {
return (model.equals(modelFor35) ? "3.5" : "4.0");
}
public static void main(String[] args) throws Exception {
// BaiduGPTManager.getManager().textToImage("user","画一个猫娘,用二次元动画画风,她是粉色头发,坐在地上");
// BaiduGPTManager.getManager().imageToText("user",new File("test.png"));
// Message message = BaiduGPTManager.getManager().sendMessage("user", "现在假设小猪等于1,小猴等于2");
// System.out.println(message.getContent());
// message = BaiduGPTManager.getManager().sendMessage("user", "那么小猪加上小猴等于多少?");
// System.out.println(message.getContent());
System.out.println(BaiduGPTManager.getManager().sendMessage("user", "分析这个网页链接的页面内容,而非链接本身:https://www.bilibili.com/video/BV1TTf5YrESz/").getContent());
}
}