103 lines
3.7 KiB
Java
103 lines
3.7 KiB
Java
package com.yutou.qqbot.utlis;
|
|
|
|
import com.alibaba.fastjson2.JSONObject;
|
|
import com.yutou.qqbot.data.baidu.Message;
|
|
import com.yutou.qqbot.data.baidu.ResponseMessage;
|
|
|
|
import java.nio.charset.StandardCharsets;
|
|
import java.util.ArrayList;
|
|
import java.util.HashMap;
|
|
import java.util.List;
|
|
import java.util.Map;
|
|
|
|
public class BaiduGPTManager {
|
|
private static int MAX_MESSAGE = 5;
|
|
private static BaiduGPTManager manager;
|
|
private static final String url_3_5 = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions";
|
|
//4.0
|
|
private static final String url_4_0 = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro";
|
|
private static String url = url_3_5;
|
|
private static final String AppID = ConfigTools.load(ConfigTools.CONFIG,ConfigTools.BAIDU_GPT_APPID, String.class);
|
|
private static final String ApiKey =ConfigTools.load(ConfigTools.CONFIG,ConfigTools.BAIDU_GPT_API_KEY, String.class);
|
|
private static final String SecretKey =ConfigTools.load(ConfigTools.CONFIG,ConfigTools.BAIDU_GPT_SECRET_KEY, String.class);
|
|
private final Map<String, List<Message>> msgMap;
|
|
|
|
private BaiduGPTManager() {
|
|
msgMap = new HashMap<>();
|
|
}
|
|
|
|
public static BaiduGPTManager getManager() {
|
|
if (manager == null) {
|
|
manager = new BaiduGPTManager();
|
|
}
|
|
return manager;
|
|
}
|
|
|
|
public int setMaxMessageCount(int count) {
|
|
MAX_MESSAGE = count;
|
|
return MAX_MESSAGE;
|
|
}
|
|
|
|
public void setModelFor40() {
|
|
url = url_4_0;
|
|
ConfigTools.save(ConfigTools.CONFIG,ConfigTools.BAIDU_GPT_VERSION,"4.0");
|
|
}
|
|
|
|
public void setModelFor35() {
|
|
url = url_3_5;
|
|
ConfigTools.save(ConfigTools.CONFIG,ConfigTools.BAIDU_GPT_VERSION,"3.5");
|
|
}
|
|
|
|
public void clear() {
|
|
msgMap.clear();
|
|
}
|
|
|
|
private String getToken() {
|
|
String _url = String.format("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s"
|
|
, ApiKey
|
|
, SecretKey
|
|
);
|
|
String get = HttpTools.get(_url);
|
|
JSONObject response = JSONObject.parseObject(get);
|
|
return response.getString("access_token");
|
|
}
|
|
|
|
public ResponseMessage sendMessage(String user, String message) {
|
|
List<Message> messages = msgMap.getOrDefault(user, new ArrayList<>());
|
|
if (messages.size() > MAX_MESSAGE * 2) {
|
|
messages.remove(0);
|
|
messages.remove(1);
|
|
}
|
|
messages.add(Message.create(message));
|
|
|
|
JSONObject json = new JSONObject();
|
|
json.put("messages", messages);
|
|
System.out.println("json = " + json);
|
|
|
|
Map<String, String> map = new HashMap<>();
|
|
map.put("Content-Type", "application/json");
|
|
map.put("Content-Length", String.valueOf(json.toJSONString().getBytes(StandardCharsets.UTF_8).length));
|
|
String post = HttpTools.http_post(url + "?access_token=" + getToken()
|
|
, json.toJSONString().getBytes(StandardCharsets.UTF_8), 0, map);
|
|
System.out.println("post = " + post);
|
|
if (StringUtils.isEmpty(post)) {
|
|
clear();
|
|
return sendMessage(user, message);
|
|
}
|
|
ResponseMessage response = JSONObject.parseObject(post, ResponseMessage.class);
|
|
messages.add(Message.create(response.getResult(), true));
|
|
msgMap.put(user, messages);
|
|
System.out.println("\n\n");
|
|
return response;
|
|
}
|
|
|
|
public String getGPTVersion() {
|
|
return (url.equals(url_3_5) ? "3.5" : "4.0");
|
|
}
|
|
|
|
public static void main(String[] args) throws Exception {
|
|
ResponseMessage message = BaiduGPTManager.getManager().sendMessage("test", "你是那个版本的大模型?");
|
|
System.out.println(message.getResult());
|
|
}
|
|
}
|