百度文心一言 java 支持流式输出,Springboot+ sse的demo

AIGC 0

参考:GitHub - mmciel/wenxin-api-java: 百度文心一言Java库,支持问答和对话,支持流式输出和同步输出。提供SpringBoot调用样例。提供拓展能力。

1、依赖

<dependency>
<groupId>com.baidu.aip</groupId>
<artifactId>java-sdk</artifactId>
<version>4.16.18</version>
</dependency>

2、配置apikey和secretkey

3、主要使用的接口

4、返回的json格式 

3、WenxinEventSourceListener  事件监听器

和其他的接口不一样 需要 CompletionsResponse.data  封装下 ,不然前端页面需要兼容非json的格式

@Slf4jpublic class WenxinEventSourceListener extends EventSourceListener {    private long tokens;    private SseEmitter sseEmitter;    public WenxinEventSourceListener(SseEmitter sseEmitter) {        this.sseEmitter = sseEmitter;    }    @Override    public void onOpen(EventSource eventSource, Response response) {        log.info("建立sse连接...");    }    @SneakyThrows    @Override    @JsonIgnoreProperties(ignoreUnknown = true)    public void onEvent(EventSource eventSource, String id, String type, String data) {        ChatResponse bean = JSONUtil.parseObj(data).toBean(ChatResponse.class);        log.info("返回数据:{}", data);        if (bean.getIs_end()) {            log.info("返回数据结束了");            sseEmitter.send(SseEmitter.event()                    .id("[TOKENS]")                    .data("<br/><br/>tokens:" + tokens())                    .reconnectTime(3000));            sseEmitter.send(SseEmitter.event()                    .id("[DONE]")                    .data("[DONE]")                    .reconnectTime(3000));            // 传输完成后自动关闭sse            sseEmitter.complete();            return;        }        log.info("OpenAI返回数据:{}", data);        tokens += 1;        if (data.equals("[DONE]")) {            log.info("OpenAI返回数据结束了");            sseEmitter.send(SseEmitter.event()                    .id("[TOKENS]")                    .data("<br/><br/>tokens:" + tokens())                    .reconnectTime(3000));            sseEmitter.send(SseEmitter.event()                    .id("[DONE]")                    .data("[DONE]")                    .reconnectTime(3000));            // 传输完成后自动关闭sse            sseEmitter.complete();            return;        }        CompletionsResponse completionResponse = new CompletionsResponse();        CompletionsResponse.Data dataResult = new CompletionsResponse.Data();        dataResult.setText(bean.getResult());        completionResponse.setData(dataResult);        try {            sseEmitter.send(SseEmitter.event()                    .id(bean.getId())                    .data(completionResponse.getData())                    .reconnectTime(3000));        } catch (Exception e) {            log.error("sse信息推送失败!");            eventSource.cancel();            e.printStackTrace();        }    }    @Override    public void onClosed(EventSource eventSource) {        log.info("关闭sse连接...");    }    @SneakyThrows    @Override    public void onFailure(EventSource eventSource, Throwable t, Response response) {        if(Objects.isNull(response)){            log.error("sse连接异常:{}", t);            eventSource.cancel();            return;        }        ResponseBody body = response.body();        if (Objects.nonNull(body)) {            // 错误处理 {"error_code":110,"error_msg":"Access token invalid or no longer valid"},异常:{}            log.error("sse连接异常data:{},异常:{}", body.string(), t);        } else {            log.error("sse连接异常data:{},异常:{}", response, t);        }        eventSource.cancel();    }    /**     * tokens     * @return     */    public long tokens() {        return tokens;    }}

4、WenXinClient  流式主要看下 streamChat 方式,之前从千帆上找到流式例子 返回type是json的,所以之前自己手写的demo总报异常。

 public void streamChat(ChatBody chatBody, EventSourceListener eventSourceListener, ModelE modelE) {        if (Objects.isNull(eventSourceListener)) {            throw new WenXinException("参数异常:EventSourceListener不能为空");        }        chatBody.setStream(true);        try {            EventSource.Factory factory = EventSources.createFactory(this.okHttpClient);            Request request = new Request.Builder().url(assembleUrl(modelE))                    .post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()),                            new ObjectMapper().writeValueAsString(chatBody))).build();            factory.newEventSource(request, eventSourceListener);        } catch (Exception e) {            log.error("请求参数解析异常:", e);            e.printStackTrace();        }    }private String assembleUrl(ModelE modelE) {        accessToken = WenXinConfig.refreshAccessToken();        return modelE.getApiHost() + "?access_token=" + accessToken;    }

5、定义Sse的接口是实现方法

