mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-11-28 13:10:24 +08:00
257 lines
6.3 KiB
Go
257 lines
6.3 KiB
Go
// Copyright 2025, Command Line Inc.
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
package openaichat
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/launchdarkly/eventsource"
|
|
"github.com/wavetermdev/waveterm/pkg/aiusechat/chatstore"
|
|
"github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes"
|
|
"github.com/wavetermdev/waveterm/pkg/web/sse"
|
|
)
|
|
|
|
// RunChatStep executes a chat step using the chat completions API
|
|
func RunChatStep(
|
|
ctx context.Context,
|
|
sseHandler *sse.SSEHandlerCh,
|
|
chatOpts uctypes.WaveChatOpts,
|
|
cont *uctypes.WaveContinueResponse,
|
|
) (*uctypes.WaveStopReason, []*StoredChatMessage, *uctypes.RateLimitInfo, error) {
|
|
if sseHandler == nil {
|
|
return nil, nil, nil, errors.New("sse handler is nil")
|
|
}
|
|
|
|
chat := chatstore.DefaultChatStore.Get(chatOpts.ChatId)
|
|
if chat == nil {
|
|
return nil, nil, nil, fmt.Errorf("chat not found: %s", chatOpts.ChatId)
|
|
}
|
|
|
|
if chatOpts.Config.TimeoutMs > 0 {
|
|
var cancel context.CancelFunc
|
|
ctx, cancel = context.WithTimeout(ctx, time.Duration(chatOpts.Config.TimeoutMs)*time.Millisecond)
|
|
defer cancel()
|
|
}
|
|
|
|
// Convert stored messages to chat completions format
|
|
var messages []ChatRequestMessage
|
|
|
|
// Add system prompt if provided
|
|
if len(chatOpts.SystemPrompt) > 0 {
|
|
messages = append(messages, ChatRequestMessage{
|
|
Role: "system",
|
|
Content: strings.Join(chatOpts.SystemPrompt, "\n"),
|
|
})
|
|
}
|
|
|
|
// Convert native messages
|
|
for _, genMsg := range chat.NativeMessages {
|
|
chatMsg, ok := genMsg.(*StoredChatMessage)
|
|
if !ok {
|
|
return nil, nil, nil, fmt.Errorf("expected StoredChatMessage, got %T", genMsg)
|
|
}
|
|
messages = append(messages, *chatMsg.Message.clean())
|
|
}
|
|
|
|
req, err := buildChatHTTPRequest(ctx, messages, chatOpts)
|
|
if err != nil {
|
|
return nil, nil, nil, err
|
|
}
|
|
|
|
client := &http.Client{}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, nil, nil, fmt.Errorf("request failed: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
|
return nil, nil, nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(bodyBytes))
|
|
}
|
|
|
|
// Setup SSE if this is a new request (not a continuation)
|
|
if cont == nil {
|
|
if err := sseHandler.SetupSSE(); err != nil {
|
|
return nil, nil, nil, fmt.Errorf("failed to setup SSE: %w", err)
|
|
}
|
|
}
|
|
|
|
// Stream processing
|
|
stopReason, assistantMsg, err := processChatStream(ctx, resp.Body, sseHandler, chatOpts, cont)
|
|
if err != nil {
|
|
return nil, nil, nil, err
|
|
}
|
|
|
|
return stopReason, []*StoredChatMessage{assistantMsg}, nil, nil
|
|
}
|
|
|
|
func processChatStream(
|
|
ctx context.Context,
|
|
body io.Reader,
|
|
sseHandler *sse.SSEHandlerCh,
|
|
chatOpts uctypes.WaveChatOpts,
|
|
cont *uctypes.WaveContinueResponse,
|
|
) (*uctypes.WaveStopReason, *StoredChatMessage, error) {
|
|
decoder := eventsource.NewDecoder(body)
|
|
var textBuilder strings.Builder
|
|
msgID := uuid.New().String()
|
|
textID := uuid.New().String()
|
|
var finishReason string
|
|
textStarted := false
|
|
var toolCallsInProgress []ToolCall
|
|
|
|
if cont == nil {
|
|
_ = sseHandler.AiMsgStart(msgID)
|
|
}
|
|
_ = sseHandler.AiMsgStartStep()
|
|
|
|
for {
|
|
if err := ctx.Err(); err != nil {
|
|
_ = sseHandler.AiMsgError("request cancelled")
|
|
return &uctypes.WaveStopReason{
|
|
Kind: uctypes.StopKindCanceled,
|
|
ErrorType: "cancelled",
|
|
ErrorText: "request cancelled",
|
|
}, nil, err
|
|
}
|
|
|
|
event, err := decoder.Decode()
|
|
if err != nil {
|
|
if errors.Is(err, io.EOF) {
|
|
break
|
|
}
|
|
_ = sseHandler.AiMsgError(err.Error())
|
|
return &uctypes.WaveStopReason{
|
|
Kind: uctypes.StopKindError,
|
|
ErrorType: "stream",
|
|
ErrorText: err.Error(),
|
|
}, nil, fmt.Errorf("stream decode error: %w", err)
|
|
}
|
|
|
|
data := event.Data()
|
|
if data == "[DONE]" {
|
|
break
|
|
}
|
|
|
|
var chunk StreamChunk
|
|
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
|
log.Printf("openaichat: failed to parse chunk: %v\n", err)
|
|
continue
|
|
}
|
|
|
|
if len(chunk.Choices) == 0 {
|
|
continue
|
|
}
|
|
|
|
choice := chunk.Choices[0]
|
|
if choice.Delta.Content != "" {
|
|
if !textStarted {
|
|
_ = sseHandler.AiMsgTextStart(textID)
|
|
textStarted = true
|
|
}
|
|
textBuilder.WriteString(choice.Delta.Content)
|
|
_ = sseHandler.AiMsgTextDelta(textID, choice.Delta.Content)
|
|
}
|
|
|
|
if len(choice.Delta.ToolCalls) > 0 {
|
|
for _, tcDelta := range choice.Delta.ToolCalls {
|
|
idx := tcDelta.Index
|
|
for len(toolCallsInProgress) <= idx {
|
|
toolCallsInProgress = append(toolCallsInProgress, ToolCall{})
|
|
}
|
|
|
|
tc := &toolCallsInProgress[idx]
|
|
if tcDelta.ID != "" {
|
|
tc.ID = tcDelta.ID
|
|
}
|
|
if tcDelta.Type != "" {
|
|
tc.Type = tcDelta.Type
|
|
}
|
|
if tcDelta.Function != nil {
|
|
if tcDelta.Function.Name != "" {
|
|
tc.Function.Name = tcDelta.Function.Name
|
|
}
|
|
if tcDelta.Function.Arguments != "" {
|
|
tc.Function.Arguments += tcDelta.Function.Arguments
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if choice.FinishReason != nil && *choice.FinishReason != "" {
|
|
finishReason = *choice.FinishReason
|
|
}
|
|
}
|
|
|
|
stopKind := uctypes.StopKindDone
|
|
if finishReason == "length" {
|
|
stopKind = uctypes.StopKindMaxTokens
|
|
} else if finishReason == "tool_calls" {
|
|
stopKind = uctypes.StopKindToolUse
|
|
}
|
|
|
|
var validToolCalls []ToolCall
|
|
for _, tc := range toolCallsInProgress {
|
|
if tc.ID != "" && tc.Function.Name != "" {
|
|
validToolCalls = append(validToolCalls, tc)
|
|
}
|
|
}
|
|
|
|
var waveToolCalls []uctypes.WaveToolCall
|
|
if len(validToolCalls) > 0 {
|
|
for _, tc := range validToolCalls {
|
|
var inputJSON any
|
|
if tc.Function.Arguments != "" {
|
|
if err := json.Unmarshal([]byte(tc.Function.Arguments), &inputJSON); err != nil {
|
|
log.Printf("openaichat: failed to parse tool call arguments: %v\n", err)
|
|
continue
|
|
}
|
|
}
|
|
waveToolCalls = append(waveToolCalls, uctypes.WaveToolCall{
|
|
ID: tc.ID,
|
|
Name: tc.Function.Name,
|
|
Input: inputJSON,
|
|
})
|
|
}
|
|
}
|
|
|
|
stopReason := &uctypes.WaveStopReason{
|
|
Kind: stopKind,
|
|
RawReason: finishReason,
|
|
ToolCalls: waveToolCalls,
|
|
}
|
|
|
|
assistantMsg := &StoredChatMessage{
|
|
MessageId: msgID,
|
|
Message: ChatRequestMessage{
|
|
Role: "assistant",
|
|
},
|
|
}
|
|
|
|
if len(validToolCalls) > 0 {
|
|
assistantMsg.Message.ToolCalls = validToolCalls
|
|
} else {
|
|
assistantMsg.Message.Content = textBuilder.String()
|
|
}
|
|
|
|
if textStarted {
|
|
_ = sseHandler.AiMsgTextEnd(textID)
|
|
}
|
|
_ = sseHandler.AiMsgFinishStep()
|
|
if stopKind != uctypes.StopKindToolUse {
|
|
_ = sseHandler.AiMsgFinish(finishReason, nil)
|
|
}
|
|
|
|
return stopReason, assistantMsg, nil
|
|
}
|