QQBot/src/main/java/com/yutou/qqbot/utlis/BaiduGPTManager.java

103 lines
3.7 KiB
Java

package com.yutou.qqbot.utlis;
import com.alibaba.fastjson2.JSONObject;
import com.yutou.qqbot.data.baidu.Message;
import com.yutou.qqbot.data.baidu.ResponseMessage;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class BaiduGPTManager {
private static int MAX_MESSAGE = 5;
private static BaiduGPTManager manager;
private static final String url_3_5 = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions";
//4.0
private static final String url_4_0 = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro";
private static String url = url_3_5;
private static final String AppID = ConfigTools.load(ConfigTools.CONFIG,ConfigTools.BAIDU_GPT_APPID, String.class);
private static final String ApiKey =ConfigTools.load(ConfigTools.CONFIG,ConfigTools.BAIDU_GPT_API_KEY, String.class);
private static final String SecretKey =ConfigTools.load(ConfigTools.CONFIG,ConfigTools.BAIDU_GPT_SECRET_KEY, String.class);
private final Map<String, List<Message>> msgMap;
private BaiduGPTManager() {
msgMap = new HashMap<>();
}
public static BaiduGPTManager getManager() {
if (manager == null) {
manager = new BaiduGPTManager();
}
return manager;
}
public int setMaxMessageCount(int count) {
MAX_MESSAGE = count;
return MAX_MESSAGE;
}
public void setModelFor40() {
url = url_4_0;
ConfigTools.save(ConfigTools.CONFIG,ConfigTools.BAIDU_GPT_VERSION,"4.0");
}
public void setModelFor35() {
url = url_3_5;
ConfigTools.save(ConfigTools.CONFIG,ConfigTools.BAIDU_GPT_VERSION,"3.5");
}
public void clear() {
msgMap.clear();
}
private String getToken() {
String _url = String.format("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s"
, ApiKey
, SecretKey
);
String get = HttpTools.get(_url);
JSONObject response = JSONObject.parseObject(get);
return response.getString("access_token");
}
public ResponseMessage sendMessage(String user, String message) {
List<Message> messages = msgMap.getOrDefault(user, new ArrayList<>());
if (messages.size() > MAX_MESSAGE * 2) {
messages.remove(0);
messages.remove(1);
}
messages.add(Message.create(message));
JSONObject json = new JSONObject();
json.put("messages", messages);
System.out.println("json = " + json);
Map<String, String> map = new HashMap<>();
map.put("Content-Type", "application/json");
map.put("Content-Length", String.valueOf(json.toJSONString().getBytes(StandardCharsets.UTF_8).length));
String post = HttpTools.http_post(url + "?access_token=" + getToken()
, json.toJSONString().getBytes(StandardCharsets.UTF_8), 0, map);
System.out.println("post = " + post);
if (StringUtils.isEmpty(post)) {
clear();
return sendMessage(user, message);
}
ResponseMessage response = JSONObject.parseObject(post, ResponseMessage.class);
messages.add(Message.create(response.getResult(), true));
msgMap.put(user, messages);
System.out.println("\n\n");
return response;
}
public String getGPTVersion() {
return (url.equals(url_3_5) ? "3.5" : "4.0");
}
public static void main(String[] args) throws Exception {
ResponseMessage message = BaiduGPTManager.getManager().sendMessage("test", "你是那个版本的大模型?");
System.out.println(message.getResult());
}
}