public interface SseService {    /**     * 创建SSE     * @param uid     * @return     */    SseEmitter createSse(String uid);    /**     * 关闭SSE     * @param uid     */    void closeSse(String uid);    /**     * 客户端发送消息到服务端     * @param uid     * @param chatRequest     */    ChatResponse sseChat(String uid, ChatRequest chatRequest);}
public class WenXinSseServiceImpl implements SseService {    @Value("${chat.accessKeyId}")    private String accessKeyId;    @Value("${chat.accessKeySecret}")    private String accessKeySecret;    @Value("${chat.agentKey}")    private String agentKey;    @Value("${chat.appId}")    private String appId;    @Autowired    WenXinClient wenXinClient;    @Override    public SseEmitter createSse(String uid) {        //默认30秒超时,设置为0L则永不超时        SseEmitter sseEmitter = new SseEmitter(0l);        //完成后回调        sseEmitter.onCompletion(() -> {            log.info("[{}]结束连接...................", uid);            LocalCache.CACHE.remove(uid);        });        //超时回调        sseEmitter.onTimeout(() -> {            log.info("[{}]连接超时...................", uid);        });        //异常回调        sseEmitter.onError(                throwable -> {                    try {                        log.info("[{}]连接异常,{}", uid, throwable.toString());                        sseEmitter.send(SseEmitter.event()                                .id(uid)                                .name("发生异常!")                                .data(Message.builder().content("发生异常请重试!").build())                                .reconnectTime(3000));                        LocalCache.CACHE.put(uid, sseEmitter);                    } catch (IOException e) {                        e.printStackTrace();                    }                }        );        try {            sseEmitter.send(SseEmitter.event().reconnectTime(5000));        } catch (IOException e) {            e.printStackTrace();        }        LocalCache.CACHE.put(uid, sseEmitter);        log.info("[{}]创建sse连接成功!", uid);        return sseEmitter;    }    @Override    public void closeSse(String uid) {        SseEmitter sse = (SseEmitter) LocalCache.CACHE.get(uid);        if (sse != null) {            sse.complete();            //移除            LocalCache.CACHE.remove(uid);        }    }    @Override    public ChatResponse sseChat(String uid, ChatRequest chatRequest) {        if (StringUtils.isBlank(chatRequest.getMsg())) {            log.error("参数异常,msg为null", uid);            throw new BaseException("参数异常,msg不能为空~");        }        SseEmitter sseEmitter = (SseEmitter) LocalCache.CACHE.get(uid);        if (sseEmitter == null) {            log.info("聊天消息推送失败uid:[{}],没有创建连接,请重试。", uid);            throw new BaseException("聊天消息推送失败uid:[{}],没有创建连接,请重试。~");        }        WenxinEventSourceListener openAIEventSourceListener = new WenxinEventSourceListener(sseEmitter);        List<MessageItem> messages = new ArrayList<>();        messages.add(MessageItem.builder().role(MessageItem.Role.USER).content(chatRequest.getMsg()).build());        wenXinClient.streamChat(messages, openAIEventSourceListener, ModelE.ERNIE_Bot);        LocalCache.CACHE.put("msg" + uid, JSONUtil.toJsonStr(messages), LocalCache.TIMEOUT);        ChatResponse response = new ChatResponse();        response.setQuestionTokens(1);        return response;    }}

6、主要的controller接口

/**     * 创建sse连接     *     * @param headers     * @return     */    @CrossOrigin    @GetMapping("/createSse")    public SseEmitter createConnect(@RequestHeader Map<String, String> headers) {        String uid = getUid(headers);        return sseService.createSse(uid);    }    /**     * 聊天接口     *     * @param chatRequest     * @param headers     */    @CrossOrigin    @PostMapping("/chat")    @ResponseBody    public ChatResponse sseChat(@RequestBody ChatRequest chatRequest, @RequestHeader Map<String, String> headers, HttpServletResponse response) {        String uid = getUid(headers);        return sseService.sseChat(uid, chatRequest);    }    /**     * 关闭连接     *     * @param headers     */    @CrossOrigin    @GetMapping("/closeSse")    public void closeConnect(@RequestHeader Map<String, String> headers) {        String uid = getUid(headers);        sseService.closeSse(uid);    }

7、主要的页面代码

<!DOCTYPE html><html lang="en"><head>  <meta charset="UTF-8">  <meta name="viewport" content="width=device-width, initial-scale=1.0">  <title>智能问答</title>  <link rel="stylesheet" href="styles.css"> <!-- 引入外部CSS -->  <script src="HZRecorder.js"></script>  <script src="https://cdn.bootcdn.net/ajax/libs/jquery/3.6.0/jquery.min.js"></script>  <script src="js/markdown.min.js"></script>  <script src="js/eventsource.min.js"></script>  <script>      function setText(text, uuid_str) {        let content = document.getElementById(uuid_str);        content.innerHTML = marked(text);      }      function uuid() {        var s = [];        var hexDigits = "0123456789abcdef";        for (var i = 0; i < 36; i++) {          s[i] = hexDigits.substr(Math.floor(Math.random() * 0x10), 1);        }        s[14] = "4"; // bits 12-15 of the time_hi_and_version field to 0010        s[19] = hexDigits.substr((s[19] & 0x3) | 0x8, 1); // bits 6-7 of the clock_seq_hi_and_reserved to 01        s[8] = s[13] = s[18] = s[23] = "-";        var uuid = s.join("");        console.log(uuid)        return uuid;      }      window.onload = function () {        /*let disconnectBtn = document.getElementById("disconnectSSE");*/        let messageElement = document.getElementById("messageInput");        let chat = document.getElementById("chat-messages");        let sse;        let uid = window.localStorage.getItem("uid");        if (uid == null || uid == "" || uid == "null") {          uid = uuid();        }        let text = "";        let uuid_str;        // 设置本地存储        window.localStorage.setItem("uid", uid);        // 发送消息按钮点击事件        document.getElementById('sendTextButton').addEventListener('click', async function () {          try {            const userInput = document.getElementById('messageInput').value.trim();            if (userInput) {              await sseOneTurn(userInput)              userInput.value = ''; // 清空输入框            } else {              alert('请输入文字消息!');            }          } catch (error) {            alert('发送消息时发生错误: ' + error.message);          }        });        // 回车事件        messageElement.onkeydown = function () {          if (window.event.keyCode === 13) {            if (!messageElement.value) {              return;            }            sseOneTurn(messageElement.value);          }        };        function sseOneTurn(InputText) {          uuid_str = uuid();          //创建sse          const eventSource = new EventSourcePolyfill("/createSse", {            headers: {              uid: uid,            },          });          eventSource.onopen = (event) => {            console.log("开始输出后端返回值");            sse = event.target;          };          eventSource.onmessage = (event) => {            debugger;            if (event.lastEventId == "[TOKENS]") {              text = text + event.data;              setText(text, uuid_str);              text = "";              return;            }            if (event.data == "[DONE]") {              text = "";              if (sse) {                sse.close();              }              return;            }            let json_data = JSON.parse(event.data);            console.log(json_data);            if (json_data.text == null || json_data.text == "null") {              return;            }            text = text + json_data.text;            setText(text, uuid_str);          };          eventSource.onerror = (event) => {            console.log("onerror", event);            alert("服务异常请重试并联系开发者!");            if (event.readyState === EventSource.CLOSED) {              console.log("connection is closed");            } else {              console.log("Error occured", event);            }            event.target.close();          };          eventSource.addEventListener("customEventName", (event) => {            console.log("Message id is " + event.lastEventId);          });          eventSource.addEventListener("customEventName", (event) => {            console.log("Message id is " + event.lastEventId);          });          $.ajax({            type: "post",            url: "/chat",            data: JSON.stringify({              msg: InputText,            }),            contentType: "application/json;charset=UTF-8",            dataType: "json",            headers: {              uid: uid,            },            beforeSend: function (request) {},            success: function (result) {              //新增问题框              debugger;              chat.innerHTML +=                      '<tr><td style="height: 30px;">' +                      InputText +                      "<br/><br/> tokens:" +                      result.question_tokens +                      "</td></tr>";              InputText = null;              //新增答案框              chat.innerHTML +=                      '<tr><td><article id="' +                      uuid_str +                      '" class="markdown-body"></article></td></tr>';            },            complete: function () {},            error: function () {              console.info("发送问题失败!");            },          });        }        /*disconnectBtn.onclick = function () {          if (sse) {            sse.close();          }        };*/      };    </script>  </head><body><div class="chat-container">  <div class="chat-header">    <h1>智能问答</h1>  </div>  <div class="chat-messages" id="chat-messages">    <!-- 聊天消息将会在这里显示 -->  </div>  <form class="message-form" onsubmit="return false;">    <input type="text" id="messageInput" placeholder="输入消息..." autocomplete="off">    <button type="button" id="sendTextButton">发送文字</button>    <button type="button" id="recordAndUploadButton">按住录音</button>    <progress id="uploadProgress" value="0" max="100" style="display:none;"></progress>  </form></div></body></html>

最后的呈现效果如下:

也许您对下面的内容还感兴趣: