package com.yutou.qqbot.gpt; import com.baidubce.qianfan.Qianfan; import com.baidubce.qianfan.model.chat.ChatResponse; import com.baidubce.qianfan.model.image.Image2TextResponse; import com.baidubce.qianfan.model.image.Text2ImageResponse; import com.yutou.qqbot.data.baidu.Message; import com.yutou.qqbot.utlis.ConfigTools; import com.yutou.qqbot.utlis.Log; import com.yutou.qqbot.utlis.StringUtils; import lombok.val; import java.io.ByteArrayInputStream; import java.io.File; import java.nio.file.Files; import java.nio.file.StandardCopyOption; import java.util.ArrayList; import java.util.Base64; import java.util.Collections; import java.util.List; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; public class BaiduGPTManager extends AbsGPTManager { private static final String AppID = ConfigTools.load(ConfigTools.CONFIG, ConfigTools.BAIDU_GPT_APPID, String.class); private static final String ApiKey = ConfigTools.load(ConfigTools.CONFIG, ConfigTools.BAIDU_GPT_API_KEY, String.class); //ConfigTools.load操作可以确保获取到相关参数,所以无需关心 private static final String AccessKey = ConfigTools.load(ConfigTools.CONFIG, ConfigTools.BAIDU_GPT_ACCESS_KEY, String.class); private static final String SecretKey = ConfigTools.load(ConfigTools.CONFIG, ConfigTools.BAIDU_GPT_SECRET_KEY, String.class); private final ConcurrentHashMap> msgMap; private final static String modelFor40 = "ERNIE-4.0-8K"; private final static String modelFor35 = "ERNIE-Speed-128K"; private String model = modelFor35; // 新增锁映射表 private final ConcurrentHashMap userLocks = new ConcurrentHashMap<>(); private final Qianfan qianfan; private BaiduGPTManager() { msgMap = new ConcurrentHashMap<>(); qianfan = new Qianfan(AccessKey, SecretKey); String savedVersion = ConfigTools.load(ConfigTools.CONFIG, ConfigTools.BAIDU_GPT_VERSION, String.class); if (StringUtils.isEmpty(savedVersion) || (!"3.5".equals(savedVersion) && !"4.0".equals(savedVersion))) { savedVersion = "3.5"; ConfigTools.save(ConfigTools.CONFIG, ConfigTools.BAIDU_GPT_VERSION, savedVersion); } model = "3.5".equals(savedVersion) ? modelFor35 : modelFor40; } private static volatile BaiduGPTManager manager; public static BaiduGPTManager getManager() { if (manager == null) { synchronized (BaiduGPTManager.class) { if (manager == null) { manager = new BaiduGPTManager(); } } } return manager; } @Override public int setMaxMessageCount(int count) { MAX_MESSAGE.set(count); return count; } public synchronized void setModelFor40() { model = modelFor40; ConfigTools.save(ConfigTools.CONFIG, ConfigTools.BAIDU_GPT_VERSION, "4.0"); } public synchronized void setModelFor35() { model = modelFor35; ConfigTools.save(ConfigTools.CONFIG, ConfigTools.BAIDU_GPT_VERSION, "3.5"); } /** * 这里确实是需要清空所有数据 */ @Override public synchronized void clear() { // 添加同步 msgMap.clear(); for (AtomicBoolean value : userLocks.values()) { value.set(false); } userLocks.forEachValue(1, atomicBoolean -> atomicBoolean.set(false)); userLocks.clear(); } // 这个是官方的示例代码,表示连续对话 private static void exampleChat() { Qianfan qianfan = new Qianfan(); ChatResponse response = qianfan.chatCompletion() // 设置需要使用的模型,与endpoint同时只能设置一种 .model("ERNIE-Bot") // 通过传入历史对话记录来实现多轮对话 .addMessage("user", "你好!你叫什么名字?") .addMessage("assistant", "你好!我是文心一言,英文名是ERNIE Bot。") // 传入本轮对话的用户输入 .addMessage("user", "刚刚我的问题是什么?") .execute(); System.out.println("输出内容:" + response.getResult()); } @Override public Message sendMessage(String user, String message) { // 获取或创建用户锁 AtomicBoolean lock = userLocks.computeIfAbsent(user, k -> new AtomicBoolean(false)); // 尝试加锁(如果已被锁定则立即返回提示) if (!lock.compareAndSet(false, true)) { return Message.create("您有请求正在处理中,请稍后再试", true); } try { List list = msgMap.computeIfAbsent(user, k -> Collections.synchronizedList(new ArrayList<>())); // 限制历史消息的最大数量 synchronized (list) { if (list.size() >= MAX_MESSAGE.get()) { int removeCount = list.size() - MAX_MESSAGE.get() + 1; // 腾出空间给新消息 list.subList(0, removeCount).clear(); } list.add(Message.create(message)); } val builder = qianfan.chatCompletion() .model(model); for (Message msg : list) { builder.addMessage(msg.getRole(), msg.getContent()); } ChatResponse chatResponse = builder.execute(); Message response = Message.create(chatResponse.getResult(), true); synchronized (list) { list.add(response); if (list.size() > MAX_MESSAGE.get()) { int overflow = list.size() - MAX_MESSAGE.get(); list.subList(0, overflow).clear(); } } // msgMap.put(user, list); return response; } catch (Exception e) { Log.e(e, message); return Message.create("请求失败,请重试", true); } finally { lock.set(false); userLocks.remove(user, lock); } } /** * 将文本转换为图像文件 * 该方法使用预训练的AI模型将给定的文本转换为图像,并将其保存为文件 * * @param user 用户标识符,用于为生成的图像文件命名 * @param text 要转换为图像的文本 * @return 返回生成的图像文件对象,如果转换过程中发生错误,则返回null */ @Override public File textToImage(String user, String text) { // 使用QianFan的text2Image方法将文本转换为图像数据 Text2ImageResponse response = qianfan.text2Image() .prompt(text) .execute(); // 获取转换后的图像数据,以Base64编码的图像字符串形式 val b64Image = response.getData().get(0).getB64Image(); // 将Base64编码的图像数据转换为图像文件 // 创建一个临时目录下的图像文件,文件名包含用户标识符和当前时间戳,以确保唯一性 val imageFile = new File("tmp" + File.separator + user + "_" + System.currentTimeMillis() + ".png"); try (val inputStream = new ByteArrayInputStream(Base64.getDecoder().decode(b64Image))) { // 将解码后的图像数据复制到图像文件中,替换现有文件 Files.copy(inputStream, imageFile.toPath(), StandardCopyOption.REPLACE_EXISTING); return imageFile; } catch (Exception e) { // 如果在图像文件生成过程中发生错误,记录错误信息 Log.e(e); } // 如果发生错误,返回null return null; } /** * 将图片转换为文本描述 * * @param user 使用该功能的用户标识 * @param file 要转换的图片文件 * @return 转换后的文本描述,如果转换失败则返回null */ @Override public String imageToText(String user, File file) { // 将file文件转换成base64的代码 try { // 读取文件内容并转换为Base64编码 val base64 = Base64.getEncoder().encodeToString(Files.readAllBytes(file.toPath())); // 调用图像转文本的API Image2TextResponse response = qianfan.image2Text() .image(base64) .prompt("请描述这张图片中的主要内容和细节,以及它们之间的关系\n") .execute(); String translationPrompt = "将以下英文内容严格翻译为简体中文,不要解释、不要添加额外内容,保留专业术语和名称(如Star Wars保持英文):\n" + response.getResult(); // 获取API返回的结果 return sendMessage("bot",translationPrompt).getContent(); } catch (Exception e) { // 异常处理:记录错误日志 Log.e(e); } // 如果发生异常,返回null return null; } @Override public String getGPTVersion() { return model; } public static void main(String[] args) throws Exception { // BaiduGPTManager.getManager().textToImage("user","画一个猫娘,用二次元动画画风,她是粉色头发,坐在地上"); // BaiduGPTManager.getManager().imageToText("user",new File("test.png")); // Message message = BaiduGPTManager.getManager().sendMessage("user", "现在假设小猪等于1,小猴等于2"); // System.out.println(message.getContent()); // message = BaiduGPTManager.getManager().sendMessage("user", "那么小猪加上小猴等于多少?"); // System.out.println(message.getContent()); System.out.println(BaiduGPTManager.getManager().sendMessage("user", "分析这个网页链接的页面内容,而非链接本身:https://www.bilibili.com/video/BV1TTf5YrESz/").getContent()); } }