waveterm/pkg/aiusechat/usechat.go

833 lines
27 KiB
Go

// Copyright 2025, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package aiusechat
import (
"context"
_ "embed"
"encoding/json"
"fmt"
"log"
"net/http"
"os"
"os/user"
"regexp"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/wavetermdev/waveterm/pkg/aiusechat/aiutil"
"github.com/wavetermdev/waveterm/pkg/aiusechat/chatstore"
"github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes"
"github.com/wavetermdev/waveterm/pkg/secretstore"
"github.com/wavetermdev/waveterm/pkg/telemetry"
"github.com/wavetermdev/waveterm/pkg/telemetry/telemetrydata"
"github.com/wavetermdev/waveterm/pkg/util/ds"
"github.com/wavetermdev/waveterm/pkg/util/logutil"
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
"github.com/wavetermdev/waveterm/pkg/waveappstore"
"github.com/wavetermdev/waveterm/pkg/wavebase"
"github.com/wavetermdev/waveterm/pkg/waveobj"
"github.com/wavetermdev/waveterm/pkg/web/sse"
"github.com/wavetermdev/waveterm/pkg/wps"
"github.com/wavetermdev/waveterm/pkg/wstore"
)
const DefaultAPI = uctypes.APIType_OpenAIResponses
const DefaultMaxTokens = 4 * 1024
const BuilderMaxTokens = 24 * 1024
const WaveAIEndpointEnvName = "WAVETERM_WAVEAI_ENDPOINT"
var (
globalRateLimitInfo = &uctypes.RateLimitInfo{Unknown: true}
rateLimitLock sync.Mutex
activeToolMap = ds.MakeSyncMap[bool]() // key is toolcallid
activeChats = ds.MakeSyncMap[bool]() // key is chatid
)
func getSystemPrompt(apiType string, model string, isBuilder bool) []string {
if isBuilder {
return []string{}
}
basePrompt := SystemPromptText_OpenAI
modelLower := strings.ToLower(model)
needsStrictToolAddOn, _ := regexp.MatchString(`(?i)\b(mistral|o?llama|qwen|mixtral|yi|phi|deepseek)\b`, modelLower)
if needsStrictToolAddOn {
return []string{basePrompt, SystemPromptText_StrictToolAddOn}
}
return []string{basePrompt}
}
func getWaveAISettings(premium bool, builderMode bool, rtInfo waveobj.ObjRTInfo) (*uctypes.AIOptsType, error) {
maxTokens := DefaultMaxTokens
if builderMode {
maxTokens = BuilderMaxTokens
}
if rtInfo.WaveAIMaxOutputTokens > 0 {
maxTokens = rtInfo.WaveAIMaxOutputTokens
}
aiMode, config, err := resolveAIMode(rtInfo.WaveAIMode, premium)
if err != nil {
return nil, err
}
apiToken := config.APIToken
if apiToken == "" && config.APITokenSecretName != "" {
secret, exists, err := secretstore.GetSecret(config.APITokenSecretName)
if err != nil {
return nil, fmt.Errorf("failed to retrieve secret %s: %w", config.APITokenSecretName, err)
}
if !exists || secret == "" {
return nil, fmt.Errorf("secret %s not found or empty", config.APITokenSecretName)
}
apiToken = secret
}
var baseUrl string
if config.WaveAICloud {
baseUrl = uctypes.DefaultAIEndpoint
if os.Getenv(WaveAIEndpointEnvName) != "" {
baseUrl = os.Getenv(WaveAIEndpointEnvName)
}
} else if config.BaseURL != "" {
baseUrl = config.BaseURL
} else {
return nil, fmt.Errorf("no BaseURL configured for AI mode %s", aiMode)
}
opts := &uctypes.AIOptsType{
APIType: config.APIType,
Model: config.Model,
MaxTokens: maxTokens,
ThinkingLevel: config.ThinkingLevel,
AIMode: aiMode,
BaseURL: baseUrl,
Capabilities: config.Capabilities,
}
if apiToken != "" {
opts.APIToken = apiToken
}
return opts, nil
}
func shouldUseChatCompletionsAPI(model string) bool {
m := strings.ToLower(model)
// Chat Completions API is required for older models: gpt-3.5-*, gpt-4, gpt-4-turbo, o1-*
return strings.HasPrefix(m, "gpt-3.5") ||
strings.HasPrefix(m, "gpt-4-") ||
m == "gpt-4" ||
strings.HasPrefix(m, "o1-")
}
func shouldUsePremium() bool {
info := GetGlobalRateLimit()
if info == nil || info.Unknown {
return true
}
if info.PReq > 0 {
return true
}
nowEpoch := time.Now().Unix()
if nowEpoch >= info.ResetEpoch {
return true
}
return false
}
func updateRateLimit(info *uctypes.RateLimitInfo) {
if info == nil {
return
}
rateLimitLock.Lock()
defer rateLimitLock.Unlock()
globalRateLimitInfo = info
go func() {
wps.Broker.Publish(wps.WaveEvent{
Event: wps.Event_WaveAIRateLimit,
Data: info,
})
}()
}
func GetGlobalRateLimit() *uctypes.RateLimitInfo {
rateLimitLock.Lock()
defer rateLimitLock.Unlock()
return globalRateLimitInfo
}
func runAIChatStep(ctx context.Context, sseHandler *sse.SSEHandlerCh, backend UseChatBackend, chatOpts uctypes.WaveChatOpts, cont *uctypes.WaveContinueResponse) (*uctypes.WaveStopReason, []uctypes.GenAIMessage, error) {
if chatOpts.Config.APIType == uctypes.APIType_OpenAIResponses && shouldUseChatCompletionsAPI(chatOpts.Config.Model) {
return nil, nil, fmt.Errorf("Chat completions API not available (must use newer OpenAI models)")
}
stopReason, messages, rateLimitInfo, err := backend.RunChatStep(ctx, sseHandler, chatOpts, cont)
updateRateLimit(rateLimitInfo)
return stopReason, messages, err
}
func getUsage(msgs []uctypes.GenAIMessage) uctypes.AIUsage {
var rtn uctypes.AIUsage
var found bool
for _, msg := range msgs {
if usage := msg.GetUsage(); usage != nil {
if !found {
rtn = *usage
found = true
} else {
rtn.InputTokens += usage.InputTokens
rtn.OutputTokens += usage.OutputTokens
rtn.NativeWebSearchCount += usage.NativeWebSearchCount
}
}
}
return rtn
}
func GetChatUsage(chat *uctypes.AIChat) uctypes.AIUsage {
usage := getUsage(chat.NativeMessages)
usage.APIType = chat.APIType
usage.Model = chat.Model
return usage
}
func updateToolUseDataInChat(backend UseChatBackend, chatOpts uctypes.WaveChatOpts, toolCallID string, toolUseData uctypes.UIMessageDataToolUse) {
if err := backend.UpdateToolUseData(chatOpts.ChatId, toolCallID, toolUseData); err != nil {
log.Printf("failed to update tool use data in chat: %v\n", err)
}
}
func processToolCallInternal(backend UseChatBackend, toolCall uctypes.WaveToolCall, chatOpts uctypes.WaveChatOpts, toolDef *uctypes.ToolDefinition, sseHandler *sse.SSEHandlerCh) uctypes.AIToolResult {
if toolCall.ToolUseData == nil {
return uctypes.AIToolResult{
ToolName: toolCall.Name,
ToolUseID: toolCall.ID,
ErrorText: "Invalid Tool Call",
}
}
if toolCall.ToolUseData.Status == uctypes.ToolUseStatusError {
errorMsg := toolCall.ToolUseData.ErrorMessage
if errorMsg == "" {
errorMsg = "Unspecified Tool Error"
}
return uctypes.AIToolResult{
ToolName: toolCall.Name,
ToolUseID: toolCall.ID,
ErrorText: errorMsg,
}
}
if toolDef != nil && toolDef.ToolVerifyInput != nil {
if err := toolDef.ToolVerifyInput(toolCall.Input, toolCall.ToolUseData); err != nil {
errorMsg := fmt.Sprintf("Input validation failed: %v", err)
toolCall.ToolUseData.Status = uctypes.ToolUseStatusError
toolCall.ToolUseData.ErrorMessage = errorMsg
return uctypes.AIToolResult{
ToolName: toolCall.Name,
ToolUseID: toolCall.ID,
ErrorText: errorMsg,
}
}
// ToolVerifyInput can modify the toolusedata. re-send it here.
_ = sseHandler.AiMsgData("data-tooluse", toolCall.ID, *toolCall.ToolUseData)
updateToolUseDataInChat(backend, chatOpts, toolCall.ID, *toolCall.ToolUseData)
}
if toolCall.ToolUseData.Approval == uctypes.ApprovalNeedsApproval {
log.Printf(" waiting for approval...\n")
approval := WaitForToolApproval(toolCall.ID)
log.Printf(" approval result: %q\n", approval)
if approval != "" {
toolCall.ToolUseData.Approval = approval
}
if !toolCall.ToolUseData.IsApproved() {
errorMsg := "Tool use denied or timed out"
if approval == uctypes.ApprovalUserDenied {
errorMsg = "Tool use denied by user"
} else if approval == uctypes.ApprovalTimeout {
errorMsg = "Tool approval timed out"
}
toolCall.ToolUseData.Status = uctypes.ToolUseStatusError
toolCall.ToolUseData.ErrorMessage = errorMsg
return uctypes.AIToolResult{
ToolName: toolCall.Name,
ToolUseID: toolCall.ID,
ErrorText: errorMsg,
}
}
// this still happens here because we need to update the FE to say the tool call was approved
_ = sseHandler.AiMsgData("data-tooluse", toolCall.ID, *toolCall.ToolUseData)
updateToolUseDataInChat(backend, chatOpts, toolCall.ID, *toolCall.ToolUseData)
}
toolCall.ToolUseData.RunTs = time.Now().UnixMilli()
result := ResolveToolCall(toolDef, toolCall, chatOpts)
if result.ErrorText != "" {
toolCall.ToolUseData.Status = uctypes.ToolUseStatusError
toolCall.ToolUseData.ErrorMessage = result.ErrorText
} else {
toolCall.ToolUseData.Status = uctypes.ToolUseStatusCompleted
}
return result
}
func processToolCall(backend UseChatBackend, toolCall uctypes.WaveToolCall, chatOpts uctypes.WaveChatOpts, sseHandler *sse.SSEHandlerCh, metrics *uctypes.AIMetrics) uctypes.AIToolResult {
inputJSON, _ := json.Marshal(toolCall.Input)
logutil.DevPrintf("TOOLUSE name=%s id=%s input=%s approval=%q\n", toolCall.Name, toolCall.ID, utilfn.TruncateString(string(inputJSON), 40), toolCall.ToolUseData.Approval)
toolDef := chatOpts.GetToolDefinition(toolCall.Name)
result := processToolCallInternal(backend, toolCall, chatOpts, toolDef, sseHandler)
if result.ErrorText != "" {
log.Printf(" error=%s\n", result.ErrorText)
metrics.ToolUseErrorCount++
} else {
log.Printf(" result=%s\n", utilfn.TruncateString(result.Text, 40))
}
if toolDef != nil && toolDef.ToolLogName != "" {
metrics.ToolDetail[toolDef.ToolLogName]++
}
if toolCall.ToolUseData != nil {
_ = sseHandler.AiMsgData("data-tooluse", toolCall.ID, *toolCall.ToolUseData)
updateToolUseDataInChat(backend, chatOpts, toolCall.ID, *toolCall.ToolUseData)
}
return result
}
func processToolCalls(backend UseChatBackend, stopReason *uctypes.WaveStopReason, chatOpts uctypes.WaveChatOpts, sseHandler *sse.SSEHandlerCh, metrics *uctypes.AIMetrics) {
for _, toolCall := range stopReason.ToolCalls {
activeToolMap.Set(toolCall.ID, true)
defer activeToolMap.Delete(toolCall.ID)
}
// Create and send all data-tooluse packets at the beginning
for i := range stopReason.ToolCalls {
toolCall := &stopReason.ToolCalls[i]
// Create toolUseData from the tool call input
var argsJSON string
if toolCall.Input != nil {
argsBytes, err := json.Marshal(toolCall.Input)
if err == nil {
argsJSON = string(argsBytes)
}
}
toolUseData := aiutil.CreateToolUseData(toolCall.ID, toolCall.Name, argsJSON, chatOpts)
stopReason.ToolCalls[i].ToolUseData = &toolUseData
log.Printf("AI data-tooluse %s\n", toolCall.ID)
_ = sseHandler.AiMsgData("data-tooluse", toolCall.ID, toolUseData)
updateToolUseDataInChat(backend, chatOpts, toolCall.ID, toolUseData)
if toolUseData.Approval == uctypes.ApprovalNeedsApproval && chatOpts.RegisterToolApproval != nil {
chatOpts.RegisterToolApproval(toolCall.ID)
}
}
// At this point, all ToolCalls are guaranteed to have non-nil ToolUseData
var toolResults []uctypes.AIToolResult
for _, toolCall := range stopReason.ToolCalls {
result := processToolCall(backend, toolCall, chatOpts, sseHandler, metrics)
toolResults = append(toolResults, result)
}
toolResultMsgs, err := backend.ConvertToolResultsToNativeChatMessage(toolResults)
if err != nil {
log.Printf("Failed to convert tool results to native chat messages: %v", err)
} else {
for _, msg := range toolResultMsgs {
chatstore.DefaultChatStore.PostMessage(chatOpts.ChatId, &chatOpts.Config, msg)
}
}
}
func RunAIChat(ctx context.Context, sseHandler *sse.SSEHandlerCh, backend UseChatBackend, chatOpts uctypes.WaveChatOpts) (*uctypes.AIMetrics, error) {
if !activeChats.SetUnless(chatOpts.ChatId, true) {
return nil, fmt.Errorf("chat %s is already running", chatOpts.ChatId)
}
defer activeChats.Delete(chatOpts.ChatId)
stepNum := chatstore.DefaultChatStore.CountUserMessages(chatOpts.ChatId)
metrics := &uctypes.AIMetrics{
ChatId: chatOpts.ChatId,
StepNum: stepNum,
Usage: uctypes.AIUsage{
APIType: chatOpts.Config.APIType,
Model: chatOpts.Config.Model,
},
WidgetAccess: chatOpts.WidgetAccess,
ToolDetail: make(map[string]int),
ThinkingLevel: chatOpts.Config.ThinkingLevel,
AIMode: chatOpts.Config.AIMode,
}
firstStep := true
var cont *uctypes.WaveContinueResponse
for {
if chatOpts.TabStateGenerator != nil {
tabState, tabTools, tabId, tabErr := chatOpts.TabStateGenerator()
if tabErr == nil {
chatOpts.TabState = tabState
chatOpts.TabTools = tabTools
chatOpts.TabId = tabId
}
}
if chatOpts.BuilderAppGenerator != nil {
appGoFile, appStaticFiles, platformInfo, appErr := chatOpts.BuilderAppGenerator()
if appErr == nil {
chatOpts.AppGoFile = appGoFile
chatOpts.AppStaticFiles = appStaticFiles
chatOpts.PlatformInfo = platformInfo
}
}
stopReason, rtnMessages, err := runAIChatStep(ctx, sseHandler, backend, chatOpts, cont)
metrics.RequestCount++
if chatOpts.Config.IsPremiumModel() {
metrics.PremiumReqCount++
}
if chatOpts.Config.IsWaveProxy() {
metrics.ProxyReqCount++
}
if len(rtnMessages) > 0 {
usage := getUsage(rtnMessages)
log.Printf("usage: input=%d output=%d websearch=%d\n", usage.InputTokens, usage.OutputTokens, usage.NativeWebSearchCount)
metrics.Usage.InputTokens += usage.InputTokens
metrics.Usage.OutputTokens += usage.OutputTokens
metrics.Usage.NativeWebSearchCount += usage.NativeWebSearchCount
if usage.Model != "" && metrics.Usage.Model != usage.Model {
metrics.Usage.Model = "mixed"
}
}
if firstStep && err != nil {
metrics.HadError = true
return metrics, fmt.Errorf("failed to stream %s chat: %w", chatOpts.Config.APIType, err)
}
if err != nil {
metrics.HadError = true
_ = sseHandler.AiMsgError(err.Error())
_ = sseHandler.AiMsgFinish("", nil)
break
}
for _, msg := range rtnMessages {
if msg != nil {
chatstore.DefaultChatStore.PostMessage(chatOpts.ChatId, &chatOpts.Config, msg)
}
}
firstStep = false
if stopReason != nil && stopReason.Kind == uctypes.StopKindPremiumRateLimit && chatOpts.Config.APIType == uctypes.APIType_OpenAIResponses && chatOpts.Config.Model == uctypes.PremiumOpenAIModel {
log.Printf("Premium rate limit hit with %s, switching to %s\n", uctypes.PremiumOpenAIModel, uctypes.DefaultOpenAIModel)
cont = &uctypes.WaveContinueResponse{
Model: uctypes.DefaultOpenAIModel,
ContinueFromKind: uctypes.StopKindPremiumRateLimit,
}
continue
}
if stopReason != nil && stopReason.Kind == uctypes.StopKindToolUse {
metrics.ToolUseCount += len(stopReason.ToolCalls)
processToolCalls(backend, stopReason, chatOpts, sseHandler, metrics)
cont = &uctypes.WaveContinueResponse{
Model: chatOpts.Config.Model,
ContinueFromKind: uctypes.StopKindToolUse,
}
continue
}
break
}
return metrics, nil
}
func ResolveToolCall(toolDef *uctypes.ToolDefinition, toolCall uctypes.WaveToolCall, chatOpts uctypes.WaveChatOpts) (result uctypes.AIToolResult) {
result = uctypes.AIToolResult{
ToolName: toolCall.Name,
ToolUseID: toolCall.ID,
}
defer func() {
if r := recover(); r != nil {
result.ErrorText = fmt.Sprintf("panic in tool execution: %v", r)
result.Text = ""
}
}()
if toolDef == nil {
result.ErrorText = fmt.Sprintf("tool '%s' not found", toolCall.Name)
return
}
// Try ToolTextCallback first, then ToolAnyCallback
if toolDef.ToolTextCallback != nil {
text, err := toolDef.ToolTextCallback(toolCall.Input)
if err != nil {
result.ErrorText = err.Error()
} else {
result.Text = text
// Recompute tool description with the result
if toolDef.ToolCallDesc != nil && toolCall.ToolUseData != nil {
toolCall.ToolUseData.ToolDesc = toolDef.ToolCallDesc(toolCall.Input, text, toolCall.ToolUseData)
}
}
} else if toolDef.ToolAnyCallback != nil {
output, err := toolDef.ToolAnyCallback(toolCall.Input, toolCall.ToolUseData)
if err != nil {
result.ErrorText = err.Error()
} else {
// Marshal the result to JSON
jsonBytes, marshalErr := json.Marshal(output)
if marshalErr != nil {
result.ErrorText = fmt.Sprintf("failed to marshal tool output: %v", marshalErr)
} else {
result.Text = string(jsonBytes)
// Recompute tool description with the result
if toolDef.ToolCallDesc != nil && toolCall.ToolUseData != nil {
toolCall.ToolUseData.ToolDesc = toolDef.ToolCallDesc(toolCall.Input, output, toolCall.ToolUseData)
}
}
}
} else {
result.ErrorText = fmt.Sprintf("tool '%s' has no callback functions", toolCall.Name)
}
return
}
func WaveAIPostMessageWrap(ctx context.Context, sseHandler *sse.SSEHandlerCh, message *uctypes.AIMessage, chatOpts uctypes.WaveChatOpts) error {
startTime := time.Now()
// Convert AIMessage to native chat message using backend
backend, err := GetBackendByAPIType(chatOpts.Config.APIType)
if err != nil {
return err
}
convertedMessage, err := backend.ConvertAIMessageToNativeChatMessage(*message)
if err != nil {
return fmt.Errorf("message conversion failed: %w", err)
}
// Post message to chat store
if err := chatstore.DefaultChatStore.PostMessage(chatOpts.ChatId, &chatOpts.Config, convertedMessage); err != nil {
return fmt.Errorf("failed to store message: %w", err)
}
metrics, err := RunAIChat(ctx, sseHandler, backend, chatOpts)
if metrics != nil {
metrics.RequestDuration = int(time.Since(startTime).Milliseconds())
for _, part := range message.Parts {
if part.Type == uctypes.AIMessagePartTypeText {
metrics.TextLen += len(part.Text)
} else if part.Type == uctypes.AIMessagePartTypeFile {
mimeType := strings.ToLower(part.MimeType)
if strings.HasPrefix(mimeType, "image/") {
metrics.ImageCount++
} else if mimeType == "application/pdf" {
metrics.PDFCount++
} else {
metrics.TextDocCount++
}
}
}
log.Printf("WaveAI call metrics: requests=%d tools=%d premium=%d proxy=%d images=%d pdfs=%d textdocs=%d textlen=%d duration=%dms error=%v\n",
metrics.RequestCount, metrics.ToolUseCount, metrics.PremiumReqCount, metrics.ProxyReqCount,
metrics.ImageCount, metrics.PDFCount, metrics.TextDocCount, metrics.TextLen, metrics.RequestDuration, metrics.HadError)
sendAIMetricsTelemetry(ctx, metrics)
}
return err
}
func sendAIMetricsTelemetry(ctx context.Context, metrics *uctypes.AIMetrics) {
event := telemetrydata.MakeTEvent("waveai:post", telemetrydata.TEventProps{
WaveAIAPIType: metrics.Usage.APIType,
WaveAIModel: metrics.Usage.Model,
WaveAIChatId: metrics.ChatId,
WaveAIStepNum: metrics.StepNum,
WaveAIInputTokens: metrics.Usage.InputTokens,
WaveAIOutputTokens: metrics.Usage.OutputTokens,
WaveAINativeWebSearchCount: metrics.Usage.NativeWebSearchCount,
WaveAIRequestCount: metrics.RequestCount,
WaveAIToolUseCount: metrics.ToolUseCount,
WaveAIToolUseErrorCount: metrics.ToolUseErrorCount,
WaveAIToolDetail: metrics.ToolDetail,
WaveAIPremiumReq: metrics.PremiumReqCount,
WaveAIProxyReq: metrics.ProxyReqCount,
WaveAIHadError: metrics.HadError,
WaveAIImageCount: metrics.ImageCount,
WaveAIPDFCount: metrics.PDFCount,
WaveAITextDocCount: metrics.TextDocCount,
WaveAITextLen: metrics.TextLen,
WaveAIFirstByteMs: metrics.FirstByteLatency,
WaveAIRequestDurMs: metrics.RequestDuration,
WaveAIWidgetAccess: metrics.WidgetAccess,
WaveAIThinkingLevel: metrics.ThinkingLevel,
WaveAIMode: metrics.AIMode,
})
_ = telemetry.RecordTEvent(ctx, event)
}
// PostMessageRequest represents the request body for posting a message
type PostMessageRequest struct {
TabId string `json:"tabid,omitempty"`
BuilderId string `json:"builderid,omitempty"`
BuilderAppId string `json:"builderappid,omitempty"`
ChatID string `json:"chatid"`
Msg uctypes.AIMessage `json:"msg"`
WidgetAccess bool `json:"widgetaccess,omitempty"`
}
func WaveAIPostMessageHandler(w http.ResponseWriter, r *http.Request) {
// Only allow POST method
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Parse request body
var req PostMessageRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest)
return
}
// Validate chatid is present and is a UUID
if req.ChatID == "" {
http.Error(w, "chatid is required in request body", http.StatusBadRequest)
return
}
if _, err := uuid.Parse(req.ChatID); err != nil {
http.Error(w, "chatid must be a valid UUID", http.StatusBadRequest)
return
}
// Get RTInfo from TabId or BuilderId
var rtInfo *waveobj.ObjRTInfo
if req.TabId != "" {
oref := waveobj.MakeORef(waveobj.OType_Tab, req.TabId)
rtInfo = wstore.GetRTInfo(oref)
} else if req.BuilderId != "" {
oref := waveobj.MakeORef(waveobj.OType_Builder, req.BuilderId)
rtInfo = wstore.GetRTInfo(oref)
}
if rtInfo == nil {
rtInfo = &waveobj.ObjRTInfo{}
}
// Get WaveAI settings
premium := shouldUsePremium()
builderMode := req.BuilderId != ""
aiOpts, err := getWaveAISettings(premium, builderMode, *rtInfo)
if err != nil {
http.Error(w, fmt.Sprintf("WaveAI configuration error: %v", err), http.StatusInternalServerError)
return
}
// Get client ID from database
client, err := wstore.DBGetSingleton[*waveobj.Client](r.Context())
if err != nil {
http.Error(w, fmt.Sprintf("Failed to get client: %v", err), http.StatusInternalServerError)
return
}
// Call the core WaveAIPostMessage function
chatOpts := uctypes.WaveChatOpts{
ChatId: req.ChatID,
ClientId: client.OID,
Config: *aiOpts,
WidgetAccess: req.WidgetAccess,
RegisterToolApproval: RegisterToolApproval,
AllowNativeWebSearch: true,
BuilderId: req.BuilderId,
BuilderAppId: req.BuilderAppId,
}
chatOpts.SystemPrompt = getSystemPrompt(chatOpts.Config.APIType, chatOpts.Config.Model, chatOpts.BuilderId != "")
if req.TabId != "" {
chatOpts.TabStateGenerator = func() (string, []uctypes.ToolDefinition, string, error) {
tabState, tabTools, err := GenerateTabStateAndTools(r.Context(), req.TabId, req.WidgetAccess)
return tabState, tabTools, req.TabId, err
}
}
if req.BuilderAppId != "" {
chatOpts.BuilderAppGenerator = func() (string, string, string, error) {
return generateBuilderAppData(req.BuilderAppId)
}
}
if req.BuilderAppId != "" {
chatOpts.Tools = append(chatOpts.Tools,
GetBuilderWriteAppFileToolDefinition(req.BuilderAppId, req.BuilderId),
GetBuilderEditAppFileToolDefinition(req.BuilderAppId, req.BuilderId),
GetBuilderListFilesToolDefinition(req.BuilderAppId),
)
}
// Validate the message
if err := req.Msg.Validate(); err != nil {
http.Error(w, fmt.Sprintf("Message validation failed: %v", err), http.StatusInternalServerError)
return
}
// Create SSE handler and set up streaming
sseHandler := sse.MakeSSEHandlerCh(w, r.Context())
defer sseHandler.Close()
if err := WaveAIPostMessageWrap(r.Context(), sseHandler, &req.Msg, chatOpts); err != nil {
http.Error(w, fmt.Sprintf("Failed to post message: %v", err), http.StatusInternalServerError)
return
}
}
func WaveAIGetChatHandler(w http.ResponseWriter, r *http.Request) {
// Only allow GET method
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Get chatid from URL parameters
chatID := r.URL.Query().Get("chatid")
if chatID == "" {
http.Error(w, "chatid parameter is required", http.StatusBadRequest)
return
}
// Validate chatid is a UUID
if _, err := uuid.Parse(chatID); err != nil {
http.Error(w, "chatid must be a valid UUID", http.StatusBadRequest)
return
}
// Get chat from store
chat := chatstore.DefaultChatStore.Get(chatID)
if chat == nil {
http.Error(w, "chat not found", http.StatusNotFound)
return
}
// Set response headers for JSON
w.Header().Set("Content-Type", "application/json")
// Encode and return the chat
if err := json.NewEncoder(w).Encode(chat); err != nil {
http.Error(w, fmt.Sprintf("Failed to encode response: %v", err), http.StatusInternalServerError)
return
}
}
// CreateWriteTextFileDiff generates a diff for write_text_file or edit_text_file tool calls.
// Returns the original content, modified content, and any error.
// For Anthropic, this returns an unimplemented error.
func CreateWriteTextFileDiff(ctx context.Context, chatId string, toolCallId string) ([]byte, []byte, error) {
aiChat := chatstore.DefaultChatStore.Get(chatId)
if aiChat == nil {
return nil, nil, fmt.Errorf("chat not found: %s", chatId)
}
backend, err := GetBackendByAPIType(aiChat.APIType)
if err != nil {
return nil, nil, err
}
funcCallInput := backend.GetFunctionCallInputByToolCallId(*aiChat, toolCallId)
if funcCallInput == nil {
return nil, nil, fmt.Errorf("tool call not found: %s", toolCallId)
}
toolName := funcCallInput.Name
if toolName != "write_text_file" && toolName != "edit_text_file" {
return nil, nil, fmt.Errorf("tool call %s is not a write_text_file or edit_text_file (got: %s)", toolCallId, toolName)
}
var backupFileName string
if funcCallInput.ToolUseData != nil {
backupFileName = funcCallInput.ToolUseData.WriteBackupFileName
}
var parsedArguments any
if err := json.Unmarshal([]byte(funcCallInput.Arguments), &parsedArguments); err != nil {
return nil, nil, fmt.Errorf("failed to unmarshal arguments: %w", err)
}
if toolName == "edit_text_file" {
originalContent, modifiedContent, err := EditTextFileDryRun(parsedArguments, backupFileName)
if err != nil {
return nil, nil, fmt.Errorf("failed to generate diff: %w", err)
}
return originalContent, modifiedContent, nil
}
params, err := parseWriteTextFileInput(parsedArguments)
if err != nil {
return nil, nil, fmt.Errorf("failed to parse write_text_file input: %w", err)
}
var originalContent []byte
if backupFileName != "" {
originalContent, err = os.ReadFile(backupFileName)
if err != nil {
return nil, nil, fmt.Errorf("failed to read backup file: %w", err)
}
} else {
expandedPath, err := wavebase.ExpandHomeDir(params.Filename)
if err != nil {
return nil, nil, fmt.Errorf("failed to expand path: %w", err)
}
originalContent, err = os.ReadFile(expandedPath)
if err != nil && !os.IsNotExist(err) {
return nil, nil, fmt.Errorf("failed to read original file: %w", err)
}
}
modifiedContent := []byte(params.Contents)
return originalContent, modifiedContent, nil
}
type StaticFileInfo struct {
Name string `json:"name"`
Size int64 `json:"size"`
Modified string `json:"modified"`
ModifiedTime string `json:"modified_time"`
}
func generateBuilderAppData(appId string) (string, string, string, error) {
appGoFile := ""
fileData, err := waveappstore.ReadAppFile(appId, "app.go")
if err == nil {
appGoFile = string(fileData.Contents)
}
staticFilesJSON := ""
allFiles, err := waveappstore.ListAllAppFiles(appId)
if err == nil {
var staticFiles []StaticFileInfo
for _, entry := range allFiles.Entries {
if strings.HasPrefix(entry.Name, "static/") {
staticFiles = append(staticFiles, StaticFileInfo{
Name: entry.Name,
Size: entry.Size,
Modified: entry.Modified,
ModifiedTime: entry.ModifiedTime,
})
}
}
if len(staticFiles) > 0 {
staticFilesBytes, marshalErr := json.Marshal(staticFiles)
if marshalErr == nil {
staticFilesJSON = string(staticFilesBytes)
}
}
}
platformInfo := wavebase.GetSystemSummary()
if currentUser, userErr := user.Current(); userErr == nil && currentUser.Username != "" {
platformInfo = fmt.Sprintf("Local Machine: %s, User: %s", platformInfo, currentUser.Username)
} else {
platformInfo = fmt.Sprintf("Local Machine: %s", platformInfo)
}
return appGoFile, staticFilesJSON, platformInfo, nil
}