waveterm/pkg/aiusechat/usechat.go
Mike Sawka cd6389d1ec
Create Interface for Backend AI Providers (#2572)
Created an interface in aiusechat for the backend providers. Use that
interface throughout the usechat code. Isolate the backend
implementations to only the new file usechat-backend.go.
2025-11-19 11:38:56 -08:00

859 lines
30 KiB
Go

// Copyright 2025, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package aiusechat
import (
"context"
_ "embed"
"encoding/json"
"fmt"
"log"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/wavetermdev/waveterm/pkg/aiusechat/chatstore"
"github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes"
"github.com/wavetermdev/waveterm/pkg/telemetry"
"github.com/wavetermdev/waveterm/pkg/telemetry/telemetrydata"
"github.com/wavetermdev/waveterm/pkg/util/ds"
"github.com/wavetermdev/waveterm/pkg/util/logutil"
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
"github.com/wavetermdev/waveterm/pkg/waveappstore"
"github.com/wavetermdev/waveterm/pkg/wavebase"
"github.com/wavetermdev/waveterm/pkg/waveobj"
"github.com/wavetermdev/waveterm/pkg/web/sse"
"github.com/wavetermdev/waveterm/pkg/wps"
"github.com/wavetermdev/waveterm/pkg/wstore"
)
const (
APIType_Anthropic = "anthropic"
APIType_OpenAI = "openai"
)
const DefaultAPI = APIType_OpenAI
const DefaultMaxTokens = 4 * 1024
const BuilderMaxTokens = 24 * 1024
var (
globalRateLimitInfo = &uctypes.RateLimitInfo{Unknown: true}
rateLimitLock sync.Mutex
activeToolMap = ds.MakeSyncMap[bool]() // key is toolcallid
activeChats = ds.MakeSyncMap[bool]() // key is chatid
)
var SystemPromptText = strings.Join([]string{
`You are Wave AI, an intelligent assistant embedded within Wave Terminal, a modern terminal application with graphical widgets.`,
`You appear as a pull-out panel on the left side of a tab, with the tab's widgets laid out on the right.`,
`Widget context is provided as informationa only.`,
`Do NOT assume any API access or ability to interact with the widgets except via tools provided (note that some widgets may expose NO tools, so their context is informational only).`,
}, " ")
var SystemPromptText_OpenAI = strings.Join([]string{
`You are Wave AI, an assistant embedded in Wave Terminal (a terminal with graphical widgets).`,
`You appear as a pull-out panel on the left; widgets are on the right.`,
// Capabilities & truthfulness
`Tools define your only capabilities. If a capability is not provided by a tool, you cannot do it.`,
`Context from widgets is read-only unless a tool explicitly grants interaction.`,
`Never fabricate data. If you lack data or access, say so and offer the next best step (e.g., suggest enabling a tool).`,
// Crisp behavior
`Be concise and direct. Prefer determinism over speculation. If a brief clarifying question eliminates guesswork, ask it.`,
// Attached text files
`User-attached text files may appear inline as <AttachedTextFile_xxxxxxxx file_name="...">\ncontent\n</AttachedTextFile_xxxxxxxx>.`,
`User-attached directories use the tag <AttachedDirectoryListing_xxxxxxxx directory_name="...">JSON DirInfo</AttachedDirectoryListing_xxxxxxxx>.`,
`If multiple attached files exist, treat each as a separate source file with its own file_name.`,
`When the user refers to these files, use their inline content directly; do NOT call any read_text_file or file-access tools to re-read them unless asked.`,
// Output & formatting
`When presenting commands or any runnable multi-line code, always use fenced Markdown code blocks.`,
`Use an appropriate language hint after the opening fence (e.g., "bash" for shell commands, "go" for Go, "json" for JSON).`,
`For shell commands, do NOT prefix lines with "$" or shell prompts. Use placeholders in ALL_CAPS (e.g., PROJECT_ID) and explain them once after the block if needed.`,
"Reserve inline code (single backticks) for short references like command names (`grep`, `less`), flags, env vars, file paths, or tiny snippets not meant to be executed.",
`You may use Markdown (lists, tables, bold/italics) to improve readability.`,
`Never comment on or justify your formatting choices; just follow these rules.`,
`When generating code or command blocks, try to keep lines under ~100 characters wide where practical (soft wrap; do not break tokens mid-word). Favor indentation and short variable names to stay compact, but correctness always takes priority.`,
// Safety & limits
`If a request would execute dangerous or destructive actions, warn briefly and provide a safer alternative.`,
`If output is very long, prefer a brief summary plus a copy-ready fenced block or offer a follow-up chunking strategy.`,
`You can write and edit local files on disk using available tools, but you cannot execute shell commands, run code in the terminal, or read/write remote files.`,
`If the user asks you to execute commands or run code, or deal with remote files say that these features aren't available yet AND make sure to say that they are coming soon (stay tuned for updates).`,
`Instead, show them exactly what command or code they could copy-paste to run manually.`,
// Final reminder
`You have NO API access to widgets or Wave unless provided via an explicit tool.`,
}, " ")
func getWaveAISettings(premium bool, builderMode bool, rtInfo *waveobj.ObjRTInfo) (*uctypes.AIOptsType, error) {
baseUrl := uctypes.DefaultAIEndpoint
if os.Getenv("WAVETERM_WAVEAI_ENDPOINT") != "" {
baseUrl = os.Getenv("WAVETERM_WAVEAI_ENDPOINT")
}
maxTokens := DefaultMaxTokens
if builderMode {
maxTokens = BuilderMaxTokens
}
if rtInfo != nil && rtInfo.WaveAIMaxOutputTokens > 0 {
maxTokens = rtInfo.WaveAIMaxOutputTokens
}
var thinkingMode string
if premium {
thinkingMode = uctypes.ThinkingModeBalanced
if rtInfo != nil && rtInfo.WaveAIThinkingMode != "" {
thinkingMode = rtInfo.WaveAIThinkingMode
}
} else {
thinkingMode = uctypes.ThinkingModeQuick
}
if DefaultAPI == APIType_Anthropic {
thinkingLevel := uctypes.ThinkingLevelMedium
return &uctypes.AIOptsType{
APIType: APIType_Anthropic,
Model: uctypes.DefaultAnthropicModel,
MaxTokens: maxTokens,
ThinkingLevel: thinkingLevel,
ThinkingMode: thinkingMode,
BaseURL: baseUrl,
}, nil
} else if DefaultAPI == APIType_OpenAI {
var model string
var thinkingLevel string
switch thinkingMode {
case uctypes.ThinkingModeQuick:
model = uctypes.DefaultOpenAIModel
thinkingLevel = uctypes.ThinkingLevelLow
case uctypes.ThinkingModeBalanced:
model = uctypes.PremiumOpenAIModel
thinkingLevel = uctypes.ThinkingLevelLow
case uctypes.ThinkingModeDeep:
model = uctypes.PremiumOpenAIModel
thinkingLevel = uctypes.ThinkingLevelMedium
default:
model = uctypes.PremiumOpenAIModel
thinkingLevel = uctypes.ThinkingLevelLow
}
return &uctypes.AIOptsType{
APIType: APIType_OpenAI,
Model: model,
MaxTokens: maxTokens,
ThinkingLevel: thinkingLevel,
ThinkingMode: thinkingMode,
BaseURL: baseUrl,
}, nil
}
return nil, fmt.Errorf("invalid API type: %s", DefaultAPI)
}
func shouldUseChatCompletionsAPI(model string) bool {
m := strings.ToLower(model)
// Chat Completions API is required for older models: gpt-3.5-*, gpt-4, gpt-4-turbo, o1-*
return strings.HasPrefix(m, "gpt-3.5") ||
strings.HasPrefix(m, "gpt-4-") ||
m == "gpt-4" ||
strings.HasPrefix(m, "o1-")
}
func shouldUsePremium() bool {
info := GetGlobalRateLimit()
if info == nil || info.Unknown {
return true
}
if info.PReq > 0 {
return true
}
nowEpoch := time.Now().Unix()
if nowEpoch >= info.ResetEpoch {
return true
}
return false
}
func updateRateLimit(info *uctypes.RateLimitInfo) {
if info == nil {
return
}
rateLimitLock.Lock()
defer rateLimitLock.Unlock()
globalRateLimitInfo = info
go func() {
wps.Broker.Publish(wps.WaveEvent{
Event: wps.Event_WaveAIRateLimit,
Data: info,
})
}()
}
func GetGlobalRateLimit() *uctypes.RateLimitInfo {
rateLimitLock.Lock()
defer rateLimitLock.Unlock()
return globalRateLimitInfo
}
func runAIChatStep(ctx context.Context, sseHandler *sse.SSEHandlerCh, backend UseChatBackend, chatOpts uctypes.WaveChatOpts, cont *uctypes.WaveContinueResponse) (*uctypes.WaveStopReason, []uctypes.GenAIMessage, error) {
if chatOpts.Config.APIType == APIType_OpenAI && shouldUseChatCompletionsAPI(chatOpts.Config.Model) {
return nil, nil, fmt.Errorf("Chat completions API not available (must use newer OpenAI models)")
}
stopReason, messages, rateLimitInfo, err := backend.RunChatStep(ctx, sseHandler, chatOpts, cont)
updateRateLimit(rateLimitInfo)
return stopReason, messages, err
}
func getUsage(msgs []uctypes.GenAIMessage) uctypes.AIUsage {
var rtn uctypes.AIUsage
var found bool
for _, msg := range msgs {
if usage := msg.GetUsage(); usage != nil {
if !found {
rtn = *usage
found = true
} else {
rtn.InputTokens += usage.InputTokens
rtn.OutputTokens += usage.OutputTokens
rtn.NativeWebSearchCount += usage.NativeWebSearchCount
}
}
}
return rtn
}
func GetChatUsage(chat *uctypes.AIChat) uctypes.AIUsage {
usage := getUsage(chat.NativeMessages)
usage.APIType = chat.APIType
usage.Model = chat.Model
return usage
}
func updateToolUseDataInChat(backend UseChatBackend, chatOpts uctypes.WaveChatOpts, toolCallID string, toolUseData *uctypes.UIMessageDataToolUse) {
if err := backend.UpdateToolUseData(chatOpts.ChatId, toolCallID, toolUseData); err != nil {
log.Printf("failed to update tool use data in chat: %v\n", err)
}
}
func processToolCallInternal(backend UseChatBackend, toolCall uctypes.WaveToolCall, chatOpts uctypes.WaveChatOpts, toolDef *uctypes.ToolDefinition, sseHandler *sse.SSEHandlerCh) uctypes.AIToolResult {
if toolCall.ToolUseData == nil {
return uctypes.AIToolResult{
ToolName: toolCall.Name,
ToolUseID: toolCall.ID,
ErrorText: "Invalid Tool Call",
}
}
if toolCall.ToolUseData.Status == uctypes.ToolUseStatusError {
errorMsg := toolCall.ToolUseData.ErrorMessage
if errorMsg == "" {
errorMsg = "Unspecified Tool Error"
}
return uctypes.AIToolResult{
ToolName: toolCall.Name,
ToolUseID: toolCall.ID,
ErrorText: errorMsg,
}
}
if toolDef != nil && toolDef.ToolVerifyInput != nil {
if err := toolDef.ToolVerifyInput(toolCall.Input, toolCall.ToolUseData); err != nil {
errorMsg := fmt.Sprintf("Input validation failed: %v", err)
toolCall.ToolUseData.Status = uctypes.ToolUseStatusError
toolCall.ToolUseData.ErrorMessage = errorMsg
return uctypes.AIToolResult{
ToolName: toolCall.Name,
ToolUseID: toolCall.ID,
ErrorText: errorMsg,
}
}
// ToolVerifyInput can modify the toolusedata. re-send it here.
_ = sseHandler.AiMsgData("data-tooluse", toolCall.ID, *toolCall.ToolUseData)
updateToolUseDataInChat(backend, chatOpts, toolCall.ID, toolCall.ToolUseData)
}
if toolCall.ToolUseData.Approval == uctypes.ApprovalNeedsApproval {
log.Printf(" waiting for approval...\n")
approval := WaitForToolApproval(toolCall.ID)
log.Printf(" approval result: %q\n", approval)
if approval != "" {
toolCall.ToolUseData.Approval = approval
}
if !toolCall.ToolUseData.IsApproved() {
errorMsg := "Tool use denied or timed out"
if approval == uctypes.ApprovalUserDenied {
errorMsg = "Tool use denied by user"
} else if approval == uctypes.ApprovalTimeout {
errorMsg = "Tool approval timed out"
}
toolCall.ToolUseData.Status = uctypes.ToolUseStatusError
toolCall.ToolUseData.ErrorMessage = errorMsg
return uctypes.AIToolResult{
ToolName: toolCall.Name,
ToolUseID: toolCall.ID,
ErrorText: errorMsg,
}
}
// this still happens here because we need to update the FE to say the tool call was approved
_ = sseHandler.AiMsgData("data-tooluse", toolCall.ID, *toolCall.ToolUseData)
updateToolUseDataInChat(backend, chatOpts, toolCall.ID, toolCall.ToolUseData)
}
toolCall.ToolUseData.RunTs = time.Now().UnixMilli()
result := ResolveToolCall(toolDef, toolCall, chatOpts)
if result.ErrorText != "" {
toolCall.ToolUseData.Status = uctypes.ToolUseStatusError
toolCall.ToolUseData.ErrorMessage = result.ErrorText
} else {
toolCall.ToolUseData.Status = uctypes.ToolUseStatusCompleted
}
return result
}
func processToolCall(backend UseChatBackend, toolCall uctypes.WaveToolCall, chatOpts uctypes.WaveChatOpts, sseHandler *sse.SSEHandlerCh, metrics *uctypes.AIMetrics) uctypes.AIToolResult {
inputJSON, _ := json.Marshal(toolCall.Input)
logutil.DevPrintf("TOOLUSE name=%s id=%s input=%s approval=%q\n", toolCall.Name, toolCall.ID, utilfn.TruncateString(string(inputJSON), 40), toolCall.ToolUseData.Approval)
toolDef := chatOpts.GetToolDefinition(toolCall.Name)
result := processToolCallInternal(backend, toolCall, chatOpts, toolDef, sseHandler)
if result.ErrorText != "" {
log.Printf(" error=%s\n", result.ErrorText)
metrics.ToolUseErrorCount++
} else {
log.Printf(" result=%s\n", utilfn.TruncateString(result.Text, 40))
}
if toolDef != nil && toolDef.ToolLogName != "" {
metrics.ToolDetail[toolDef.ToolLogName]++
}
if toolCall.ToolUseData != nil {
_ = sseHandler.AiMsgData("data-tooluse", toolCall.ID, *toolCall.ToolUseData)
updateToolUseDataInChat(backend, chatOpts, toolCall.ID, toolCall.ToolUseData)
}
return result
}
func processToolCalls(backend UseChatBackend, stopReason *uctypes.WaveStopReason, chatOpts uctypes.WaveChatOpts, sseHandler *sse.SSEHandlerCh, metrics *uctypes.AIMetrics) {
for _, toolCall := range stopReason.ToolCalls {
activeToolMap.Set(toolCall.ID, true)
defer activeToolMap.Delete(toolCall.ID)
}
// Send all data-tooluse packets at the beginning
for _, toolCall := range stopReason.ToolCalls {
if toolCall.ToolUseData != nil {
log.Printf("AI data-tooluse %s\n", toolCall.ID)
_ = sseHandler.AiMsgData("data-tooluse", toolCall.ID, *toolCall.ToolUseData)
updateToolUseDataInChat(backend, chatOpts, toolCall.ID, toolCall.ToolUseData)
if toolCall.ToolUseData.Approval == uctypes.ApprovalNeedsApproval && chatOpts.RegisterToolApproval != nil {
chatOpts.RegisterToolApproval(toolCall.ID)
}
}
}
var toolResults []uctypes.AIToolResult
for _, toolCall := range stopReason.ToolCalls {
result := processToolCall(backend, toolCall, chatOpts, sseHandler, metrics)
toolResults = append(toolResults, result)
}
toolResultMsgs, err := backend.ConvertToolResultsToNativeChatMessage(toolResults)
if err != nil {
log.Printf("Failed to convert tool results to native chat messages: %v", err)
} else {
for _, msg := range toolResultMsgs {
chatstore.DefaultChatStore.PostMessage(chatOpts.ChatId, &chatOpts.Config, msg)
}
}
}
func RunAIChat(ctx context.Context, sseHandler *sse.SSEHandlerCh, backend UseChatBackend, chatOpts uctypes.WaveChatOpts) (*uctypes.AIMetrics, error) {
if !activeChats.SetUnless(chatOpts.ChatId, true) {
return nil, fmt.Errorf("chat %s is already running", chatOpts.ChatId)
}
defer activeChats.Delete(chatOpts.ChatId)
metrics := &uctypes.AIMetrics{
Usage: uctypes.AIUsage{
APIType: chatOpts.Config.APIType,
Model: chatOpts.Config.Model,
},
WidgetAccess: chatOpts.WidgetAccess,
ToolDetail: make(map[string]int),
ThinkingLevel: chatOpts.Config.ThinkingLevel,
ThinkingMode: chatOpts.Config.ThinkingMode,
}
firstStep := true
var cont *uctypes.WaveContinueResponse
for {
if chatOpts.TabStateGenerator != nil {
tabState, tabTools, tabId, tabErr := chatOpts.TabStateGenerator()
if tabErr == nil {
chatOpts.TabState = tabState
chatOpts.TabTools = tabTools
chatOpts.TabId = tabId
}
}
if chatOpts.BuilderAppGenerator != nil {
appGoFile, appStaticFiles, appErr := chatOpts.BuilderAppGenerator()
if appErr == nil {
chatOpts.AppGoFile = appGoFile
chatOpts.AppStaticFiles = appStaticFiles
}
}
stopReason, rtnMessage, err := runAIChatStep(ctx, sseHandler, backend, chatOpts, cont)
metrics.RequestCount++
if chatOpts.Config.IsPremiumModel() {
metrics.PremiumReqCount++
}
if chatOpts.Config.IsWaveProxy() {
metrics.ProxyReqCount++
}
if len(rtnMessage) > 0 {
usage := getUsage(rtnMessage)
log.Printf("usage: input=%d output=%d websearch=%d\n", usage.InputTokens, usage.OutputTokens, usage.NativeWebSearchCount)
metrics.Usage.InputTokens += usage.InputTokens
metrics.Usage.OutputTokens += usage.OutputTokens
metrics.Usage.NativeWebSearchCount += usage.NativeWebSearchCount
if usage.Model != "" && metrics.Usage.Model != usage.Model {
metrics.Usage.Model = "mixed"
}
}
if firstStep && err != nil {
metrics.HadError = true
return metrics, fmt.Errorf("failed to stream %s chat: %w", chatOpts.Config.APIType, err)
}
if err != nil {
metrics.HadError = true
_ = sseHandler.AiMsgError(err.Error())
_ = sseHandler.AiMsgFinish("", nil)
break
}
for _, msg := range rtnMessage {
if msg != nil {
chatstore.DefaultChatStore.PostMessage(chatOpts.ChatId, &chatOpts.Config, msg)
}
}
firstStep = false
if stopReason != nil && stopReason.Kind == uctypes.StopKindPremiumRateLimit && chatOpts.Config.APIType == APIType_OpenAI && chatOpts.Config.Model == uctypes.PremiumOpenAIModel {
log.Printf("Premium rate limit hit with gpt-5.1, switching to gpt-5-mini\n")
cont = &uctypes.WaveContinueResponse{
Model: uctypes.DefaultOpenAIModel,
ContinueFromKind: uctypes.StopKindPremiumRateLimit,
}
continue
}
if stopReason != nil && stopReason.Kind == uctypes.StopKindToolUse {
metrics.ToolUseCount += len(stopReason.ToolCalls)
processToolCalls(backend, stopReason, chatOpts, sseHandler, metrics)
cont = &uctypes.WaveContinueResponse{
Model: chatOpts.Config.Model,
ContinueFromKind: uctypes.StopKindToolUse,
}
continue
}
break
}
return metrics, nil
}
func ResolveToolCall(toolDef *uctypes.ToolDefinition, toolCall uctypes.WaveToolCall, chatOpts uctypes.WaveChatOpts) (result uctypes.AIToolResult) {
result = uctypes.AIToolResult{
ToolName: toolCall.Name,
ToolUseID: toolCall.ID,
}
defer func() {
if r := recover(); r != nil {
result.ErrorText = fmt.Sprintf("panic in tool execution: %v", r)
result.Text = ""
}
}()
if toolDef == nil {
result.ErrorText = fmt.Sprintf("tool '%s' not found", toolCall.Name)
return
}
// Try ToolTextCallback first, then ToolAnyCallback
if toolDef.ToolTextCallback != nil {
text, err := toolDef.ToolTextCallback(toolCall.Input)
if err != nil {
result.ErrorText = err.Error()
} else {
result.Text = text
// Recompute tool description with the result
if toolDef.ToolCallDesc != nil && toolCall.ToolUseData != nil {
toolCall.ToolUseData.ToolDesc = toolDef.ToolCallDesc(toolCall.Input, text, toolCall.ToolUseData)
}
}
} else if toolDef.ToolAnyCallback != nil {
output, err := toolDef.ToolAnyCallback(toolCall.Input, toolCall.ToolUseData)
if err != nil {
result.ErrorText = err.Error()
} else {
// Marshal the result to JSON
jsonBytes, marshalErr := json.Marshal(output)
if marshalErr != nil {
result.ErrorText = fmt.Sprintf("failed to marshal tool output: %v", marshalErr)
} else {
result.Text = string(jsonBytes)
// Recompute tool description with the result
if toolDef.ToolCallDesc != nil && toolCall.ToolUseData != nil {
toolCall.ToolUseData.ToolDesc = toolDef.ToolCallDesc(toolCall.Input, output, toolCall.ToolUseData)
}
}
}
} else {
result.ErrorText = fmt.Sprintf("tool '%s' has no callback functions", toolCall.Name)
}
return
}
func WaveAIPostMessageWrap(ctx context.Context, sseHandler *sse.SSEHandlerCh, message *uctypes.AIMessage, chatOpts uctypes.WaveChatOpts) error {
startTime := time.Now()
// Convert AIMessage to native chat message using backend
backend, err := GetBackendByAPIType(chatOpts.Config.APIType)
if err != nil {
return err
}
convertedMessage, err := backend.ConvertAIMessageToNativeChatMessage(*message)
if err != nil {
return fmt.Errorf("message conversion failed: %w", err)
}
// Post message to chat store
if err := chatstore.DefaultChatStore.PostMessage(chatOpts.ChatId, &chatOpts.Config, convertedMessage); err != nil {
return fmt.Errorf("failed to store message: %w", err)
}
metrics, err := RunAIChat(ctx, sseHandler, backend, chatOpts)
if metrics != nil {
metrics.RequestDuration = int(time.Since(startTime).Milliseconds())
for _, part := range message.Parts {
if part.Type == uctypes.AIMessagePartTypeText {
metrics.TextLen += len(part.Text)
} else if part.Type == uctypes.AIMessagePartTypeFile {
mimeType := strings.ToLower(part.MimeType)
if strings.HasPrefix(mimeType, "image/") {
metrics.ImageCount++
} else if mimeType == "application/pdf" {
metrics.PDFCount++
} else {
metrics.TextDocCount++
}
}
}
log.Printf("WaveAI call metrics: requests=%d tools=%d premium=%d proxy=%d images=%d pdfs=%d textdocs=%d textlen=%d duration=%dms error=%v\n",
metrics.RequestCount, metrics.ToolUseCount, metrics.PremiumReqCount, metrics.ProxyReqCount,
metrics.ImageCount, metrics.PDFCount, metrics.TextDocCount, metrics.TextLen, metrics.RequestDuration, metrics.HadError)
sendAIMetricsTelemetry(ctx, metrics)
}
return err
}
func sendAIMetricsTelemetry(ctx context.Context, metrics *uctypes.AIMetrics) {
event := telemetrydata.MakeTEvent("waveai:post", telemetrydata.TEventProps{
WaveAIAPIType: metrics.Usage.APIType,
WaveAIModel: metrics.Usage.Model,
WaveAIInputTokens: metrics.Usage.InputTokens,
WaveAIOutputTokens: metrics.Usage.OutputTokens,
WaveAINativeWebSearchCount: metrics.Usage.NativeWebSearchCount,
WaveAIRequestCount: metrics.RequestCount,
WaveAIToolUseCount: metrics.ToolUseCount,
WaveAIToolUseErrorCount: metrics.ToolUseErrorCount,
WaveAIToolDetail: metrics.ToolDetail,
WaveAIPremiumReq: metrics.PremiumReqCount,
WaveAIProxyReq: metrics.ProxyReqCount,
WaveAIHadError: metrics.HadError,
WaveAIImageCount: metrics.ImageCount,
WaveAIPDFCount: metrics.PDFCount,
WaveAITextDocCount: metrics.TextDocCount,
WaveAITextLen: metrics.TextLen,
WaveAIFirstByteMs: metrics.FirstByteLatency,
WaveAIRequestDurMs: metrics.RequestDuration,
WaveAIWidgetAccess: metrics.WidgetAccess,
WaveAIThinkingLevel: metrics.ThinkingLevel,
WaveAIThinkingMode: metrics.ThinkingMode,
})
_ = telemetry.RecordTEvent(ctx, event)
}
// PostMessageRequest represents the request body for posting a message
type PostMessageRequest struct {
TabId string `json:"tabid,omitempty"`
BuilderId string `json:"builderid,omitempty"`
BuilderAppId string `json:"builderappid,omitempty"`
ChatID string `json:"chatid"`
Msg uctypes.AIMessage `json:"msg"`
WidgetAccess bool `json:"widgetaccess,omitempty"`
}
func WaveAIPostMessageHandler(w http.ResponseWriter, r *http.Request) {
// Only allow POST method
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Parse request body
var req PostMessageRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest)
return
}
// Validate chatid is present and is a UUID
if req.ChatID == "" {
http.Error(w, "chatid is required in request body", http.StatusBadRequest)
return
}
if _, err := uuid.Parse(req.ChatID); err != nil {
http.Error(w, "chatid must be a valid UUID", http.StatusBadRequest)
return
}
// Get RTInfo from TabId or BuilderId
var rtInfo *waveobj.ObjRTInfo
if req.TabId != "" {
oref := waveobj.MakeORef(waveobj.OType_Tab, req.TabId)
rtInfo = wstore.GetRTInfo(oref)
} else if req.BuilderId != "" {
oref := waveobj.MakeORef(waveobj.OType_Builder, req.BuilderId)
rtInfo = wstore.GetRTInfo(oref)
}
// Get WaveAI settings
premium := shouldUsePremium()
builderMode := req.BuilderId != ""
aiOpts, err := getWaveAISettings(premium, builderMode, rtInfo)
if err != nil {
http.Error(w, fmt.Sprintf("WaveAI configuration error: %v", err), http.StatusInternalServerError)
return
}
// Get client ID from database
client, err := wstore.DBGetSingleton[*waveobj.Client](r.Context())
if err != nil {
http.Error(w, fmt.Sprintf("Failed to get client: %v", err), http.StatusInternalServerError)
return
}
// Call the core WaveAIPostMessage function
chatOpts := uctypes.WaveChatOpts{
ChatId: req.ChatID,
ClientId: client.OID,
Config: *aiOpts,
WidgetAccess: req.WidgetAccess,
RegisterToolApproval: RegisterToolApproval,
AllowNativeWebSearch: true,
BuilderId: req.BuilderId,
BuilderAppId: req.BuilderAppId,
}
if chatOpts.Config.APIType == APIType_OpenAI {
if chatOpts.BuilderId != "" {
chatOpts.SystemPrompt = []string{}
} else {
chatOpts.SystemPrompt = []string{SystemPromptText_OpenAI}
}
} else {
chatOpts.SystemPrompt = []string{SystemPromptText}
}
if req.TabId != "" {
chatOpts.TabStateGenerator = func() (string, []uctypes.ToolDefinition, string, error) {
tabState, tabTools, err := GenerateTabStateAndTools(r.Context(), req.TabId, req.WidgetAccess)
return tabState, tabTools, req.TabId, err
}
}
if req.BuilderAppId != "" {
chatOpts.BuilderAppGenerator = func() (string, string, error) {
return generateBuilderAppData(req.BuilderAppId)
}
}
if req.BuilderAppId != "" {
chatOpts.Tools = append(chatOpts.Tools,
GetBuilderWriteAppFileToolDefinition(req.BuilderAppId, req.BuilderId),
GetBuilderEditAppFileToolDefinition(req.BuilderAppId, req.BuilderId),
GetBuilderListFilesToolDefinition(req.BuilderAppId),
)
}
// Validate the message
if err := req.Msg.Validate(); err != nil {
http.Error(w, fmt.Sprintf("Message validation failed: %v", err), http.StatusInternalServerError)
return
}
// Create SSE handler and set up streaming
sseHandler := sse.MakeSSEHandlerCh(w, r.Context())
defer sseHandler.Close()
if err := WaveAIPostMessageWrap(r.Context(), sseHandler, &req.Msg, chatOpts); err != nil {
http.Error(w, fmt.Sprintf("Failed to post message: %v", err), http.StatusInternalServerError)
return
}
}
func WaveAIGetChatHandler(w http.ResponseWriter, r *http.Request) {
// Only allow GET method
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Get chatid from URL parameters
chatID := r.URL.Query().Get("chatid")
if chatID == "" {
http.Error(w, "chatid parameter is required", http.StatusBadRequest)
return
}
// Validate chatid is a UUID
if _, err := uuid.Parse(chatID); err != nil {
http.Error(w, "chatid must be a valid UUID", http.StatusBadRequest)
return
}
// Get chat from store
chat := chatstore.DefaultChatStore.Get(chatID)
if chat == nil {
http.Error(w, "chat not found", http.StatusNotFound)
return
}
// Set response headers for JSON
w.Header().Set("Content-Type", "application/json")
// Encode and return the chat
if err := json.NewEncoder(w).Encode(chat); err != nil {
http.Error(w, fmt.Sprintf("Failed to encode response: %v", err), http.StatusInternalServerError)
return
}
}
// CreateWriteTextFileDiff generates a diff for write_text_file or edit_text_file tool calls.
// Returns the original content, modified content, and any error.
// For Anthropic, this returns an unimplemented error.
func CreateWriteTextFileDiff(ctx context.Context, chatId string, toolCallId string) ([]byte, []byte, error) {
aiChat := chatstore.DefaultChatStore.Get(chatId)
if aiChat == nil {
return nil, nil, fmt.Errorf("chat not found: %s", chatId)
}
backend, err := GetBackendByAPIType(aiChat.APIType)
if err != nil {
return nil, nil, err
}
funcCallInput := backend.GetFunctionCallInputByToolCallId(*aiChat, toolCallId)
if funcCallInput == nil {
return nil, nil, fmt.Errorf("tool call not found: %s", toolCallId)
}
toolName := funcCallInput.Name
if toolName != "write_text_file" && toolName != "edit_text_file" {
return nil, nil, fmt.Errorf("tool call %s is not a write_text_file or edit_text_file (got: %s)", toolCallId, toolName)
}
var backupFileName string
if funcCallInput.ToolUseData != nil {
backupFileName = funcCallInput.ToolUseData.WriteBackupFileName
}
var parsedArguments any
if err := json.Unmarshal([]byte(funcCallInput.Arguments), &parsedArguments); err != nil {
return nil, nil, fmt.Errorf("failed to unmarshal arguments: %w", err)
}
if toolName == "edit_text_file" {
originalContent, modifiedContent, err := EditTextFileDryRun(parsedArguments, backupFileName)
if err != nil {
return nil, nil, fmt.Errorf("failed to generate diff: %w", err)
}
return originalContent, modifiedContent, nil
}
params, err := parseWriteTextFileInput(parsedArguments)
if err != nil {
return nil, nil, fmt.Errorf("failed to parse write_text_file input: %w", err)
}
var originalContent []byte
if backupFileName != "" {
originalContent, err = os.ReadFile(backupFileName)
if err != nil {
return nil, nil, fmt.Errorf("failed to read backup file: %w", err)
}
} else {
expandedPath, err := wavebase.ExpandHomeDir(params.Filename)
if err != nil {
return nil, nil, fmt.Errorf("failed to expand path: %w", err)
}
originalContent, err = os.ReadFile(expandedPath)
if err != nil && !os.IsNotExist(err) {
return nil, nil, fmt.Errorf("failed to read original file: %w", err)
}
}
modifiedContent := []byte(params.Contents)
return originalContent, modifiedContent, nil
}
type StaticFileInfo struct {
Name string `json:"name"`
Size int64 `json:"size"`
Modified string `json:"modified"`
ModifiedTime string `json:"modified_time"`
}
func generateBuilderAppData(appId string) (string, string, error) {
appGoFile := ""
fileData, err := waveappstore.ReadAppFile(appId, "app.go")
if err == nil {
appGoFile = string(fileData.Contents)
}
staticFilesJSON := ""
allFiles, err := waveappstore.ListAllAppFiles(appId)
if err == nil {
var staticFiles []StaticFileInfo
for _, entry := range allFiles.Entries {
if strings.HasPrefix(entry.Name, "static/") {
staticFiles = append(staticFiles, StaticFileInfo{
Name: entry.Name,
Size: entry.Size,
Modified: entry.Modified,
ModifiedTime: entry.ModifiedTime,
})
}
}
if len(staticFiles) > 0 {
staticFilesBytes, marshalErr := json.Marshal(staticFiles)
if marshalErr == nil {
staticFilesJSON = string(staticFilesBytes)
}
}
}
return appGoFile, staticFilesJSON, nil
}