waveterm/pkg/aiusechat/chatstore/chatstore.go
Mike Sawka e0ca73ad53
builder secrets, builder config/data tab hooked up (#2581)
builder secrets, builder config/data tab hooked up, and tsunami cors
config env var
2025-11-21 10:36:51 -08:00

111 lines
2.6 KiB
Go

// Copyright 2025, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0
package chatstore
import (
"fmt"
"sync"
"github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes"
)
type ChatStore struct {
lock sync.Mutex
chats map[string]*uctypes.AIChat
}
var DefaultChatStore = &ChatStore{
chats: make(map[string]*uctypes.AIChat),
}
func (cs *ChatStore) Get(chatId string) *uctypes.AIChat {
cs.lock.Lock()
defer cs.lock.Unlock()
chat := cs.chats[chatId]
if chat == nil {
return nil
}
// Copy the chat to prevent concurrent access issues
copyChat := &uctypes.AIChat{
ChatId: chat.ChatId,
APIType: chat.APIType,
Model: chat.Model,
APIVersion: chat.APIVersion,
NativeMessages: make([]uctypes.GenAIMessage, len(chat.NativeMessages)),
}
copy(copyChat.NativeMessages, chat.NativeMessages)
return copyChat
}
func (cs *ChatStore) Delete(chatId string) {
cs.lock.Lock()
defer cs.lock.Unlock()
delete(cs.chats, chatId)
}
func (cs *ChatStore) CountUserMessages(chatId string) int {
cs.lock.Lock()
defer cs.lock.Unlock()
chat := cs.chats[chatId]
if chat == nil {
return 0
}
count := 0
for _, msg := range chat.NativeMessages {
if msg.GetRole() == "user" {
count++
}
}
return count
}
func (cs *ChatStore) PostMessage(chatId string, aiOpts *uctypes.AIOptsType, message uctypes.GenAIMessage) error {
cs.lock.Lock()
defer cs.lock.Unlock()
chat := cs.chats[chatId]
if chat == nil {
// Create new chat
chat = &uctypes.AIChat{
ChatId: chatId,
APIType: aiOpts.APIType,
Model: aiOpts.Model,
APIVersion: aiOpts.APIVersion,
NativeMessages: make([]uctypes.GenAIMessage, 0),
}
cs.chats[chatId] = chat
} else {
// Verify that the AI options match
if chat.APIType != aiOpts.APIType {
return fmt.Errorf("API type mismatch: expected %s, got %s", chat.APIType, aiOpts.APIType)
}
if !uctypes.AreModelsCompatible(chat.APIType, chat.Model, aiOpts.Model) {
return fmt.Errorf("model mismatch: expected %s, got %s", chat.Model, aiOpts.Model)
}
if chat.APIVersion != aiOpts.APIVersion {
return fmt.Errorf("API version mismatch: expected %s, got %s", chat.APIVersion, aiOpts.APIVersion)
}
}
// Check for existing message with same ID (idempotency)
messageId := message.GetMessageId()
for i, existingMessage := range chat.NativeMessages {
if existingMessage.GetMessageId() == messageId {
// Replace existing message with same ID
chat.NativeMessages[i] = message
return nil
}
}
// Append the new message if no duplicate found
chat.NativeMessages = append(chat.NativeMessages, message)
return nil
}