mirror of
				https://github.com/zadam/trilium.git
				synced 2025-11-04 05:28:59 +01:00 
			
		
		
		
	feat(llm): also improve the llm streaming service, to make it cooperate with unit tests better
This commit is contained in:
		
							parent
							
								
									40cad2e886
								
							
						
					
					
						commit
						c1bcb73337
					
				@ -195,7 +195,7 @@ async function updateSession(req: Request, res: Response) {
 | 
			
		||||
        // Get the chat
 | 
			
		||||
        const chat = await chatStorageService.getChat(chatNoteId);
 | 
			
		||||
        if (!chat) {
 | 
			
		||||
            throw new Error(`Chat with ID ${chatNoteId} not found`);
 | 
			
		||||
            return [404, { error: `Chat with ID ${chatNoteId} not found` }];
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // Update title if provided
 | 
			
		||||
@ -211,7 +211,7 @@ async function updateSession(req: Request, res: Response) {
 | 
			
		||||
        };
 | 
			
		||||
    } catch (error) {
 | 
			
		||||
        log.error(`Error updating chat: ${error}`);
 | 
			
		||||
        throw new Error(`Failed to update chat: ${error}`);
 | 
			
		||||
        return [500, { error: `Failed to update chat: ${error}` }];
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -264,7 +264,7 @@ async function listSessions(req: Request, res: Response) {
 | 
			
		||||
        };
 | 
			
		||||
    } catch (error) {
 | 
			
		||||
        log.error(`Error listing sessions: ${error}`);
 | 
			
		||||
        throw new Error(`Failed to list sessions: ${error}`);
 | 
			
		||||
        return [500, { error: `Failed to list sessions: ${error}` }];
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -419,123 +419,211 @@ async function sendMessage(req: Request, res: Response) {
 | 
			
		||||
 */
 | 
			
		||||
async function streamMessage(req: Request, res: Response) {
 | 
			
		||||
    log.info("=== Starting streamMessage ===");
 | 
			
		||||
    try {
 | 
			
		||||
        const chatNoteId = req.params.chatNoteId;
 | 
			
		||||
        const { content, useAdvancedContext, showThinking, mentions } = req.body;
 | 
			
		||||
    
 | 
			
		||||
    const chatNoteId = req.params.chatNoteId;
 | 
			
		||||
    const { content, useAdvancedContext, showThinking, mentions } = req.body;
 | 
			
		||||
 | 
			
		||||
        if (!content || typeof content !== 'string' || content.trim().length === 0) {
 | 
			
		||||
            return res.status(400).json({
 | 
			
		||||
                success: false,
 | 
			
		||||
                error: 'Content cannot be empty'
 | 
			
		||||
    // Input validation
 | 
			
		||||
    if (!content || typeof content !== 'string' || content.trim().length === 0) {
 | 
			
		||||
        return [400, {
 | 
			
		||||
            success: false,
 | 
			
		||||
            error: 'Content cannot be empty'
 | 
			
		||||
        }];
 | 
			
		||||
    }
 | 
			
		||||
    
 | 
			
		||||
    // Start background streaming process immediately (before sending response)
 | 
			
		||||
    const backgroundPromise = handleStreamingProcess(chatNoteId, content, useAdvancedContext, showThinking, mentions)
 | 
			
		||||
        .catch(error => {
 | 
			
		||||
            log.error(`Background streaming error: ${error.message}`);
 | 
			
		||||
            
 | 
			
		||||
            // Send error via WebSocket since HTTP response was already sent
 | 
			
		||||
            import('../../services/ws.js').then(wsModule => {
 | 
			
		||||
                wsModule.default.sendMessageToAllClients({
 | 
			
		||||
                    type: 'llm-stream',
 | 
			
		||||
                    chatNoteId: chatNoteId,
 | 
			
		||||
                    error: `Error during streaming: ${error.message}`,
 | 
			
		||||
                    done: true
 | 
			
		||||
                });
 | 
			
		||||
            }).catch(wsError => {
 | 
			
		||||
                log.error(`Could not send WebSocket error: ${wsError}`);
 | 
			
		||||
            });
 | 
			
		||||
        }
 | 
			
		||||
        
 | 
			
		||||
        // IMPORTANT: Immediately send a success response to the initial POST request
 | 
			
		||||
        // The client is waiting for this to confirm streaming has been initiated
 | 
			
		||||
        res.status(200).json({
 | 
			
		||||
            success: true,
 | 
			
		||||
            message: 'Streaming initiated successfully'
 | 
			
		||||
        });
 | 
			
		||||
        
 | 
			
		||||
        // Mark response as handled to prevent apiResultHandler from processing it again
 | 
			
		||||
        (res as any).triliumResponseHandled = true;
 | 
			
		||||
        
 | 
			
		||||
        
 | 
			
		||||
        // Create a new response object for streaming through WebSocket only
 | 
			
		||||
        // We won't use HTTP streaming since we've already sent the HTTP response
 | 
			
		||||
    
 | 
			
		||||
    // Return immediate acknowledgment that streaming has been initiated
 | 
			
		||||
    // The background process will handle the actual streaming
 | 
			
		||||
    return {
 | 
			
		||||
        success: true,
 | 
			
		||||
        message: 'Streaming initiated successfully'
 | 
			
		||||
    };
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
        // Get or create chat directly from storage (simplified approach)
 | 
			
		||||
        let chat = await chatStorageService.getChat(chatNoteId);
 | 
			
		||||
        if (!chat) {
 | 
			
		||||
            // Create a new chat if it doesn't exist
 | 
			
		||||
            chat = await chatStorageService.createChat('New Chat');
 | 
			
		||||
            log.info(`Created new chat with ID: ${chat.id} for stream request`);
 | 
			
		||||
        }
 | 
			
		||||
        
 | 
			
		||||
        // Add the user message to the chat immediately
 | 
			
		||||
        chat.messages.push({
 | 
			
		||||
            role: 'user',
 | 
			
		||||
            content
 | 
			
		||||
        });
 | 
			
		||||
        // Save the chat to ensure the user message is recorded
 | 
			
		||||
        await chatStorageService.updateChat(chat.id, chat.messages, chat.title);
 | 
			
		||||
/**
 | 
			
		||||
 * Handle the streaming process in the background
 | 
			
		||||
 * This is separate from the HTTP request/response cycle
 | 
			
		||||
 */
 | 
			
		||||
async function handleStreamingProcess(
 | 
			
		||||
    chatNoteId: string, 
 | 
			
		||||
    content: string, 
 | 
			
		||||
    useAdvancedContext: boolean, 
 | 
			
		||||
    showThinking: boolean, 
 | 
			
		||||
    mentions: any[]
 | 
			
		||||
) {
 | 
			
		||||
    log.info("=== Starting background streaming process ===");
 | 
			
		||||
    
 | 
			
		||||
    // Get or create chat directly from storage
 | 
			
		||||
    let chat = await chatStorageService.getChat(chatNoteId);
 | 
			
		||||
    if (!chat) {
 | 
			
		||||
        chat = await chatStorageService.createChat('New Chat');
 | 
			
		||||
        log.info(`Created new chat with ID: ${chat.id} for stream request`);
 | 
			
		||||
    }
 | 
			
		||||
    
 | 
			
		||||
    // Add the user message to the chat immediately
 | 
			
		||||
    chat.messages.push({
 | 
			
		||||
        role: 'user',
 | 
			
		||||
        content
 | 
			
		||||
    });
 | 
			
		||||
    await chatStorageService.updateChat(chat.id, chat.messages, chat.title);
 | 
			
		||||
 | 
			
		||||
        // Process mentions if provided
 | 
			
		||||
        let enhancedContent = content;
 | 
			
		||||
        if (mentions && Array.isArray(mentions) && mentions.length > 0) {
 | 
			
		||||
            log.info(`Processing ${mentions.length} note mentions`);
 | 
			
		||||
    // Process mentions if provided
 | 
			
		||||
    let enhancedContent = content;
 | 
			
		||||
    if (mentions && Array.isArray(mentions) && mentions.length > 0) {
 | 
			
		||||
        log.info(`Processing ${mentions.length} note mentions`);
 | 
			
		||||
 | 
			
		||||
            // Import note service to get note content
 | 
			
		||||
            const becca = (await import('../../becca/becca.js')).default;
 | 
			
		||||
            const mentionContexts: string[] = [];
 | 
			
		||||
        const becca = (await import('../../becca/becca.js')).default;
 | 
			
		||||
        const mentionContexts: string[] = [];
 | 
			
		||||
 | 
			
		||||
            for (const mention of mentions) {
 | 
			
		||||
                try {
 | 
			
		||||
                    const note = becca.getNote(mention.noteId);
 | 
			
		||||
                    if (note && !note.isDeleted) {
 | 
			
		||||
                        const noteContent = note.getContent();
 | 
			
		||||
                        if (noteContent && typeof noteContent === 'string' && noteContent.trim()) {
 | 
			
		||||
                            mentionContexts.push(`\n\n--- Content from "${mention.title}" (${mention.noteId}) ---\n${noteContent}\n--- End of "${mention.title}" ---`);
 | 
			
		||||
                            log.info(`Added content from note "${mention.title}" (${mention.noteId})`);
 | 
			
		||||
                        }
 | 
			
		||||
                    } else {
 | 
			
		||||
                        log.info(`Referenced note not found or deleted: ${mention.noteId}`);
 | 
			
		||||
        for (const mention of mentions) {
 | 
			
		||||
            try {
 | 
			
		||||
                const note = becca.getNote(mention.noteId);
 | 
			
		||||
                if (note && !note.isDeleted) {
 | 
			
		||||
                    const noteContent = note.getContent();
 | 
			
		||||
                    if (noteContent && typeof noteContent === 'string' && noteContent.trim()) {
 | 
			
		||||
                        mentionContexts.push(`\n\n--- Content from "${mention.title}" (${mention.noteId}) ---\n${noteContent}\n--- End of "${mention.title}" ---`);
 | 
			
		||||
                        log.info(`Added content from note "${mention.title}" (${mention.noteId})`);
 | 
			
		||||
                    }
 | 
			
		||||
                } catch (error) {
 | 
			
		||||
                    log.error(`Error retrieving content for note ${mention.noteId}: ${error}`);
 | 
			
		||||
                } else {
 | 
			
		||||
                    log.info(`Referenced note not found or deleted: ${mention.noteId}`);
 | 
			
		||||
                }
 | 
			
		||||
            } catch (error) {
 | 
			
		||||
                log.error(`Error retrieving content for note ${mention.noteId}: ${error}`);
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (mentionContexts.length > 0) {
 | 
			
		||||
            enhancedContent = `${content}\n\n=== Referenced Notes ===\n${mentionContexts.join('\n')}`;
 | 
			
		||||
            log.info(`Enhanced content with ${mentionContexts.length} note references`);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Import WebSocket service for streaming
 | 
			
		||||
    const wsService = (await import('../../services/ws.js')).default;
 | 
			
		||||
 | 
			
		||||
    // Let the client know streaming has started
 | 
			
		||||
    wsService.sendMessageToAllClients({
 | 
			
		||||
        type: 'llm-stream',
 | 
			
		||||
        chatNoteId: chatNoteId,
 | 
			
		||||
        thinking: showThinking ? 'Initializing streaming LLM response...' : undefined
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    // Instead of calling the complex handleSendMessage service, 
 | 
			
		||||
    // let's implement streaming directly to avoid response conflicts
 | 
			
		||||
    
 | 
			
		||||
    try {
 | 
			
		||||
        // Check if AI is enabled
 | 
			
		||||
        const optionsModule = await import('../../services/options.js');
 | 
			
		||||
        const aiEnabled = optionsModule.default.getOptionBool('aiEnabled');
 | 
			
		||||
        if (!aiEnabled) {
 | 
			
		||||
            throw new Error("AI features are disabled. Please enable them in the settings.");
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // Get AI service
 | 
			
		||||
        const aiServiceManager = await import('../ai_service_manager.js');
 | 
			
		||||
        await aiServiceManager.default.getOrCreateAnyService();
 | 
			
		||||
 | 
			
		||||
        // Use the chat pipeline directly for streaming
 | 
			
		||||
        const { ChatPipeline } = await import('../pipeline/chat_pipeline.js');
 | 
			
		||||
        const pipeline = new ChatPipeline({
 | 
			
		||||
            enableStreaming: true,
 | 
			
		||||
            enableMetrics: true,
 | 
			
		||||
            maxToolCallIterations: 5
 | 
			
		||||
        });
 | 
			
		||||
 | 
			
		||||
        // Get selected model
 | 
			
		||||
        const { getSelectedModelConfig } = await import('../config/configuration_helpers.js');
 | 
			
		||||
        const modelConfig = await getSelectedModelConfig();
 | 
			
		||||
        
 | 
			
		||||
        if (!modelConfig) {
 | 
			
		||||
            throw new Error("No valid AI model configuration found");
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        const pipelineInput = {
 | 
			
		||||
            messages: chat.messages.map(msg => ({
 | 
			
		||||
                role: msg.role as 'user' | 'assistant' | 'system',
 | 
			
		||||
                content: msg.content
 | 
			
		||||
            })),
 | 
			
		||||
            query: enhancedContent,
 | 
			
		||||
            noteId: undefined,
 | 
			
		||||
            showThinking: showThinking,
 | 
			
		||||
            options: {
 | 
			
		||||
                useAdvancedContext: useAdvancedContext === true,
 | 
			
		||||
                model: modelConfig.model,
 | 
			
		||||
                stream: true,
 | 
			
		||||
                chatNoteId: chatNoteId
 | 
			
		||||
            },
 | 
			
		||||
            streamCallback: (data, done, rawChunk) => {
 | 
			
		||||
                const message = {
 | 
			
		||||
                    type: 'llm-stream' as const,
 | 
			
		||||
                    chatNoteId: chatNoteId,
 | 
			
		||||
                    done: done
 | 
			
		||||
                };
 | 
			
		||||
 | 
			
		||||
                if (data) {
 | 
			
		||||
                    (message as any).content = data;
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                if (rawChunk && 'thinking' in rawChunk && rawChunk.thinking) {
 | 
			
		||||
                    (message as any).thinking = rawChunk.thinking as string;
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                if (rawChunk && 'toolExecution' in rawChunk && rawChunk.toolExecution) {
 | 
			
		||||
                    const toolExec = rawChunk.toolExecution;
 | 
			
		||||
                    (message as any).toolExecution = {
 | 
			
		||||
                        tool: typeof toolExec.tool === 'string' ? toolExec.tool : toolExec.tool?.name,
 | 
			
		||||
                        result: toolExec.result,
 | 
			
		||||
                        args: 'arguments' in toolExec ?
 | 
			
		||||
                            (typeof toolExec.arguments === 'object' ? toolExec.arguments as Record<string, unknown> : {}) : {},
 | 
			
		||||
                        action: 'action' in toolExec ? toolExec.action as string : undefined,
 | 
			
		||||
                        toolCallId: 'toolCallId' in toolExec ? toolExec.toolCallId as string : undefined,
 | 
			
		||||
                        error: 'error' in toolExec ? toolExec.error as string : undefined
 | 
			
		||||
                    };
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                wsService.sendMessageToAllClients(message);
 | 
			
		||||
 | 
			
		||||
                // Save final response when done
 | 
			
		||||
                if (done && data) {
 | 
			
		||||
                    chat.messages.push({
 | 
			
		||||
                        role: 'assistant',
 | 
			
		||||
                        content: data
 | 
			
		||||
                    });
 | 
			
		||||
                    chatStorageService.updateChat(chat.id, chat.messages, chat.title).catch(err => {
 | 
			
		||||
                        log.error(`Error saving streamed response: ${err}`);
 | 
			
		||||
                    });
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
            // Enhance the content with note references
 | 
			
		||||
            if (mentionContexts.length > 0) {
 | 
			
		||||
                enhancedContent = `${content}\n\n=== Referenced Notes ===\n${mentionContexts.join('\n')}`;
 | 
			
		||||
                log.info(`Enhanced content with ${mentionContexts.length} note references`);
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // Import the WebSocket service to send immediate feedback
 | 
			
		||||
        const wsService = (await import('../../services/ws.js')).default;
 | 
			
		||||
 | 
			
		||||
        // Let the client know streaming has started
 | 
			
		||||
        // Execute the pipeline
 | 
			
		||||
        await pipeline.execute(pipelineInput);
 | 
			
		||||
        
 | 
			
		||||
    } catch (error: any) {
 | 
			
		||||
        log.error(`Error in direct streaming: ${error.message}`);
 | 
			
		||||
        wsService.sendMessageToAllClients({
 | 
			
		||||
            type: 'llm-stream',
 | 
			
		||||
            chatNoteId: chatNoteId,
 | 
			
		||||
            thinking: showThinking ? 'Initializing streaming LLM response...' : undefined
 | 
			
		||||
            error: `Error during streaming: ${error.message}`,
 | 
			
		||||
            done: true
 | 
			
		||||
        });
 | 
			
		||||
 | 
			
		||||
        // Process the LLM request using the existing service but with streaming setup
 | 
			
		||||
        // Since we've already sent the initial HTTP response, we'll use the WebSocket for streaming
 | 
			
		||||
        try {
 | 
			
		||||
            // Call restChatService with streaming mode enabled
 | 
			
		||||
            // The important part is setting method to GET to indicate streaming mode
 | 
			
		||||
            await restChatService.handleSendMessage({
 | 
			
		||||
                ...req,
 | 
			
		||||
                method: 'GET', // Indicate streaming mode
 | 
			
		||||
                query: {
 | 
			
		||||
                    ...req.query,
 | 
			
		||||
                    stream: 'true' // Add the required stream parameter
 | 
			
		||||
                },
 | 
			
		||||
                body: {
 | 
			
		||||
                    content: enhancedContent,
 | 
			
		||||
                    useAdvancedContext: useAdvancedContext === true,
 | 
			
		||||
                    showThinking: showThinking === true
 | 
			
		||||
                },
 | 
			
		||||
                params: { chatNoteId }
 | 
			
		||||
            } as unknown as Request, res);
 | 
			
		||||
        } catch (streamError) {
 | 
			
		||||
            log.error(`Error during WebSocket streaming: ${streamError}`);
 | 
			
		||||
            
 | 
			
		||||
            // Send error message through WebSocket
 | 
			
		||||
            wsService.sendMessageToAllClients({
 | 
			
		||||
                type: 'llm-stream',
 | 
			
		||||
                chatNoteId: chatNoteId,
 | 
			
		||||
                error: `Error during streaming: ${streamError}`,
 | 
			
		||||
                done: true
 | 
			
		||||
            });
 | 
			
		||||
        }
 | 
			
		||||
    } catch (error: any) {
 | 
			
		||||
        log.error(`Error starting message stream: ${error.message}`);
 | 
			
		||||
        log.error(`Error starting message stream, can't communicate via WebSocket: ${error.message}`);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -325,13 +325,13 @@ class RestChatService {
 | 
			
		||||
 | 
			
		||||
            const chat = await chatStorageService.getChat(sessionId);
 | 
			
		||||
            if (!chat) {
 | 
			
		||||
                res.status(404).json({
 | 
			
		||||
                // Return error in Express route format [statusCode, response]
 | 
			
		||||
                return [404, {
 | 
			
		||||
                    error: true,
 | 
			
		||||
                    message: `Session with ID ${sessionId} not found`,
 | 
			
		||||
                    code: 'session_not_found',
 | 
			
		||||
                    sessionId
 | 
			
		||||
                });
 | 
			
		||||
                return null;
 | 
			
		||||
                }];
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            return {
 | 
			
		||||
@ -344,7 +344,7 @@ class RestChatService {
 | 
			
		||||
            };
 | 
			
		||||
        } catch (error: any) {
 | 
			
		||||
            log.error(`Error getting chat session: ${error.message || 'Unknown error'}`);
 | 
			
		||||
            throw new Error(`Failed to get session: ${error.message || 'Unknown error'}`);
 | 
			
		||||
            return [500, { error: `Failed to get session: ${error.message || 'Unknown error'}` }];
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user