mirror of
https://github.com/wavetermdev/waveterm.git
synced 2025-11-28 21:20:25 +08:00
111 lines
2.6 KiB
Go
111 lines
2.6 KiB
Go
// Copyright 2025, Command Line Inc.
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
package chatstore
|
|
|
|
import (
|
|
"fmt"
|
|
"sync"
|
|
|
|
"github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes"
|
|
)
|
|
|
|
type ChatStore struct {
|
|
lock sync.Mutex
|
|
chats map[string]*uctypes.AIChat
|
|
}
|
|
|
|
var DefaultChatStore = &ChatStore{
|
|
chats: make(map[string]*uctypes.AIChat),
|
|
}
|
|
|
|
func (cs *ChatStore) Get(chatId string) *uctypes.AIChat {
|
|
cs.lock.Lock()
|
|
defer cs.lock.Unlock()
|
|
|
|
chat := cs.chats[chatId]
|
|
if chat == nil {
|
|
return nil
|
|
}
|
|
|
|
// Copy the chat to prevent concurrent access issues
|
|
copyChat := &uctypes.AIChat{
|
|
ChatId: chat.ChatId,
|
|
APIType: chat.APIType,
|
|
Model: chat.Model,
|
|
APIVersion: chat.APIVersion,
|
|
NativeMessages: make([]uctypes.GenAIMessage, len(chat.NativeMessages)),
|
|
}
|
|
copy(copyChat.NativeMessages, chat.NativeMessages)
|
|
|
|
return copyChat
|
|
}
|
|
|
|
func (cs *ChatStore) Delete(chatId string) {
|
|
cs.lock.Lock()
|
|
defer cs.lock.Unlock()
|
|
|
|
delete(cs.chats, chatId)
|
|
}
|
|
|
|
func (cs *ChatStore) CountUserMessages(chatId string) int {
|
|
cs.lock.Lock()
|
|
defer cs.lock.Unlock()
|
|
|
|
chat := cs.chats[chatId]
|
|
if chat == nil {
|
|
return 0
|
|
}
|
|
|
|
count := 0
|
|
for _, msg := range chat.NativeMessages {
|
|
if msg.GetRole() == "user" {
|
|
count++
|
|
}
|
|
}
|
|
return count
|
|
}
|
|
|
|
func (cs *ChatStore) PostMessage(chatId string, aiOpts *uctypes.AIOptsType, message uctypes.GenAIMessage) error {
|
|
cs.lock.Lock()
|
|
defer cs.lock.Unlock()
|
|
|
|
chat := cs.chats[chatId]
|
|
if chat == nil {
|
|
// Create new chat
|
|
chat = &uctypes.AIChat{
|
|
ChatId: chatId,
|
|
APIType: aiOpts.APIType,
|
|
Model: aiOpts.Model,
|
|
APIVersion: aiOpts.APIVersion,
|
|
NativeMessages: make([]uctypes.GenAIMessage, 0),
|
|
}
|
|
cs.chats[chatId] = chat
|
|
} else {
|
|
// Verify that the AI options match
|
|
if chat.APIType != aiOpts.APIType {
|
|
return fmt.Errorf("API type mismatch: expected %s, got %s", chat.APIType, aiOpts.APIType)
|
|
}
|
|
if !uctypes.AreModelsCompatible(chat.APIType, chat.Model, aiOpts.Model) {
|
|
return fmt.Errorf("model mismatch: expected %s, got %s", chat.Model, aiOpts.Model)
|
|
}
|
|
if chat.APIVersion != aiOpts.APIVersion {
|
|
return fmt.Errorf("API version mismatch: expected %s, got %s", chat.APIVersion, aiOpts.APIVersion)
|
|
}
|
|
}
|
|
|
|
// Check for existing message with same ID (idempotency)
|
|
messageId := message.GetMessageId()
|
|
for i, existingMessage := range chat.NativeMessages {
|
|
if existingMessage.GetMessageId() == messageId {
|
|
// Replace existing message with same ID
|
|
chat.NativeMessages[i] = message
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// Append the new message if no duplicate found
|
|
chat.NativeMessages = append(chat.NativeMessages, message)
|
|
|
|
return nil
|
|
}
|