mirror of
				https://github.com/zadam/trilium.git
				synced 2025-11-03 21:19:01 +01:00 
			
		
		
		
	feat(llm): also improve the llm streaming service, to make it cooperate with unit tests better
This commit is contained in:
		
							parent
							
								
									40cad2e886
								
							
						
					
					
						commit
						c1bcb73337
					
				@ -195,7 +195,7 @@ async function updateSession(req: Request, res: Response) {
 | 
				
			|||||||
        // Get the chat
 | 
					        // Get the chat
 | 
				
			||||||
        const chat = await chatStorageService.getChat(chatNoteId);
 | 
					        const chat = await chatStorageService.getChat(chatNoteId);
 | 
				
			||||||
        if (!chat) {
 | 
					        if (!chat) {
 | 
				
			||||||
            throw new Error(`Chat with ID ${chatNoteId} not found`);
 | 
					            return [404, { error: `Chat with ID ${chatNoteId} not found` }];
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // Update title if provided
 | 
					        // Update title if provided
 | 
				
			||||||
@ -211,7 +211,7 @@ async function updateSession(req: Request, res: Response) {
 | 
				
			|||||||
        };
 | 
					        };
 | 
				
			||||||
    } catch (error) {
 | 
					    } catch (error) {
 | 
				
			||||||
        log.error(`Error updating chat: ${error}`);
 | 
					        log.error(`Error updating chat: ${error}`);
 | 
				
			||||||
        throw new Error(`Failed to update chat: ${error}`);
 | 
					        return [500, { error: `Failed to update chat: ${error}` }];
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -264,7 +264,7 @@ async function listSessions(req: Request, res: Response) {
 | 
				
			|||||||
        };
 | 
					        };
 | 
				
			||||||
    } catch (error) {
 | 
					    } catch (error) {
 | 
				
			||||||
        log.error(`Error listing sessions: ${error}`);
 | 
					        log.error(`Error listing sessions: ${error}`);
 | 
				
			||||||
        throw new Error(`Failed to list sessions: ${error}`);
 | 
					        return [500, { error: `Failed to list sessions: ${error}` }];
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -419,35 +419,60 @@ async function sendMessage(req: Request, res: Response) {
 | 
				
			|||||||
 */
 | 
					 */
 | 
				
			||||||
async function streamMessage(req: Request, res: Response) {
 | 
					async function streamMessage(req: Request, res: Response) {
 | 
				
			||||||
    log.info("=== Starting streamMessage ===");
 | 
					    log.info("=== Starting streamMessage ===");
 | 
				
			||||||
    try {
 | 
					    
 | 
				
			||||||
    const chatNoteId = req.params.chatNoteId;
 | 
					    const chatNoteId = req.params.chatNoteId;
 | 
				
			||||||
    const { content, useAdvancedContext, showThinking, mentions } = req.body;
 | 
					    const { content, useAdvancedContext, showThinking, mentions } = req.body;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Input validation
 | 
				
			||||||
    if (!content || typeof content !== 'string' || content.trim().length === 0) {
 | 
					    if (!content || typeof content !== 'string' || content.trim().length === 0) {
 | 
				
			||||||
            return res.status(400).json({
 | 
					        return [400, {
 | 
				
			||||||
            success: false,
 | 
					            success: false,
 | 
				
			||||||
            error: 'Content cannot be empty'
 | 
					            error: 'Content cannot be empty'
 | 
				
			||||||
            });
 | 
					        }];
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
        // IMPORTANT: Immediately send a success response to the initial POST request
 | 
					    // Start background streaming process immediately (before sending response)
 | 
				
			||||||
        // The client is waiting for this to confirm streaming has been initiated
 | 
					    const backgroundPromise = handleStreamingProcess(chatNoteId, content, useAdvancedContext, showThinking, mentions)
 | 
				
			||||||
        res.status(200).json({
 | 
					        .catch(error => {
 | 
				
			||||||
            success: true,
 | 
					            log.error(`Background streaming error: ${error.message}`);
 | 
				
			||||||
            message: 'Streaming initiated successfully'
 | 
					            
 | 
				
			||||||
 | 
					            // Send error via WebSocket since HTTP response was already sent
 | 
				
			||||||
 | 
					            import('../../services/ws.js').then(wsModule => {
 | 
				
			||||||
 | 
					                wsModule.default.sendMessageToAllClients({
 | 
				
			||||||
 | 
					                    type: 'llm-stream',
 | 
				
			||||||
 | 
					                    chatNoteId: chatNoteId,
 | 
				
			||||||
 | 
					                    error: `Error during streaming: ${error.message}`,
 | 
				
			||||||
 | 
					                    done: true
 | 
				
			||||||
 | 
					                });
 | 
				
			||||||
 | 
					            }).catch(wsError => {
 | 
				
			||||||
 | 
					                log.error(`Could not send WebSocket error: ${wsError}`);
 | 
				
			||||||
 | 
					            });
 | 
				
			||||||
        });
 | 
					        });
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
        // Mark response as handled to prevent apiResultHandler from processing it again
 | 
					    // Return immediate acknowledgment that streaming has been initiated
 | 
				
			||||||
        (res as any).triliumResponseHandled = true;
 | 
					    // The background process will handle the actual streaming
 | 
				
			||||||
 | 
					    return {
 | 
				
			||||||
 | 
					        success: true,
 | 
				
			||||||
 | 
					        message: 'Streaming initiated successfully'
 | 
				
			||||||
 | 
					    };
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Handle the streaming process in the background
 | 
				
			||||||
 | 
					 * This is separate from the HTTP request/response cycle
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					async function handleStreamingProcess(
 | 
				
			||||||
 | 
					    chatNoteId: string, 
 | 
				
			||||||
 | 
					    content: string, 
 | 
				
			||||||
 | 
					    useAdvancedContext: boolean, 
 | 
				
			||||||
 | 
					    showThinking: boolean, 
 | 
				
			||||||
 | 
					    mentions: any[]
 | 
				
			||||||
 | 
					) {
 | 
				
			||||||
 | 
					    log.info("=== Starting background streaming process ===");
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
        // Create a new response object for streaming through WebSocket only
 | 
					    // Get or create chat directly from storage
 | 
				
			||||||
        // We won't use HTTP streaming since we've already sent the HTTP response
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        // Get or create chat directly from storage (simplified approach)
 | 
					 | 
				
			||||||
    let chat = await chatStorageService.getChat(chatNoteId);
 | 
					    let chat = await chatStorageService.getChat(chatNoteId);
 | 
				
			||||||
    if (!chat) {
 | 
					    if (!chat) {
 | 
				
			||||||
            // Create a new chat if it doesn't exist
 | 
					 | 
				
			||||||
        chat = await chatStorageService.createChat('New Chat');
 | 
					        chat = await chatStorageService.createChat('New Chat');
 | 
				
			||||||
        log.info(`Created new chat with ID: ${chat.id} for stream request`);
 | 
					        log.info(`Created new chat with ID: ${chat.id} for stream request`);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@ -457,7 +482,6 @@ async function streamMessage(req: Request, res: Response) {
 | 
				
			|||||||
        role: 'user',
 | 
					        role: 'user',
 | 
				
			||||||
        content
 | 
					        content
 | 
				
			||||||
    });
 | 
					    });
 | 
				
			||||||
        // Save the chat to ensure the user message is recorded
 | 
					 | 
				
			||||||
    await chatStorageService.updateChat(chat.id, chat.messages, chat.title);
 | 
					    await chatStorageService.updateChat(chat.id, chat.messages, chat.title);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Process mentions if provided
 | 
					    // Process mentions if provided
 | 
				
			||||||
@ -465,7 +489,6 @@ async function streamMessage(req: Request, res: Response) {
 | 
				
			|||||||
    if (mentions && Array.isArray(mentions) && mentions.length > 0) {
 | 
					    if (mentions && Array.isArray(mentions) && mentions.length > 0) {
 | 
				
			||||||
        log.info(`Processing ${mentions.length} note mentions`);
 | 
					        log.info(`Processing ${mentions.length} note mentions`);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            // Import note service to get note content
 | 
					 | 
				
			||||||
        const becca = (await import('../../becca/becca.js')).default;
 | 
					        const becca = (await import('../../becca/becca.js')).default;
 | 
				
			||||||
        const mentionContexts: string[] = [];
 | 
					        const mentionContexts: string[] = [];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -486,14 +509,13 @@ async function streamMessage(req: Request, res: Response) {
 | 
				
			|||||||
            }
 | 
					            }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            // Enhance the content with note references
 | 
					 | 
				
			||||||
        if (mentionContexts.length > 0) {
 | 
					        if (mentionContexts.length > 0) {
 | 
				
			||||||
            enhancedContent = `${content}\n\n=== Referenced Notes ===\n${mentionContexts.join('\n')}`;
 | 
					            enhancedContent = `${content}\n\n=== Referenced Notes ===\n${mentionContexts.join('\n')}`;
 | 
				
			||||||
            log.info(`Enhanced content with ${mentionContexts.length} note references`);
 | 
					            log.info(`Enhanced content with ${mentionContexts.length} note references`);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // Import the WebSocket service to send immediate feedback
 | 
					    // Import WebSocket service for streaming
 | 
				
			||||||
    const wsService = (await import('../../services/ws.js')).default;
 | 
					    const wsService = (await import('../../services/ws.js')).default;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Let the client know streaming has started
 | 
					    // Let the client know streaming has started
 | 
				
			||||||
@ -503,40 +525,106 @@ async function streamMessage(req: Request, res: Response) {
 | 
				
			|||||||
        thinking: showThinking ? 'Initializing streaming LLM response...' : undefined
 | 
					        thinking: showThinking ? 'Initializing streaming LLM response...' : undefined
 | 
				
			||||||
    });
 | 
					    });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // Process the LLM request using the existing service but with streaming setup
 | 
					    // Instead of calling the complex handleSendMessage service, 
 | 
				
			||||||
        // Since we've already sent the initial HTTP response, we'll use the WebSocket for streaming
 | 
					    // let's implement streaming directly to avoid response conflicts
 | 
				
			||||||
        try {
 | 
					 | 
				
			||||||
            // Call restChatService with streaming mode enabled
 | 
					 | 
				
			||||||
            // The important part is setting method to GET to indicate streaming mode
 | 
					 | 
				
			||||||
            await restChatService.handleSendMessage({
 | 
					 | 
				
			||||||
                ...req,
 | 
					 | 
				
			||||||
                method: 'GET', // Indicate streaming mode
 | 
					 | 
				
			||||||
                query: {
 | 
					 | 
				
			||||||
                    ...req.query,
 | 
					 | 
				
			||||||
                    stream: 'true' // Add the required stream parameter
 | 
					 | 
				
			||||||
                },
 | 
					 | 
				
			||||||
                body: {
 | 
					 | 
				
			||||||
                    content: enhancedContent,
 | 
					 | 
				
			||||||
                    useAdvancedContext: useAdvancedContext === true,
 | 
					 | 
				
			||||||
                    showThinking: showThinking === true
 | 
					 | 
				
			||||||
                },
 | 
					 | 
				
			||||||
                params: { chatNoteId }
 | 
					 | 
				
			||||||
            } as unknown as Request, res);
 | 
					 | 
				
			||||||
        } catch (streamError) {
 | 
					 | 
				
			||||||
            log.error(`Error during WebSocket streaming: ${streamError}`);
 | 
					 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
            // Send error message through WebSocket
 | 
					    try {
 | 
				
			||||||
 | 
					        // Check if AI is enabled
 | 
				
			||||||
 | 
					        const optionsModule = await import('../../services/options.js');
 | 
				
			||||||
 | 
					        const aiEnabled = optionsModule.default.getOptionBool('aiEnabled');
 | 
				
			||||||
 | 
					        if (!aiEnabled) {
 | 
				
			||||||
 | 
					            throw new Error("AI features are disabled. Please enable them in the settings.");
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // Get AI service
 | 
				
			||||||
 | 
					        const aiServiceManager = await import('../ai_service_manager.js');
 | 
				
			||||||
 | 
					        await aiServiceManager.default.getOrCreateAnyService();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // Use the chat pipeline directly for streaming
 | 
				
			||||||
 | 
					        const { ChatPipeline } = await import('../pipeline/chat_pipeline.js');
 | 
				
			||||||
 | 
					        const pipeline = new ChatPipeline({
 | 
				
			||||||
 | 
					            enableStreaming: true,
 | 
				
			||||||
 | 
					            enableMetrics: true,
 | 
				
			||||||
 | 
					            maxToolCallIterations: 5
 | 
				
			||||||
 | 
					        });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // Get selected model
 | 
				
			||||||
 | 
					        const { getSelectedModelConfig } = await import('../config/configuration_helpers.js');
 | 
				
			||||||
 | 
					        const modelConfig = await getSelectedModelConfig();
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        if (!modelConfig) {
 | 
				
			||||||
 | 
					            throw new Error("No valid AI model configuration found");
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        const pipelineInput = {
 | 
				
			||||||
 | 
					            messages: chat.messages.map(msg => ({
 | 
				
			||||||
 | 
					                role: msg.role as 'user' | 'assistant' | 'system',
 | 
				
			||||||
 | 
					                content: msg.content
 | 
				
			||||||
 | 
					            })),
 | 
				
			||||||
 | 
					            query: enhancedContent,
 | 
				
			||||||
 | 
					            noteId: undefined,
 | 
				
			||||||
 | 
					            showThinking: showThinking,
 | 
				
			||||||
 | 
					            options: {
 | 
				
			||||||
 | 
					                useAdvancedContext: useAdvancedContext === true,
 | 
				
			||||||
 | 
					                model: modelConfig.model,
 | 
				
			||||||
 | 
					                stream: true,
 | 
				
			||||||
 | 
					                chatNoteId: chatNoteId
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					            streamCallback: (data, done, rawChunk) => {
 | 
				
			||||||
 | 
					                const message = {
 | 
				
			||||||
 | 
					                    type: 'llm-stream' as const,
 | 
				
			||||||
 | 
					                    chatNoteId: chatNoteId,
 | 
				
			||||||
 | 
					                    done: done
 | 
				
			||||||
 | 
					                };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if (data) {
 | 
				
			||||||
 | 
					                    (message as any).content = data;
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if (rawChunk && 'thinking' in rawChunk && rawChunk.thinking) {
 | 
				
			||||||
 | 
					                    (message as any).thinking = rawChunk.thinking as string;
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if (rawChunk && 'toolExecution' in rawChunk && rawChunk.toolExecution) {
 | 
				
			||||||
 | 
					                    const toolExec = rawChunk.toolExecution;
 | 
				
			||||||
 | 
					                    (message as any).toolExecution = {
 | 
				
			||||||
 | 
					                        tool: typeof toolExec.tool === 'string' ? toolExec.tool : toolExec.tool?.name,
 | 
				
			||||||
 | 
					                        result: toolExec.result,
 | 
				
			||||||
 | 
					                        args: 'arguments' in toolExec ?
 | 
				
			||||||
 | 
					                            (typeof toolExec.arguments === 'object' ? toolExec.arguments as Record<string, unknown> : {}) : {},
 | 
				
			||||||
 | 
					                        action: 'action' in toolExec ? toolExec.action as string : undefined,
 | 
				
			||||||
 | 
					                        toolCallId: 'toolCallId' in toolExec ? toolExec.toolCallId as string : undefined,
 | 
				
			||||||
 | 
					                        error: 'error' in toolExec ? toolExec.error as string : undefined
 | 
				
			||||||
 | 
					                    };
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                wsService.sendMessageToAllClients(message);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                // Save final response when done
 | 
				
			||||||
 | 
					                if (done && data) {
 | 
				
			||||||
 | 
					                    chat.messages.push({
 | 
				
			||||||
 | 
					                        role: 'assistant',
 | 
				
			||||||
 | 
					                        content: data
 | 
				
			||||||
 | 
					                    });
 | 
				
			||||||
 | 
					                    chatStorageService.updateChat(chat.id, chat.messages, chat.title).catch(err => {
 | 
				
			||||||
 | 
					                        log.error(`Error saving streamed response: ${err}`);
 | 
				
			||||||
 | 
					                    });
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // Execute the pipeline
 | 
				
			||||||
 | 
					        await pipeline.execute(pipelineInput);
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					    } catch (error: any) {
 | 
				
			||||||
 | 
					        log.error(`Error in direct streaming: ${error.message}`);
 | 
				
			||||||
        wsService.sendMessageToAllClients({
 | 
					        wsService.sendMessageToAllClients({
 | 
				
			||||||
            type: 'llm-stream',
 | 
					            type: 'llm-stream',
 | 
				
			||||||
            chatNoteId: chatNoteId,
 | 
					            chatNoteId: chatNoteId,
 | 
				
			||||||
                error: `Error during streaming: ${streamError}`,
 | 
					            error: `Error during streaming: ${error.message}`,
 | 
				
			||||||
            done: true
 | 
					            done: true
 | 
				
			||||||
        });
 | 
					        });
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    } catch (error: any) {
 | 
					 | 
				
			||||||
        log.error(`Error starting message stream: ${error.message}`);
 | 
					 | 
				
			||||||
        log.error(`Error starting message stream, can't communicate via WebSocket: ${error.message}`);
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
export default {
 | 
					export default {
 | 
				
			||||||
 | 
				
			|||||||
@ -325,13 +325,13 @@ class RestChatService {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            const chat = await chatStorageService.getChat(sessionId);
 | 
					            const chat = await chatStorageService.getChat(sessionId);
 | 
				
			||||||
            if (!chat) {
 | 
					            if (!chat) {
 | 
				
			||||||
                res.status(404).json({
 | 
					                // Return error in Express route format [statusCode, response]
 | 
				
			||||||
 | 
					                return [404, {
 | 
				
			||||||
                    error: true,
 | 
					                    error: true,
 | 
				
			||||||
                    message: `Session with ID ${sessionId} not found`,
 | 
					                    message: `Session with ID ${sessionId} not found`,
 | 
				
			||||||
                    code: 'session_not_found',
 | 
					                    code: 'session_not_found',
 | 
				
			||||||
                    sessionId
 | 
					                    sessionId
 | 
				
			||||||
                });
 | 
					                }];
 | 
				
			||||||
                return null;
 | 
					 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            return {
 | 
					            return {
 | 
				
			||||||
@ -344,7 +344,7 @@ class RestChatService {
 | 
				
			|||||||
            };
 | 
					            };
 | 
				
			||||||
        } catch (error: any) {
 | 
					        } catch (error: any) {
 | 
				
			||||||
            log.error(`Error getting chat session: ${error.message || 'Unknown error'}`);
 | 
					            log.error(`Error getting chat session: ${error.message || 'Unknown error'}`);
 | 
				
			||||||
            throw new Error(`Failed to get session: ${error.message || 'Unknown error'}`);
 | 
					            return [500, { error: `Failed to get session: ${error.message || 'Unknown error'}` }];
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user