mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-11-28 05:00:26 +08:00
1008 lines
34 KiB
Go
1008 lines
34 KiB
Go
// Copyright 2025, Command Line Inc.
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
package openai
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/launchdarkly/eventsource"
|
|
"github.com/wavetermdev/waveterm/pkg/aiusechat/aiutil"
|
|
"github.com/wavetermdev/waveterm/pkg/aiusechat/chatstore"
|
|
"github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes"
|
|
"github.com/wavetermdev/waveterm/pkg/util/logutil"
|
|
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
|
|
"github.com/wavetermdev/waveterm/pkg/web/sse"
|
|
)
|
|
|
|
// sanitizeHostnameInError removes the specific hostname from error messages
|
|
func sanitizeHostnameInError(err error, baseURL string) error {
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
|
|
errStr := err.Error()
|
|
parsedURL, parseErr := url.Parse(baseURL)
|
|
if parseErr == nil && parsedURL.Host != "" {
|
|
errStr = strings.ReplaceAll(errStr, baseURL, "AI service")
|
|
errStr = strings.ReplaceAll(errStr, parsedURL.Host, "host")
|
|
}
|
|
|
|
return fmt.Errorf("%s", errStr)
|
|
}
|
|
|
|
// ---------- OpenAI wire types (subset) ----------
|
|
|
|
type OpenAIChatMessage struct {
|
|
MessageId string `json:"messageid"` // internal field for idempotency (cannot send to openai)
|
|
Message *OpenAIMessage `json:"message,omitempty"`
|
|
FunctionCall *OpenAIFunctionCallInput `json:"functioncall,omitempty"`
|
|
FunctionCallOutput *OpenAIFunctionCallOutputInput `json:"functioncalloutput,omitempty"`
|
|
Usage *OpenAIUsage
|
|
}
|
|
|
|
type OpenAIMessage struct {
|
|
Role string `json:"role"`
|
|
Content []OpenAIMessageContent `json:"content"`
|
|
}
|
|
|
|
type OpenAIFunctionCallInput struct {
|
|
Type string `json:"type"` // Required: The type of the function tool call. Always function_call
|
|
CallId string `json:"call_id"` // Required: The unique ID of the function tool call generated by the model
|
|
Name string `json:"name"` // Required: The name of the function to run
|
|
Arguments string `json:"arguments"` // Required: A JSON string of the arguments to pass to the function
|
|
Status string `json:"status,omitempty"` // Optional: The status of the item. One of in_progress, completed, or incomplete
|
|
ToolUseData *uctypes.UIMessageDataToolUse `json:"toolusedata,omitempty"` // Internal field for UI tool use data (must be cleaned before sending to API)
|
|
// removed the "id" field (optional to send back in inputs)
|
|
}
|
|
|
|
type OpenAIFunctionCallOutputInput struct {
|
|
Type string `json:"type"` // Required: The type of the function tool call output. Always function_call_output
|
|
CallId string `json:"call_id"` // Required: The unique ID of the function tool call generated by the model
|
|
Output interface{} `json:"output"` // Required: Text, image, or file output of the function tool call
|
|
// removed "status" field (not required for inputs)
|
|
// removed the "id" field (optional to send back in inputs)
|
|
}
|
|
|
|
type OpenAIFunctionCallErrorOutput struct {
|
|
Ok string `json:"ok"`
|
|
Error string `json:"error"`
|
|
}
|
|
|
|
type OpenAIMessageContent struct {
|
|
Type string `json:"type"` // "input_text", "output_text", "input_image", "input_file", "function_call"
|
|
Text string `json:"text,omitempty"`
|
|
ImageUrl string `json:"image_url,omitempty"`
|
|
PreviewUrl string `json:"previewurl,omitempty"` // internal field for 128x128 webp data url (cannot send to API)
|
|
Filename string `json:"filename,omitempty"`
|
|
FileData string `json:"file_data,omitempty"`
|
|
|
|
// for Tools (type will be "function_call")
|
|
Arguments any `json:"arguments,omitempty"`
|
|
CallId string `json:"call_id,omitempty"`
|
|
Name string `json:"name,omitempty"`
|
|
}
|
|
|
|
func (c *OpenAIMessageContent) clean() *OpenAIMessageContent {
|
|
if c.PreviewUrl == "" && (c.Type != "input_image" || c.Filename == "") {
|
|
return c
|
|
}
|
|
rtn := *c
|
|
rtn.PreviewUrl = ""
|
|
if c.Type == "input_image" {
|
|
rtn.Filename = ""
|
|
}
|
|
return &rtn
|
|
}
|
|
|
|
func (m *OpenAIMessage) cleanAndCopy() *OpenAIMessage {
|
|
rtn := &OpenAIMessage{Role: m.Role}
|
|
rtn.Content = make([]OpenAIMessageContent, len(m.Content))
|
|
for idx, block := range m.Content {
|
|
cleaned := block.clean()
|
|
rtn.Content[idx] = *cleaned
|
|
}
|
|
return rtn
|
|
}
|
|
|
|
func (f *OpenAIFunctionCallInput) clean() *OpenAIFunctionCallInput {
|
|
if f.ToolUseData == nil {
|
|
return f
|
|
}
|
|
rtn := *f
|
|
rtn.ToolUseData = nil
|
|
return &rtn
|
|
}
|
|
|
|
type openAIErrorResponse struct {
|
|
Error openAIErrorType `json:"error"`
|
|
}
|
|
|
|
type openAIErrorType struct {
|
|
Message string `json:"message"`
|
|
Type string `json:"type"`
|
|
Code string `json:"code"`
|
|
}
|
|
|
|
func (m *OpenAIChatMessage) GetMessageId() string {
|
|
return m.MessageId
|
|
}
|
|
|
|
func (m *OpenAIChatMessage) GetRole() string {
|
|
if m.Message != nil {
|
|
return m.Message.Role
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func (m *OpenAIChatMessage) GetUsage() *uctypes.AIUsage {
|
|
if m.Usage == nil {
|
|
return nil
|
|
}
|
|
return &uctypes.AIUsage{
|
|
APIType: uctypes.APIType_OpenAIResponses,
|
|
Model: m.Usage.Model,
|
|
InputTokens: m.Usage.InputTokens,
|
|
OutputTokens: m.Usage.OutputTokens,
|
|
NativeWebSearchCount: m.Usage.NativeWebSearchCount,
|
|
}
|
|
}
|
|
|
|
// ---------- OpenAI SSE Event Types ----------
|
|
|
|
type openaiResponseCreatedEvent struct {
|
|
Type string `json:"type"`
|
|
SequenceNumber int `json:"sequence_number"`
|
|
Response openaiResponse `json:"response"`
|
|
}
|
|
|
|
type openaiResponseInProgressEvent struct {
|
|
Type string `json:"type"`
|
|
SequenceNumber int `json:"sequence_number"`
|
|
Response openaiResponse `json:"response"`
|
|
}
|
|
|
|
type openaiResponseOutputItemAddedEvent struct {
|
|
Type string `json:"type"`
|
|
SequenceNumber int `json:"sequence_number"`
|
|
OutputIndex int `json:"output_index"`
|
|
Item openaiOutputItem `json:"item"`
|
|
}
|
|
|
|
type openaiResponseOutputItemDoneEvent struct {
|
|
Type string `json:"type"`
|
|
SequenceNumber int `json:"sequence_number"`
|
|
OutputIndex int `json:"output_index"`
|
|
Item openaiOutputItem `json:"item"`
|
|
}
|
|
|
|
type openaiResponseContentPartAddedEvent struct {
|
|
Type string `json:"type"`
|
|
SequenceNumber int `json:"sequence_number"`
|
|
ItemId string `json:"item_id"`
|
|
OutputIndex int `json:"output_index"`
|
|
ContentIndex int `json:"content_index"`
|
|
Part OpenAIMessageContent `json:"part"`
|
|
}
|
|
|
|
type openaiResponseOutputTextDeltaEvent struct {
|
|
Type string `json:"type"`
|
|
SequenceNumber int `json:"sequence_number"`
|
|
ItemId string `json:"item_id"`
|
|
OutputIndex int `json:"output_index"`
|
|
ContentIndex int `json:"content_index"`
|
|
Delta string `json:"delta"`
|
|
Logprobs []string `json:"logprobs"`
|
|
Obfuscation string `json:"obfuscation"`
|
|
}
|
|
|
|
type openaiResponseOutputTextDoneEvent struct {
|
|
Type string `json:"type"`
|
|
SequenceNumber int `json:"sequence_number"`
|
|
ItemId string `json:"item_id"`
|
|
OutputIndex int `json:"output_index"`
|
|
ContentIndex int `json:"content_index"`
|
|
Text string `json:"text"`
|
|
Logprobs []string `json:"logprobs"`
|
|
}
|
|
|
|
type openaiResponseContentPartDoneEvent struct {
|
|
Type string `json:"type"`
|
|
SequenceNumber int `json:"sequence_number"`
|
|
ItemId string `json:"item_id"`
|
|
OutputIndex int `json:"output_index"`
|
|
ContentIndex int `json:"content_index"`
|
|
Part OpenAIMessageContent `json:"part"`
|
|
}
|
|
|
|
type openaiResponseCompletedEvent struct {
|
|
Type string `json:"type"`
|
|
SequenceNumber int `json:"sequence_number"`
|
|
Response openaiResponse `json:"response"`
|
|
}
|
|
|
|
type openaiResponseFunctionCallArgumentsDeltaEvent struct {
|
|
Type string `json:"type"`
|
|
SequenceNumber int `json:"sequence_number"`
|
|
ItemId string `json:"item_id"`
|
|
OutputIndex int `json:"output_index"`
|
|
Delta string `json:"delta"`
|
|
}
|
|
|
|
type openaiResponseFunctionCallArgumentsDoneEvent struct {
|
|
Type string `json:"type"`
|
|
SequenceNumber int `json:"sequence_number"`
|
|
ItemId string `json:"item_id"`
|
|
OutputIndex int `json:"output_index"`
|
|
Arguments string `json:"arguments"`
|
|
}
|
|
|
|
type openaiResponseReasoningSummaryPartAddedEvent struct {
|
|
Type string `json:"type"`
|
|
SequenceNumber int `json:"sequence_number"`
|
|
ItemId string `json:"item_id"`
|
|
OutputIndex int `json:"output_index"`
|
|
SummaryIndex int `json:"summary_index"`
|
|
Part openaiReasoningSummaryPart `json:"part"`
|
|
}
|
|
|
|
type openaiResponseReasoningSummaryPartDoneEvent struct {
|
|
Type string `json:"type"`
|
|
SequenceNumber int `json:"sequence_number"`
|
|
ItemId string `json:"item_id"`
|
|
OutputIndex int `json:"output_index"`
|
|
SummaryIndex int `json:"summary_index"`
|
|
Part openaiReasoningSummaryPart `json:"part"`
|
|
}
|
|
|
|
type openaiReasoningSummaryPart struct {
|
|
Type string `json:"type"`
|
|
Text string `json:"text"`
|
|
}
|
|
|
|
type openaiResponseReasoningSummaryTextDeltaEvent struct {
|
|
Type string `json:"type"`
|
|
SequenceNumber int `json:"sequence_number"`
|
|
ItemId string `json:"item_id"`
|
|
OutputIndex int `json:"output_index"`
|
|
SummaryIndex int `json:"summary_index"`
|
|
Delta string `json:"delta"`
|
|
}
|
|
|
|
type openaiResponseReasoningSummaryTextDoneEvent struct {
|
|
Type string `json:"type"`
|
|
SequenceNumber int `json:"sequence_number"`
|
|
ItemId string `json:"item_id"`
|
|
OutputIndex int `json:"output_index"`
|
|
SummaryIndex int `json:"summary_index"`
|
|
Text string `json:"text"`
|
|
}
|
|
|
|
// ---------- OpenAI Response Structure Types ----------
|
|
|
|
type openaiResponse struct {
|
|
Id string `json:"id"`
|
|
Object string `json:"object"`
|
|
CreatedAt int64 `json:"created_at"`
|
|
Status string `json:"status"`
|
|
Background bool `json:"background"`
|
|
Error *openaiError `json:"error"`
|
|
IncompleteDetails *openaiIncompleteInfo `json:"incomplete_details"`
|
|
Instructions *string `json:"instructions"`
|
|
MaxOutputTokens *int `json:"max_output_tokens"`
|
|
MaxToolCalls *int `json:"max_tool_calls"`
|
|
Model string `json:"model"`
|
|
Output []openaiOutputItem `json:"output"`
|
|
ParallelToolCalls bool `json:"parallel_tool_calls"`
|
|
PreviousResponseId *string `json:"previous_response_id"`
|
|
PromptCacheKey *string `json:"prompt_cache_key"`
|
|
Reasoning openaiReasoning `json:"reasoning"`
|
|
SafetyIdentifier *string `json:"safety_identifier"`
|
|
ServiceTier string `json:"service_tier"`
|
|
Store bool `json:"store"`
|
|
Temperature float64 `json:"temperature"`
|
|
Text openaiTextConfig `json:"text"`
|
|
ToolChoice string `json:"tool_choice"`
|
|
Tools []OpenAIRequestTool `json:"tools"`
|
|
TopLogprobs int `json:"top_logprobs"`
|
|
TopP float64 `json:"top_p"`
|
|
Truncation string `json:"truncation"`
|
|
Usage *OpenAIUsage `json:"usage"`
|
|
User *string `json:"user"`
|
|
Metadata map[string]interface{} `json:"metadata"`
|
|
}
|
|
|
|
type openaiOutputItem struct {
|
|
Id string `json:"id"`
|
|
Type string `json:"type"`
|
|
Status string `json:"status,omitempty"`
|
|
Content []OpenAIMessageContent `json:"content,omitempty"`
|
|
Role string `json:"role,omitempty"`
|
|
Summary []openaiReasoningSummaryPart `json:"summary,omitempty"`
|
|
|
|
// tools (type="function_call")
|
|
Name string `json:"name,omitempty"`
|
|
CallId string `json:"call_id,omitempty"`
|
|
Arguments string `json:"arguments,omitempty"`
|
|
}
|
|
|
|
type openaiReasoning struct {
|
|
Effort string `json:"effort"`
|
|
Summary *string `json:"summary"`
|
|
}
|
|
|
|
type openaiTextConfig struct {
|
|
Format openaiTextFormat `json:"format"`
|
|
Verbosity string `json:"verbosity"`
|
|
}
|
|
|
|
type openaiTextFormat struct {
|
|
Type string `json:"type"`
|
|
}
|
|
|
|
type OpenAIUsage struct {
|
|
InputTokens int `json:"input_tokens,omitempty"`
|
|
OutputTokens int `json:"output_tokens,omitempty"`
|
|
TotalTokens int `json:"total_tokens,omitempty"`
|
|
InputTokensDetails *openaiInputTokensDetails `json:"input_tokens_details,omitempty"`
|
|
OutputTokensDetails *openaiOutputTokensDetails `json:"output_tokens_details,omitempty"`
|
|
Model string `json:"model,omitempty"` // internal field (not from OpenAI API)
|
|
NativeWebSearchCount int `json:"nativewebsearchcount,omitempty"` // internal field (not from OpenAI API)
|
|
}
|
|
|
|
type openaiInputTokensDetails struct {
|
|
CachedTokens int `json:"cached_tokens"`
|
|
}
|
|
|
|
type openaiOutputTokensDetails struct {
|
|
ReasoningTokens int `json:"reasoning_tokens"`
|
|
}
|
|
|
|
type openaiError struct {
|
|
// Error details - can be expanded later
|
|
}
|
|
|
|
type openaiIncompleteInfo struct {
|
|
Reason string `json:"reason"`
|
|
}
|
|
|
|
// ---------- OpenAI streaming state types ----------
|
|
|
|
type openaiBlockKind int
|
|
|
|
const (
|
|
openaiBlockText openaiBlockKind = iota
|
|
openaiBlockReasoning
|
|
openaiBlockToolUse
|
|
)
|
|
|
|
type openaiBlockState struct {
|
|
kind openaiBlockKind
|
|
localID string // For SSE streaming to UI
|
|
toolCallID string // For function calls
|
|
toolName string // For function calls
|
|
summaryCount int // For reasoning: number of summary parts seen
|
|
partialJSON []byte // For function calls: accumulated JSON arguments
|
|
}
|
|
|
|
type openaiStreamingState struct {
|
|
blockMap map[string]*openaiBlockState // Use item_id as key for UI streaming
|
|
msgID string
|
|
model string
|
|
stepStarted bool
|
|
chatOpts uctypes.WaveChatOpts
|
|
webSearchCount int
|
|
}
|
|
|
|
// ---------- Public entrypoint ----------
|
|
|
|
func UpdateToolUseData(chatId string, callId string, newToolUseData uctypes.UIMessageDataToolUse) error {
|
|
chat := chatstore.DefaultChatStore.Get(chatId)
|
|
if chat == nil {
|
|
return fmt.Errorf("chat not found: %s", chatId)
|
|
}
|
|
|
|
for _, genMsg := range chat.NativeMessages {
|
|
chatMsg, ok := genMsg.(*OpenAIChatMessage)
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
if chatMsg.FunctionCall != nil && chatMsg.FunctionCall.CallId == callId {
|
|
updatedMsg := *chatMsg
|
|
updatedFunctionCall := *chatMsg.FunctionCall
|
|
updatedFunctionCall.ToolUseData = &newToolUseData
|
|
updatedMsg.FunctionCall = &updatedFunctionCall
|
|
|
|
aiOpts := &uctypes.AIOptsType{
|
|
APIType: chat.APIType,
|
|
Model: chat.Model,
|
|
APIVersion: chat.APIVersion,
|
|
}
|
|
|
|
return chatstore.DefaultChatStore.PostMessage(chatId, aiOpts, &updatedMsg)
|
|
}
|
|
}
|
|
|
|
return fmt.Errorf("function call with callId %s not found in chat %s", callId, chatId)
|
|
}
|
|
|
|
func RunOpenAIChatStep(
|
|
ctx context.Context,
|
|
sse *sse.SSEHandlerCh,
|
|
chatOpts uctypes.WaveChatOpts,
|
|
cont *uctypes.WaveContinueResponse,
|
|
) (*uctypes.WaveStopReason, []*OpenAIChatMessage, *uctypes.RateLimitInfo, error) {
|
|
if sse == nil {
|
|
return nil, nil, nil, errors.New("sse handler is nil")
|
|
}
|
|
|
|
// Get chat from store
|
|
chat := chatstore.DefaultChatStore.Get(chatOpts.ChatId)
|
|
if chat == nil {
|
|
return nil, nil, nil, fmt.Errorf("chat not found: %s", chatOpts.ChatId)
|
|
}
|
|
|
|
// Validate that chatOpts.Config match the chat's stored configuration
|
|
if chat.APIType != chatOpts.Config.APIType {
|
|
return nil, nil, nil, fmt.Errorf("API type mismatch: chat has %s, chatOpts has %s", chat.APIType, chatOpts.Config.APIType)
|
|
}
|
|
if !uctypes.AreModelsCompatible(chat.APIType, chat.Model, chatOpts.Config.Model) {
|
|
return nil, nil, nil, fmt.Errorf("model mismatch: chat has %s, chatOpts has %s", chat.Model, chatOpts.Config.Model)
|
|
}
|
|
if chat.APIVersion != chatOpts.Config.APIVersion {
|
|
return nil, nil, nil, fmt.Errorf("API version mismatch: chat has %s, chatOpts has %s", chat.APIVersion, chatOpts.Config.APIVersion)
|
|
}
|
|
|
|
// Context with timeout if provided.
|
|
if chatOpts.Config.TimeoutMs > 0 {
|
|
var cancel context.CancelFunc
|
|
ctx, cancel = context.WithTimeout(ctx, time.Duration(chatOpts.Config.TimeoutMs)*time.Millisecond)
|
|
defer cancel()
|
|
}
|
|
|
|
// Validate continuation if provided
|
|
if cont != nil {
|
|
if !uctypes.AreModelsCompatible(chat.APIType, chatOpts.Config.Model, cont.Model) {
|
|
return nil, nil, nil, fmt.Errorf("cannot continue with a different model, model:%q, cont-model:%q", chatOpts.Config.Model, cont.Model)
|
|
}
|
|
}
|
|
|
|
// Convert GenAIMessages to input objects (OpenAIMessage or OpenAIFunctionCallInput)
|
|
var inputs []any
|
|
for _, genMsg := range chat.NativeMessages {
|
|
// Cast to OpenAIChatMessage
|
|
chatMsg, ok := genMsg.(*OpenAIChatMessage)
|
|
if !ok {
|
|
return nil, nil, nil, fmt.Errorf("expected OpenAIChatMessage, got %T", genMsg)
|
|
}
|
|
|
|
// Convert to appropriate input type based on what's populated
|
|
if chatMsg.Message != nil {
|
|
// Clean message to remove preview URLs
|
|
cleanedMsg := chatMsg.Message.cleanAndCopy()
|
|
inputs = append(inputs, *cleanedMsg)
|
|
} else if chatMsg.FunctionCall != nil {
|
|
cleanedFunctionCall := chatMsg.FunctionCall.clean()
|
|
inputs = append(inputs, *cleanedFunctionCall)
|
|
} else if chatMsg.FunctionCallOutput != nil {
|
|
inputs = append(inputs, *chatMsg.FunctionCallOutput)
|
|
}
|
|
}
|
|
|
|
req, err := buildOpenAIHTTPRequest(ctx, inputs, chatOpts, cont)
|
|
if err != nil {
|
|
return nil, nil, nil, err
|
|
}
|
|
|
|
httpClient := &http.Client{
|
|
Timeout: 0, // rely on ctx; streaming can be long
|
|
}
|
|
// Proxy support
|
|
if chatOpts.Config.ProxyURL != "" {
|
|
pURL, perr := url.Parse(chatOpts.Config.ProxyURL)
|
|
if perr != nil {
|
|
return nil, nil, nil, fmt.Errorf("invalid proxy URL: %w", perr)
|
|
}
|
|
httpClient.Transport = &http.Transport{
|
|
Proxy: http.ProxyURL(pURL),
|
|
}
|
|
}
|
|
|
|
resp, err := httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, nil, nil, sanitizeHostnameInError(err, chatOpts.Config.BaseURL)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// Parse rate limit info from header if present (do this before error check)
|
|
rateLimitInfo := uctypes.ParseRateLimitHeader(resp.Header.Get("X-Wave-RateLimit"))
|
|
|
|
ct := resp.Header.Get("Content-Type")
|
|
if resp.StatusCode != http.StatusOK || !strings.HasPrefix(ct, "text/event-stream") {
|
|
// Handle 429 rate limit with special logic
|
|
if resp.StatusCode == http.StatusTooManyRequests && rateLimitInfo != nil {
|
|
if rateLimitInfo.PReq == 0 && rateLimitInfo.Req > 0 {
|
|
// Premium requests exhausted, but regular requests available
|
|
stopReason := &uctypes.WaveStopReason{
|
|
Kind: uctypes.StopKindPremiumRateLimit,
|
|
}
|
|
return stopReason, nil, rateLimitInfo, nil
|
|
}
|
|
if rateLimitInfo.Req == 0 {
|
|
// All requests exhausted
|
|
stopReason := &uctypes.WaveStopReason{
|
|
Kind: uctypes.StopKindRateLimit,
|
|
}
|
|
return stopReason, nil, rateLimitInfo, nil
|
|
}
|
|
}
|
|
return nil, nil, rateLimitInfo, parseOpenAIHTTPError(resp)
|
|
}
|
|
|
|
// At this point we have a valid SSE stream, so setup SSE handling
|
|
// From here on, errors must be returned through the SSE stream
|
|
if cont == nil {
|
|
sse.SetupSSE()
|
|
}
|
|
|
|
// Use eventsource decoder for proper SSE parsing
|
|
decoder := eventsource.NewDecoder(resp.Body)
|
|
|
|
stopReason, rtnMessages := handleOpenAIStreamingResp(ctx, sse, decoder, cont, chatOpts)
|
|
return stopReason, rtnMessages, rateLimitInfo, nil
|
|
}
|
|
|
|
// parseOpenAIHTTPError parses OpenAI API HTTP error responses
|
|
func parseOpenAIHTTPError(resp *http.Response) error {
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return fmt.Errorf("openai %s: failed to read error response: %v", resp.Status, err)
|
|
}
|
|
|
|
logutil.DevPrintf("openai full error: %s\n", body)
|
|
|
|
// Try to parse as OpenAI error format first
|
|
var errorResp openAIErrorResponse
|
|
if err := json.Unmarshal(body, &errorResp); err == nil && errorResp.Error.Message != "" {
|
|
return fmt.Errorf("openai %s: %s", resp.Status, errorResp.Error.Message)
|
|
}
|
|
|
|
// Try to parse as proxy error format
|
|
var proxyErr uctypes.ProxyErrorResponse
|
|
if err := json.Unmarshal(body, &proxyErr); err == nil && !proxyErr.Success && proxyErr.Error != "" {
|
|
return fmt.Errorf("openai %s: %s", resp.Status, proxyErr.Error)
|
|
}
|
|
|
|
return fmt.Errorf("openai %s: %s", resp.Status, utilfn.TruncateString(string(body), 120))
|
|
}
|
|
|
|
// handleOpenAIStreamingResp handles the OpenAI SSE streaming response
|
|
func handleOpenAIStreamingResp(ctx context.Context, sse *sse.SSEHandlerCh, decoder *eventsource.Decoder, cont *uctypes.WaveContinueResponse, chatOpts uctypes.WaveChatOpts) (*uctypes.WaveStopReason, []*OpenAIChatMessage) {
|
|
// Per-response state
|
|
state := &openaiStreamingState{
|
|
blockMap: map[string]*openaiBlockState{},
|
|
chatOpts: chatOpts,
|
|
}
|
|
|
|
var rtnStopReason *uctypes.WaveStopReason
|
|
var rtnMessages []*OpenAIChatMessage
|
|
|
|
// Ensure step is closed on error/cancellation
|
|
defer func() {
|
|
if !state.stepStarted {
|
|
return
|
|
}
|
|
_ = sse.AiMsgFinishStep()
|
|
if rtnStopReason == nil || rtnStopReason.Kind != uctypes.StopKindToolUse {
|
|
_ = sse.AiMsgFinish("", nil)
|
|
}
|
|
}()
|
|
|
|
// SSE event processing loop
|
|
for {
|
|
// Check for context cancellation
|
|
if err := ctx.Err(); err != nil {
|
|
_ = sse.AiMsgError("request cancelled")
|
|
return &uctypes.WaveStopReason{
|
|
Kind: uctypes.StopKindCanceled,
|
|
ErrorType: "cancelled",
|
|
ErrorText: "request cancelled",
|
|
}, rtnMessages
|
|
}
|
|
|
|
event, err := decoder.Decode()
|
|
if err != nil {
|
|
if errors.Is(err, io.EOF) {
|
|
// EOF without proper completion - protocol error
|
|
_ = sse.AiMsgError("stream ended unexpectedly without completion")
|
|
return &uctypes.WaveStopReason{
|
|
Kind: uctypes.StopKindError,
|
|
ErrorType: "protocol",
|
|
ErrorText: "stream ended unexpectedly without completion",
|
|
}, rtnMessages
|
|
}
|
|
// transport error mid-stream
|
|
_ = sse.AiMsgError(err.Error())
|
|
return &uctypes.WaveStopReason{
|
|
Kind: uctypes.StopKindError,
|
|
ErrorType: "stream",
|
|
ErrorText: err.Error(),
|
|
}, rtnMessages
|
|
}
|
|
|
|
if finalStopReason, finalMessages := handleOpenAIEvent(event, sse, state, cont); finalStopReason != nil {
|
|
// Either error or response.completed triggered return
|
|
rtnStopReason = finalStopReason
|
|
if finalMessages != nil {
|
|
rtnMessages = finalMessages
|
|
}
|
|
return finalStopReason, rtnMessages
|
|
}
|
|
}
|
|
|
|
// unreachable
|
|
}
|
|
|
|
// handleOpenAIEvent processes one SSE event block. It may emit SSE parts
|
|
// and/or return a StopReason and final message when the stream is complete.
|
|
//
|
|
// Return tuple:
|
|
// - final: a *StopReason to return immediately (e.g., after response.completed or error)
|
|
// - message: a *OpenAIChatMessage when response is completed
|
|
func handleOpenAIEvent(
|
|
event eventsource.Event,
|
|
sse *sse.SSEHandlerCh,
|
|
state *openaiStreamingState,
|
|
cont *uctypes.WaveContinueResponse,
|
|
) (final *uctypes.WaveStopReason, messages []*OpenAIChatMessage) {
|
|
eventName := event.Event()
|
|
data := event.Data()
|
|
|
|
switch eventName {
|
|
case "response.created":
|
|
var ev openaiResponseCreatedEvent
|
|
if err := json.Unmarshal([]byte(data), &ev); err != nil {
|
|
_ = sse.AiMsgError(err.Error())
|
|
return &uctypes.WaveStopReason{Kind: uctypes.StopKindError, ErrorType: "decode", ErrorText: err.Error()}, nil
|
|
}
|
|
state.msgID = ev.Response.Id
|
|
state.model = ev.Response.Model
|
|
if cont == nil {
|
|
_ = sse.AiMsgStart(state.msgID)
|
|
}
|
|
return nil, nil
|
|
|
|
case "response.in_progress":
|
|
// Start the step on in_progress
|
|
if !state.stepStarted {
|
|
_ = sse.AiMsgStartStep()
|
|
state.stepStarted = true
|
|
}
|
|
return nil, nil
|
|
|
|
case "response.output_item.added":
|
|
var ev openaiResponseOutputItemAddedEvent
|
|
if err := json.Unmarshal([]byte(data), &ev); err != nil {
|
|
_ = sse.AiMsgError(err.Error())
|
|
return &uctypes.WaveStopReason{Kind: uctypes.StopKindError, ErrorType: "decode", ErrorText: err.Error()}, nil
|
|
}
|
|
|
|
switch ev.Item.Type {
|
|
case "reasoning":
|
|
// Create reasoning block - emit start immediately
|
|
id := uuid.New().String()
|
|
state.blockMap[ev.Item.Id] = &openaiBlockState{
|
|
kind: openaiBlockReasoning,
|
|
localID: id,
|
|
summaryCount: 0,
|
|
}
|
|
_ = sse.AiMsgReasoningStart(id)
|
|
case "message":
|
|
// Message item - content parts will be handled in streaming events
|
|
case "function_call":
|
|
// Track function call info and notify frontend
|
|
id := uuid.New().String()
|
|
state.blockMap[ev.Item.Id] = &openaiBlockState{
|
|
kind: openaiBlockToolUse,
|
|
localID: id,
|
|
toolCallID: ev.Item.CallId,
|
|
toolName: ev.Item.Name,
|
|
}
|
|
// no longer send tool inputs to FE
|
|
// _ = sse.AiMsgToolInputStart(ev.Item.CallId, ev.Item.Name)
|
|
}
|
|
return nil, nil
|
|
|
|
case "response.output_item.done":
|
|
var ev openaiResponseOutputItemDoneEvent
|
|
if err := json.Unmarshal([]byte(data), &ev); err != nil {
|
|
_ = sse.AiMsgError(err.Error())
|
|
return &uctypes.WaveStopReason{Kind: uctypes.StopKindError, ErrorType: "decode", ErrorText: err.Error()}, nil
|
|
}
|
|
|
|
if st := state.blockMap[ev.Item.Id]; st != nil {
|
|
switch st.kind {
|
|
case openaiBlockReasoning:
|
|
_ = sse.AiMsgReasoningEnd(st.localID)
|
|
case openaiBlockToolUse:
|
|
// Tool input completion notification was already sent in function_call_arguments.done
|
|
// This just marks the end of the tool item itself
|
|
}
|
|
}
|
|
return nil, nil
|
|
|
|
case "response.content_part.added":
|
|
var ev openaiResponseContentPartAddedEvent
|
|
if err := json.Unmarshal([]byte(data), &ev); err != nil {
|
|
_ = sse.AiMsgError(err.Error())
|
|
return &uctypes.WaveStopReason{Kind: uctypes.StopKindError, ErrorType: "decode", ErrorText: err.Error()}, nil
|
|
}
|
|
|
|
switch ev.Part.Type {
|
|
case "output_text":
|
|
// Handle text content for UI streaming only
|
|
id := uuid.New().String()
|
|
state.blockMap[ev.ItemId] = &openaiBlockState{
|
|
kind: openaiBlockText,
|
|
localID: id,
|
|
}
|
|
_ = sse.AiMsgTextStart(id)
|
|
}
|
|
return nil, nil
|
|
|
|
case "response.output_text.delta":
|
|
var ev openaiResponseOutputTextDeltaEvent
|
|
if err := json.Unmarshal([]byte(data), &ev); err != nil {
|
|
_ = sse.AiMsgError(err.Error())
|
|
return &uctypes.WaveStopReason{Kind: uctypes.StopKindError, ErrorType: "decode", ErrorText: err.Error()}, nil
|
|
}
|
|
|
|
if st := state.blockMap[ev.ItemId]; st != nil && st.kind == openaiBlockText {
|
|
_ = sse.AiMsgTextDelta(st.localID, ev.Delta)
|
|
}
|
|
return nil, nil
|
|
|
|
case "response.output_text.done":
|
|
return nil, nil
|
|
|
|
case "response.content_part.done":
|
|
var ev openaiResponseContentPartDoneEvent
|
|
if err := json.Unmarshal([]byte(data), &ev); err != nil {
|
|
_ = sse.AiMsgError(err.Error())
|
|
return &uctypes.WaveStopReason{Kind: uctypes.StopKindError, ErrorType: "decode", ErrorText: err.Error()}, nil
|
|
}
|
|
|
|
if st := state.blockMap[ev.ItemId]; st != nil && st.kind == openaiBlockText {
|
|
_ = sse.AiMsgTextEnd(st.localID)
|
|
}
|
|
return nil, nil
|
|
|
|
case "response.completed", "response.failed", "response.incomplete":
|
|
var ev openaiResponseCompletedEvent
|
|
if err := json.Unmarshal([]byte(data), &ev); err != nil {
|
|
_ = sse.AiMsgError(err.Error())
|
|
return &uctypes.WaveStopReason{Kind: uctypes.StopKindError, ErrorType: "decode", ErrorText: err.Error()}, nil
|
|
}
|
|
|
|
// Handle error case
|
|
if ev.Response.Error != nil {
|
|
errorMsg := "OpenAI API error"
|
|
_ = sse.AiMsgError(errorMsg)
|
|
return &uctypes.WaveStopReason{
|
|
Kind: uctypes.StopKindError,
|
|
ErrorType: "api",
|
|
ErrorText: errorMsg,
|
|
}, nil
|
|
}
|
|
|
|
// Handle incomplete case
|
|
if ev.Response.IncompleteDetails != nil {
|
|
reason := ev.Response.IncompleteDetails.Reason
|
|
var stopKind uctypes.StopReasonKind
|
|
var errorMsg string
|
|
|
|
switch reason {
|
|
case "max_output_tokens":
|
|
stopKind = uctypes.StopKindMaxTokens
|
|
errorMsg = "Maximum output tokens reached"
|
|
case "max_prompt_tokens":
|
|
stopKind = uctypes.StopKindError
|
|
errorMsg = "Maximum prompt tokens reached"
|
|
case "content_filter":
|
|
stopKind = uctypes.StopKindContent
|
|
errorMsg = "Content filtered"
|
|
default:
|
|
stopKind = uctypes.StopKindError
|
|
errorMsg = fmt.Sprintf("Response incomplete: %s", reason)
|
|
}
|
|
|
|
// Extract partial message if available
|
|
finalMessages, _ := extractMessageAndToolsFromResponse(ev.Response, state)
|
|
|
|
_ = sse.AiMsgError(errorMsg)
|
|
return &uctypes.WaveStopReason{
|
|
Kind: stopKind,
|
|
RawReason: reason,
|
|
ErrorText: errorMsg,
|
|
}, finalMessages
|
|
}
|
|
|
|
// Extract the final message and tool calls from the response output
|
|
finalMessages, toolCalls := extractMessageAndToolsFromResponse(ev.Response, state)
|
|
|
|
stopKind := uctypes.StopKindDone
|
|
if len(toolCalls) > 0 {
|
|
stopKind = uctypes.StopKindToolUse
|
|
}
|
|
|
|
return &uctypes.WaveStopReason{
|
|
Kind: stopKind,
|
|
RawReason: ev.Response.Status,
|
|
ToolCalls: toolCalls,
|
|
}, finalMessages
|
|
|
|
case "response.function_call_arguments.delta":
|
|
var ev openaiResponseFunctionCallArgumentsDeltaEvent
|
|
if err := json.Unmarshal([]byte(data), &ev); err != nil {
|
|
_ = sse.AiMsgError(err.Error())
|
|
return &uctypes.WaveStopReason{Kind: uctypes.StopKindError, ErrorType: "decode", ErrorText: err.Error()}, nil
|
|
}
|
|
if st := state.blockMap[ev.ItemId]; st != nil && st.kind == openaiBlockToolUse {
|
|
st.partialJSON = append(st.partialJSON, []byte(ev.Delta)...)
|
|
aiutil.SendToolProgress(st.toolCallID, st.toolName, st.partialJSON, state.chatOpts, sse, true)
|
|
}
|
|
return nil, nil
|
|
|
|
case "response.function_call_arguments.done":
|
|
var ev openaiResponseFunctionCallArgumentsDoneEvent
|
|
if err := json.Unmarshal([]byte(data), &ev); err != nil {
|
|
_ = sse.AiMsgError(err.Error())
|
|
return &uctypes.WaveStopReason{Kind: uctypes.StopKindError, ErrorType: "decode", ErrorText: err.Error()}, nil
|
|
}
|
|
|
|
// Get the function call info from the block state
|
|
if st := state.blockMap[ev.ItemId]; st != nil && st.kind == openaiBlockToolUse {
|
|
aiutil.SendToolProgress(st.toolCallID, st.toolName, []byte(ev.Arguments), state.chatOpts, sse, false)
|
|
}
|
|
return nil, nil
|
|
|
|
case "response.web_search_call.in_progress":
|
|
return nil, nil
|
|
|
|
case "response.web_search_call.searching":
|
|
return nil, nil
|
|
|
|
case "response.web_search_call.completed":
|
|
state.webSearchCount++
|
|
return nil, nil
|
|
|
|
case "response.output_text.annotation.added":
|
|
return nil, nil
|
|
|
|
case "response.reasoning_summary_part.added":
|
|
var ev openaiResponseReasoningSummaryPartAddedEvent
|
|
if err := json.Unmarshal([]byte(data), &ev); err != nil {
|
|
_ = sse.AiMsgError(err.Error())
|
|
return &uctypes.WaveStopReason{Kind: uctypes.StopKindError, ErrorType: "decode", ErrorText: err.Error()}, nil
|
|
}
|
|
|
|
if st := state.blockMap[ev.ItemId]; st != nil && st.kind == openaiBlockReasoning {
|
|
if st.summaryCount > 0 {
|
|
// Not the first summary part, emit separator
|
|
_ = sse.AiMsgReasoningDelta(st.localID, "\n\n")
|
|
}
|
|
st.summaryCount++
|
|
}
|
|
return nil, nil
|
|
|
|
case "response.reasoning_summary_part.done":
|
|
return nil, nil
|
|
|
|
case "response.reasoning_summary_text.delta":
|
|
var ev openaiResponseReasoningSummaryTextDeltaEvent
|
|
if err := json.Unmarshal([]byte(data), &ev); err != nil {
|
|
_ = sse.AiMsgError(err.Error())
|
|
return &uctypes.WaveStopReason{Kind: uctypes.StopKindError, ErrorType: "decode", ErrorText: err.Error()}, nil
|
|
}
|
|
|
|
if st := state.blockMap[ev.ItemId]; st != nil && st.kind == openaiBlockReasoning {
|
|
_ = sse.AiMsgReasoningDelta(st.localID, ev.Delta)
|
|
}
|
|
return nil, nil
|
|
|
|
case "response.reasoning_summary_text.done":
|
|
return nil, nil
|
|
|
|
default:
|
|
logutil.DevPrintf("OpenAI: unknown event: %s, data: %s", eventName, data)
|
|
return nil, nil
|
|
}
|
|
}
|
|
|
|
// extractMessageAndToolsFromResponse extracts the final OpenAI message and tool calls from the completed response
|
|
func extractMessageAndToolsFromResponse(resp openaiResponse, state *openaiStreamingState) ([]*OpenAIChatMessage, []uctypes.WaveToolCall) {
|
|
var messageContent []OpenAIMessageContent
|
|
var toolCalls []uctypes.WaveToolCall
|
|
var messages []*OpenAIChatMessage
|
|
|
|
// Process all output items in the response
|
|
for _, outputItem := range resp.Output {
|
|
switch outputItem.Type {
|
|
case "message":
|
|
if outputItem.Role == "assistant" {
|
|
// Copy ALL content parts from the output item
|
|
for _, contentPart := range outputItem.Content {
|
|
messageContent = append(messageContent, OpenAIMessageContent{
|
|
Type: contentPart.Type,
|
|
Text: contentPart.Text,
|
|
})
|
|
}
|
|
}
|
|
case "function_call":
|
|
// Extract tool call information
|
|
toolCall := uctypes.WaveToolCall{
|
|
ID: outputItem.CallId,
|
|
Name: outputItem.Name,
|
|
}
|
|
|
|
// Parse arguments JSON string if present
|
|
var parsedArguments any
|
|
if outputItem.Arguments != "" {
|
|
if err := json.Unmarshal([]byte(outputItem.Arguments), &parsedArguments); err == nil {
|
|
toolCall.Input = parsedArguments
|
|
}
|
|
}
|
|
|
|
toolCalls = append(toolCalls, toolCall)
|
|
|
|
// Create separate FunctionCall message
|
|
var argsStr string
|
|
if outputItem.Arguments != "" {
|
|
argsStr = outputItem.Arguments
|
|
}
|
|
functionCallMsg := &OpenAIChatMessage{
|
|
MessageId: uuid.New().String(),
|
|
FunctionCall: &OpenAIFunctionCallInput{
|
|
Type: "function_call",
|
|
CallId: outputItem.CallId,
|
|
Name: outputItem.Name,
|
|
Arguments: argsStr,
|
|
},
|
|
}
|
|
messages = append(messages, functionCallMsg)
|
|
}
|
|
}
|
|
|
|
// Create OpenAIChatMessage with assistant message (first in slice)
|
|
usage := resp.Usage
|
|
if usage != nil {
|
|
resp.Usage.Model = resp.Model
|
|
if state.webSearchCount > 0 {
|
|
usage.NativeWebSearchCount = state.webSearchCount
|
|
}
|
|
}
|
|
assistantMessage := &OpenAIChatMessage{
|
|
MessageId: uuid.New().String(),
|
|
Message: &OpenAIMessage{
|
|
Role: "assistant",
|
|
Content: messageContent,
|
|
},
|
|
Usage: usage,
|
|
}
|
|
|
|
// Return assistant message first, followed by function call messages
|
|
allMessages := []*OpenAIChatMessage{assistantMessage}
|
|
allMessages = append(allMessages, messages...)
|
|
|
|
return allMessages, toolCalls
|
|
}
|