waveterm/pkg/web/sse/ssehandler.go
Mike Sawka fd0e75a984
New data-tooluse AI SDK packet and Tool Approvals Implemented (#2407)
provides richer information for FE to use to display tools. also implements a full approve/deny flow for tools that require approval (readfile)
2025-10-09 10:21:01 -07:00

475 lines
11 KiB
Go

// Copyright 2025, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package sse
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"sync"
"time"
)
// see /aiprompts/usechat-streamingproto.md for protocol
const (
SSEContentType = "text/event-stream"
SSECacheControl = "no-cache"
SSEConnection = "keep-alive"
SSEKeepaliveMsg = ": keepalive\n\n"
SSEStreamStartMsg = ": stream-start\n\n"
SSEKeepaliveInterval = 1 * time.Second
)
// SSEMessageType represents the type of message to write
type SSEMessageType string
const (
SSEMsgData SSEMessageType = "data"
SSEMsgEvent SSEMessageType = "event"
SSEMsgComment SSEMessageType = "comment"
SSEMsgError SSEMessageType = "error"
)
// AI message type constants
const (
AiMsgStart = "start"
AiMsgTextStart = "text-start"
AiMsgTextDelta = "text-delta"
AiMsgTextEnd = "text-end"
AiMsgReasoningStart = "reasoning-start"
AiMsgReasoningDelta = "reasoning-delta"
AiMsgReasoningEnd = "reasoning-end"
AiMsgToolInputStart = "tool-input-start"
AiMsgToolInputDelta = "tool-input-delta"
AiMsgToolInputAvailable = "tool-input-available"
AiMsgToolOutputAvailable = "tool-output-available" // not used here, but reserved
AiMsgStartStep = "start-step"
AiMsgFinishStep = "finish-step"
AiMsgFinish = "finish"
AiMsgError = "error"
)
// SSEMessage represents a message to be written to the SSE stream
type SSEMessage struct {
Type SSEMessageType
Data string
EventType string // Only used for SSEMsgEvent
}
// SSEHandlerCh provides channel-based Server-Sent Events functionality
type SSEHandlerCh struct {
w http.ResponseWriter
rc *http.ResponseController
ctx context.Context
writeCh chan SSEMessage
errCh chan error
mu sync.RWMutex
closed bool
initialized bool
err error
wg sync.WaitGroup
}
// MakeSSEHandlerCh creates a new channel-based SSE handler
func MakeSSEHandlerCh(w http.ResponseWriter, ctx context.Context) *SSEHandlerCh {
return &SSEHandlerCh{
w: w,
rc: http.NewResponseController(w),
ctx: ctx,
writeCh: make(chan SSEMessage, 10), // Buffered to prevent blocking
errCh: make(chan error, 1), // Buffered for single error
}
}
// SetupSSE configures the response headers and starts the writer goroutine
func (h *SSEHandlerCh) SetupSSE() error {
h.mu.Lock()
defer h.mu.Unlock()
if h.closed {
return fmt.Errorf("SSE handler is closed")
}
h.initialized = true
// Reset write deadline for streaming
if err := h.rc.SetWriteDeadline(time.Time{}); err != nil {
return fmt.Errorf("failed to reset write deadline: %v", err)
}
// Set SSE headers
h.w.Header().Set("Content-Type", SSEContentType)
h.w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate, no-transform")
h.w.Header().Set("Connection", SSEConnection)
h.w.Header().Set("x-vercel-ai-ui-message-stream", "v1")
h.w.Header().Set("X-Accel-Buffering", "no")
// Send headers and establish streaming
h.w.WriteHeader(http.StatusOK)
fmt.Fprint(h.w, SSEStreamStartMsg)
if err := h.flush(); err != nil {
return err
}
// Start the writer goroutine
h.wg.Add(1)
go h.writerLoop()
return nil
}
// writerLoop handles all writes and keepalives in a single goroutine
func (h *SSEHandlerCh) writerLoop() {
defer h.wg.Done()
keepaliveTicker := time.NewTicker(SSEKeepaliveInterval)
defer keepaliveTicker.Stop()
for {
select {
case msg, ok := <-h.writeCh:
if !ok {
// Channel closed, send [DONE] and exit
h.writeDirectly("[DONE]", SSEMsgData)
return
}
if err := h.writeMessage(msg); err != nil {
h.setError(err)
return
}
case <-keepaliveTicker.C:
if err := h.writeDirectly("keepalive", SSEMsgComment); err != nil {
h.setError(err)
return
}
case <-h.ctx.Done():
return
}
}
}
// writeMessage writes a message to the SSE stream
func (h *SSEHandlerCh) writeMessage(msg SSEMessage) error {
switch msg.Type {
case SSEMsgData:
return h.writeDirectly(msg.Data, SSEMsgData)
case SSEMsgEvent:
return h.writeEvent(msg.EventType, msg.Data)
case SSEMsgComment:
return h.writeDirectly(msg.Data, SSEMsgComment)
case SSEMsgError:
return h.writeDirectly(msg.Data, SSEMsgData)
default:
return fmt.Errorf("unknown message type: %s", msg.Type)
}
}
// isInitialized returns whether SetupSSE has been called
func (h *SSEHandlerCh) isInitialized() bool {
h.mu.RLock()
defer h.mu.RUnlock()
return h.initialized
}
// writeDirectly writes data directly to the response writer
func (h *SSEHandlerCh) writeDirectly(data string, msgType SSEMessageType) error {
if !h.isInitialized() {
panic("SSEHandlerCh not initialized - call SetupSSE first")
}
switch msgType {
case SSEMsgData:
_, err := fmt.Fprintf(h.w, "data: %s\n\n", data)
if err != nil {
return err
}
case SSEMsgComment:
_, err := fmt.Fprintf(h.w, ": %s\n\n", data)
if err != nil {
return err
}
default:
panic(fmt.Sprintf("unsupported direct write type: %s", msgType))
}
return h.flush()
}
// writeEvent writes an SSE event with optional event type
func (h *SSEHandlerCh) writeEvent(eventType, data string) error {
if !h.isInitialized() {
panic("SSEHandlerCh not initialized - call SetupSSE first")
}
if eventType != "" {
if _, err := fmt.Fprintf(h.w, "event: %s\n", eventType); err != nil {
return err
}
}
if _, err := fmt.Fprintf(h.w, "data: %s\n\n", data); err != nil {
return err
}
return h.flush()
}
// flush attempts to flush the response writer
func (h *SSEHandlerCh) flush() error {
return h.rc.Flush()
}
// setError sets the error state thread-safely
func (h *SSEHandlerCh) setError(err error) {
h.mu.Lock()
defer h.mu.Unlock()
if h.err == nil {
h.err = err
// Send error to error channel if there's space
select {
case h.errCh <- err:
default:
}
}
}
// WriteData queues data to be written in SSE format
func (h *SSEHandlerCh) WriteData(data string) error {
h.mu.RLock()
closed := h.closed
h.mu.RUnlock()
if closed {
return fmt.Errorf("SSE handler is closed")
}
select {
case h.writeCh <- SSEMessage{Type: SSEMsgData, Data: data}:
return nil
case <-h.ctx.Done():
return h.ctx.Err()
default:
return fmt.Errorf("write channel is full")
}
}
// WriteJsonData marshals data to JSON and queues it for writing
func (h *SSEHandlerCh) WriteJsonData(data interface{}) error {
jsonData, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("failed to marshal JSON: %v", err)
}
return h.WriteData(string(jsonData))
}
// WriteError queues an error message and closes the handler
func (h *SSEHandlerCh) WriteError(errorMsg string) error {
errorResp := map[string]interface{}{
"type": AiMsgError,
"errorText": errorMsg,
}
if err := h.WriteJsonData(errorResp); err != nil {
return err
}
h.Close()
return nil
}
// WriteEvent queues an SSE event with optional event type
func (h *SSEHandlerCh) WriteEvent(eventType, data string) error {
h.mu.RLock()
closed := h.closed
h.mu.RUnlock()
if closed {
return fmt.Errorf("SSE handler is closed")
}
select {
case h.writeCh <- SSEMessage{Type: SSEMsgEvent, Data: data, EventType: eventType}:
return nil
case <-h.ctx.Done():
return h.ctx.Err()
default:
return fmt.Errorf("write channel is full")
}
}
// WriteComment queues an SSE comment
func (h *SSEHandlerCh) WriteComment(comment string) error {
h.mu.RLock()
closed := h.closed
h.mu.RUnlock()
if closed {
return fmt.Errorf("SSE handler is closed")
}
select {
case h.writeCh <- SSEMessage{Type: SSEMsgComment, Data: comment}:
return nil
case <-h.ctx.Done():
return h.ctx.Err()
default:
return fmt.Errorf("write channel is full")
}
}
// Err returns any error that occurred during writing
func (h *SSEHandlerCh) Err() error {
h.mu.RLock()
defer h.mu.RUnlock()
return h.err
}
// Close closes the write channel, sends [DONE], and cleans up resources
func (h *SSEHandlerCh) Close() {
h.mu.Lock()
if h.closed || !h.initialized {
h.mu.Unlock()
return
}
h.closed = true
// Close the write channel, which will trigger [DONE] in writerLoop
close(h.writeCh)
h.mu.Unlock()
// Wait for writer goroutine to finish (without holding the lock)
h.wg.Wait()
}
// AI message writing methods
func (h *SSEHandlerCh) AiMsgStart(messageId string) error {
resp := map[string]interface{}{
"type": AiMsgStart,
"messageId": messageId,
}
return h.WriteJsonData(resp)
}
func (h *SSEHandlerCh) AiMsgTextStart(textId string) error {
resp := map[string]interface{}{
"type": AiMsgTextStart,
"id": textId,
}
return h.WriteJsonData(resp)
}
func (h *SSEHandlerCh) AiMsgTextDelta(textId string, text string) error {
resp := map[string]interface{}{
"type": AiMsgTextDelta,
"id": textId,
"delta": text,
}
return h.WriteJsonData(resp)
}
func (h *SSEHandlerCh) AiMsgTextEnd(textId string) error {
resp := map[string]interface{}{
"type": AiMsgTextEnd,
"id": textId,
}
return h.WriteJsonData(resp)
}
func (h *SSEHandlerCh) AiMsgFinish(finishReason string, usage interface{}) error {
resp := map[string]interface{}{
"type": AiMsgFinish,
}
return h.WriteJsonData(resp)
}
func (h *SSEHandlerCh) AiMsgReasoningStart(reasoningId string) error {
resp := map[string]interface{}{
"type": AiMsgReasoningStart,
"id": reasoningId,
}
return h.WriteJsonData(resp)
}
func (h *SSEHandlerCh) AiMsgReasoningDelta(reasoningId string, reasoning string) error {
resp := map[string]interface{}{
"type": AiMsgReasoningDelta,
"id": reasoningId,
"delta": reasoning,
}
return h.WriteJsonData(resp)
}
func (h *SSEHandlerCh) AiMsgReasoningEnd(reasoningId string) error {
resp := map[string]interface{}{
"type": AiMsgReasoningEnd,
"id": reasoningId,
}
return h.WriteJsonData(resp)
}
func (h *SSEHandlerCh) AiMsgToolInputStart(toolCallId, toolName string) error {
resp := map[string]interface{}{
"type": AiMsgToolInputStart,
"toolCallId": toolCallId,
"toolName": toolName,
}
return h.WriteJsonData(resp)
}
func (h *SSEHandlerCh) AiMsgToolInputDelta(toolCallId, inputTextDelta string) error {
resp := map[string]interface{}{
"type": AiMsgToolInputDelta,
"toolCallId": toolCallId,
"inputTextDelta": inputTextDelta,
}
return h.WriteJsonData(resp)
}
func (h *SSEHandlerCh) AiMsgToolInputAvailable(toolCallId, toolName string, input json.RawMessage) error {
resp := map[string]interface{}{
"type": AiMsgToolInputAvailable,
"toolCallId": toolCallId,
"toolName": toolName,
"input": json.RawMessage(input),
}
return h.WriteJsonData(resp)
}
func (h *SSEHandlerCh) AiMsgStartStep() error {
resp := map[string]interface{}{
"type": AiMsgStartStep,
}
return h.WriteJsonData(resp)
}
func (h *SSEHandlerCh) AiMsgFinishStep() error {
resp := map[string]interface{}{
"type": AiMsgFinishStep,
}
return h.WriteJsonData(resp)
}
func (h *SSEHandlerCh) AiMsgError(errText string) error {
resp := map[string]interface{}{
"type": AiMsgError,
"errorText": errText,
}
return h.WriteJsonData(resp)
}
func (h *SSEHandlerCh) AiMsgData(dataType string, id string, data interface{}) error {
if !strings.HasPrefix(dataType, "data-") {
panic(fmt.Sprintf("AiMsgData type must start with 'data-', got: %s", dataType))
}
resp := map[string]interface{}{
"type": dataType,
"id": id,
"data": data,
}
return h.WriteJsonData(resp)
}