package com.yutou.qqbot.utlis; import com.alibaba.fastjson2.JSONObject; import com.yutou.qqbot.data.baidu.Message; import com.yutou.qqbot.data.baidu.ResponseMessage; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; public class BaiduGPTManager { private static int MAX_MESSAGE = 5; private static BaiduGPTManager manager; private static final String url_3_5 = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"; //4.0 private static final String url_4_0 = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"; private static String url = url_3_5; private static final String AppID = ConfigTools.load(ConfigTools.CONFIG,ConfigTools.BAIDU_GPT_APPID, String.class); private static final String ApiKey =ConfigTools.load(ConfigTools.CONFIG,ConfigTools.BAIDU_GPT_API_KEY, String.class); private static final String SecretKey =ConfigTools.load(ConfigTools.CONFIG,ConfigTools.BAIDU_GPT_SECRET_KEY, String.class); private final Map> msgMap; private BaiduGPTManager() { msgMap = new HashMap<>(); } public static BaiduGPTManager getManager() { if (manager == null) { manager = new BaiduGPTManager(); } return manager; } public int setMaxMessageCount(int count) { MAX_MESSAGE = count; return MAX_MESSAGE; } public void setModelFor40() { url = url_4_0; ConfigTools.save(ConfigTools.CONFIG,ConfigTools.BAIDU_GPT_VERSION,"4.0"); } public void setModelFor35() { url = url_3_5; ConfigTools.save(ConfigTools.CONFIG,ConfigTools.BAIDU_GPT_VERSION,"3.5"); } public void clear() { msgMap.clear(); } private String getToken() { String _url = String.format("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s" , ApiKey , SecretKey ); String get = HttpTools.get(_url); JSONObject response = JSONObject.parseObject(get); return response.getString("access_token"); } public ResponseMessage sendMessage(String user, String message) { List messages = msgMap.getOrDefault(user, new ArrayList<>()); if (messages.size() > MAX_MESSAGE * 2) { messages.remove(0); messages.remove(1); } messages.add(Message.create(message)); JSONObject json = new JSONObject(); json.put("messages", messages); System.out.println("json = " + json); Map map = new HashMap<>(); map.put("Content-Type", "application/json"); map.put("Content-Length", String.valueOf(json.toJSONString().getBytes(StandardCharsets.UTF_8).length)); String post = HttpTools.http_post(url + "?access_token=" + getToken() , json.toJSONString().getBytes(StandardCharsets.UTF_8), 0, map); System.out.println("post = " + post); if (StringUtils.isEmpty(post)) { clear(); return sendMessage(user, message); } ResponseMessage response = JSONObject.parseObject(post, ResponseMessage.class); messages.add(Message.create(response.getResult(), true)); msgMap.put(user, messages); System.out.println("\n\n"); return response; } public String getGPTVersion() { return (url.equals(url_3_5) ? "3.5" : "4.0"); } public static void main(String[] args) throws Exception { ResponseMessage message = BaiduGPTManager.getManager().sendMessage("test", "你是那个版本的大模型?"); System.out.println(message.getResult()); } }