refactor(gpt): 重构 GPT 相关代码并优化功能
- 新增 AbsGPTManager 抽象类,定义 GPT 管理器的通用接口 - 重命名 BaiduGPTManager 类,使其位于 com.yutou.qqbot.gpt 包中 - 更新相关引用和依赖 - 优化部分代码结构,提高可维护性
This commit is contained in:
@@ -3,10 +3,9 @@ package com.yutou.qqbot.utlis;
|
||||
import com.yutou.napcat.QQDatabase;
|
||||
import com.yutou.napcat.model.GroupBean;
|
||||
import com.yutou.qqbot.Annotations.UseModel;
|
||||
import com.yutou.qqbot.QQBotManager;
|
||||
import com.yutou.qqbot.QQNumberManager;
|
||||
import com.yutou.qqbot.gpt.BaiduGPTManager;
|
||||
import com.yutou.qqbot.models.Model;
|
||||
import lombok.val;
|
||||
import org.springframework.boot.ApplicationArguments;
|
||||
import org.springframework.boot.ApplicationRunner;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@@ -1,219 +0,0 @@
|
||||
package com.yutou.qqbot.utlis;
|
||||
|
||||
import com.baidubce.qianfan.Qianfan;
|
||||
import com.baidubce.qianfan.model.chat.ChatResponse;
|
||||
import com.baidubce.qianfan.model.image.Image2TextResponse;
|
||||
import com.baidubce.qianfan.model.image.Text2ImageResponse;
|
||||
import com.yutou.qqbot.data.baidu.Message;
|
||||
import lombok.val;
|
||||
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.File;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.StandardCopyOption;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Base64;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
public class BaiduGPTManager {
|
||||
private static final AtomicInteger MAX_MESSAGE = new AtomicInteger(20);
|
||||
private static final String AppID = ConfigTools.load(ConfigTools.CONFIG, ConfigTools.BAIDU_GPT_APPID, String.class);
|
||||
private static final String ApiKey = ConfigTools.load(ConfigTools.CONFIG, ConfigTools.BAIDU_GPT_API_KEY, String.class);
|
||||
//ConfigTools.load操作可以确保获取到相关参数,所以无需关心
|
||||
private static final String AccessKey = ConfigTools.load(ConfigTools.CONFIG, ConfigTools.BAIDU_GPT_ACCESS_KEY, String.class);
|
||||
private static final String SecretKey = ConfigTools.load(ConfigTools.CONFIG, ConfigTools.BAIDU_GPT_SECRET_KEY, String.class);
|
||||
private final ConcurrentHashMap<String, List<Message>> msgMap;
|
||||
private final static String modelFor40 = "ERNIE-4.0-8K";
|
||||
private final static String modelFor35 = "ERNIE-3.5-8K";
|
||||
private String model = modelFor35;
|
||||
// 新增锁映射表
|
||||
private final ConcurrentHashMap<String, AtomicBoolean> userLocks = new ConcurrentHashMap<>();
|
||||
private final Qianfan qianfan;
|
||||
|
||||
private BaiduGPTManager() {
|
||||
msgMap = new ConcurrentHashMap<>();
|
||||
qianfan = new Qianfan(AccessKey, SecretKey);
|
||||
String savedVersion = ConfigTools.load(ConfigTools.CONFIG, ConfigTools.BAIDU_GPT_VERSION, String.class);
|
||||
if (StringUtils.isEmpty(savedVersion) || (!"3.5".equals(savedVersion) && !"4.0".equals(savedVersion))) {
|
||||
savedVersion = "3.5";
|
||||
ConfigTools.save(ConfigTools.CONFIG, ConfigTools.BAIDU_GPT_VERSION, savedVersion);
|
||||
}
|
||||
model = "3.5".equals(savedVersion) ? modelFor35 : modelFor40;
|
||||
}
|
||||
|
||||
private static volatile BaiduGPTManager manager;
|
||||
|
||||
public static BaiduGPTManager getManager() {
|
||||
if (manager == null) {
|
||||
synchronized (BaiduGPTManager.class) {
|
||||
if (manager == null) {
|
||||
manager = new BaiduGPTManager();
|
||||
}
|
||||
}
|
||||
}
|
||||
return manager;
|
||||
}
|
||||
|
||||
public int setMaxMessageCount(int count) {
|
||||
MAX_MESSAGE.set(count);
|
||||
return count;
|
||||
}
|
||||
|
||||
public synchronized void setModelFor40() {
|
||||
model = modelFor40;
|
||||
ConfigTools.save(ConfigTools.CONFIG, ConfigTools.BAIDU_GPT_VERSION, "4.0");
|
||||
}
|
||||
|
||||
public synchronized void setModelFor35() {
|
||||
model = modelFor35;
|
||||
ConfigTools.save(ConfigTools.CONFIG, ConfigTools.BAIDU_GPT_VERSION, "3.5");
|
||||
}
|
||||
|
||||
/**
|
||||
* 这里确实是需要清空所有数据
|
||||
*/
|
||||
public synchronized void clear() { // 添加同步
|
||||
msgMap.clear();
|
||||
for (AtomicBoolean value : userLocks.values()) {
|
||||
value.set(false);
|
||||
}
|
||||
userLocks.forEachValue(1, atomicBoolean -> atomicBoolean.set(false));
|
||||
userLocks.clear();
|
||||
}
|
||||
|
||||
|
||||
// 这个是官方的示例代码,表示连续对话
|
||||
private static void exampleChat() {
|
||||
Qianfan qianfan = new Qianfan();
|
||||
ChatResponse response = qianfan.chatCompletion()
|
||||
// 设置需要使用的模型,与endpoint同时只能设置一种
|
||||
.model("ERNIE-Bot")
|
||||
// 通过传入历史对话记录来实现多轮对话
|
||||
.addMessage("user", "你好!你叫什么名字?")
|
||||
.addMessage("assistant", "你好!我是文心一言,英文名是ERNIE Bot。")
|
||||
// 传入本轮对话的用户输入
|
||||
.addMessage("user", "刚刚我的问题是什么?")
|
||||
.execute();
|
||||
System.out.println("输出内容:" + response.getResult());
|
||||
}
|
||||
|
||||
public Message sendMessage(String user, String message) {
|
||||
// 获取或创建用户锁
|
||||
AtomicBoolean lock = userLocks.computeIfAbsent(user, k -> new AtomicBoolean(false));
|
||||
// 尝试加锁(如果已被锁定则立即返回提示)
|
||||
if (!lock.compareAndSet(false, true)) {
|
||||
return Message.create("您有请求正在处理中,请稍后再试", true);
|
||||
}
|
||||
try {
|
||||
List<Message> list = msgMap.computeIfAbsent(user, k -> Collections.synchronizedList(new ArrayList<>()));
|
||||
// 限制历史消息的最大数量
|
||||
synchronized (list) {
|
||||
if (list.size() >= MAX_MESSAGE.get()) {
|
||||
int removeCount = list.size() - MAX_MESSAGE.get() + 1; // 腾出空间给新消息
|
||||
list.subList(0, removeCount).clear();
|
||||
}
|
||||
list.add(Message.create(message));
|
||||
}
|
||||
val builder = qianfan.chatCompletion()
|
||||
.model(model);
|
||||
for (Message msg : list) {
|
||||
builder.addMessage(msg.getRole(), msg.getContent());
|
||||
}
|
||||
ChatResponse chatResponse = builder.execute();
|
||||
Message response = Message.create(chatResponse.getResult(), true);
|
||||
synchronized (list) {
|
||||
list.add(response);
|
||||
if (list.size() > MAX_MESSAGE.get()) {
|
||||
int overflow = list.size() - MAX_MESSAGE.get();
|
||||
list.subList(0, overflow).clear();
|
||||
}
|
||||
}
|
||||
// msgMap.put(user, list);
|
||||
return response;
|
||||
} catch (Exception e) {
|
||||
Log.e(e, message);
|
||||
return Message.create("请求失败,请重试", true);
|
||||
} finally {
|
||||
lock.set(false);
|
||||
userLocks.remove(user, lock);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* 将文本转换为图像文件
|
||||
* 该方法使用预训练的AI模型将给定的文本转换为图像,并将其保存为文件
|
||||
*
|
||||
* @param user 用户标识符,用于为生成的图像文件命名
|
||||
* @param text 要转换为图像的文本
|
||||
* @return 返回生成的图像文件对象,如果转换过程中发生错误,则返回null
|
||||
*/
|
||||
public File textToImage(String user, String text) {
|
||||
// 使用QianFan的text2Image方法将文本转换为图像数据
|
||||
Text2ImageResponse response = qianfan.text2Image()
|
||||
.prompt(text)
|
||||
.execute();
|
||||
// 获取转换后的图像数据,以Base64编码的图像字符串形式
|
||||
val b64Image = response.getData().get(0).getB64Image();
|
||||
// 将Base64编码的图像数据转换为图像文件
|
||||
// 创建一个临时目录下的图像文件,文件名包含用户标识符和当前时间戳,以确保唯一性
|
||||
val imageFile = new File("tmp" + File.separator + user + "_" + System.currentTimeMillis() + ".png");
|
||||
try (val inputStream = new ByteArrayInputStream(Base64.getDecoder().decode(b64Image))) {
|
||||
// 将解码后的图像数据复制到图像文件中,替换现有文件
|
||||
Files.copy(inputStream, imageFile.toPath(), StandardCopyOption.REPLACE_EXISTING);
|
||||
return imageFile;
|
||||
} catch (Exception e) {
|
||||
// 如果在图像文件生成过程中发生错误,记录错误信息
|
||||
Log.e(e);
|
||||
}
|
||||
// 如果发生错误,返回null
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* 将图片转换为文本描述
|
||||
*
|
||||
* @param user 使用该功能的用户标识
|
||||
* @param file 要转换的图片文件
|
||||
* @return 转换后的文本描述,如果转换失败则返回null
|
||||
*/
|
||||
public String imageToText(String user, File file) {
|
||||
// 将file文件转换成base64的代码
|
||||
try {
|
||||
// 读取文件内容并转换为Base64编码
|
||||
val base64 = Base64.getEncoder().encodeToString(Files.readAllBytes(file.toPath()));
|
||||
|
||||
// 调用图像转文本的API
|
||||
Image2TextResponse response = qianfan.image2Text()
|
||||
.image(base64)
|
||||
.prompt("请描述这张图片中的主要内容和细节,以及它们之间的关系\n")
|
||||
.execute();
|
||||
String translationPrompt = "将以下英文内容严格翻译为简体中文,不要解释、不要添加额外内容,保留专业术语和名称(如Star Wars保持英文):\n" + response.getResult();
|
||||
// 获取API返回的结果
|
||||
return sendMessage("bot",translationPrompt).getContent();
|
||||
} catch (Exception e) {
|
||||
// 异常处理:记录错误日志
|
||||
Log.e(e);
|
||||
}
|
||||
// 如果发生异常,返回null
|
||||
return null;
|
||||
}
|
||||
|
||||
public String getGPTVersion() {
|
||||
return (model.equals(modelFor35) ? "3.5" : "4.0");
|
||||
}
|
||||
|
||||
public static void main(String[] args) throws Exception {
|
||||
// BaiduGPTManager.getManager().textToImage("user","画一个猫娘,用二次元动画画风,她是粉色头发,坐在地上");
|
||||
// BaiduGPTManager.getManager().imageToText("user",new File("test.png"));
|
||||
// Message message = BaiduGPTManager.getManager().sendMessage("user", "现在假设小猪等于1,小猴等于2");
|
||||
// System.out.println(message.getContent());
|
||||
// message = BaiduGPTManager.getManager().sendMessage("user", "那么小猪加上小猴等于多少?");
|
||||
// System.out.println(message.getContent());
|
||||
System.out.println(BaiduGPTManager.getManager().sendMessage("user", "分析这个网页链接的页面内容,而非链接本身:https://www.bilibili.com/video/BV1TTf5YrESz/").getContent());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user