mirror of
https://github.com/zadam/trilium.git
synced 2025-12-04 22:44:25 +01:00
feat(llm): update llm tests for update tool executions
This commit is contained in:
parent
778f13e2e6
commit
eb2ace41b0
@ -58,13 +58,37 @@ vi.mock("../../services/llm/ai_service_manager.js", () => ({
|
||||
default: mockAiServiceManager
|
||||
}));
|
||||
|
||||
// Mock chat pipeline
|
||||
const mockChatPipelineExecute = vi.fn();
|
||||
const MockChatPipeline = vi.fn().mockImplementation(() => ({
|
||||
execute: mockChatPipelineExecute
|
||||
// Mock simplified pipeline
|
||||
const mockPipelineExecute = vi.fn();
|
||||
vi.mock("../../services/llm/pipeline/simplified_pipeline.js", () => ({
|
||||
default: {
|
||||
execute: mockPipelineExecute
|
||||
}
|
||||
}));
|
||||
vi.mock("../../services/llm/pipeline/chat_pipeline.js", () => ({
|
||||
ChatPipeline: MockChatPipeline
|
||||
|
||||
// Mock logging service
|
||||
vi.mock("../../services/llm/pipeline/logging_service.js", () => ({
|
||||
default: {
|
||||
withRequestId: vi.fn(() => ({
|
||||
log: vi.fn()
|
||||
}))
|
||||
},
|
||||
LogLevel: {
|
||||
ERROR: 'error',
|
||||
WARN: 'warn',
|
||||
INFO: 'info',
|
||||
DEBUG: 'debug'
|
||||
}
|
||||
}));
|
||||
|
||||
// Mock tool registry
|
||||
vi.mock("../../services/llm/tools/tool_registry.js", () => ({
|
||||
default: {
|
||||
getTools: vi.fn(() => []),
|
||||
getTool: vi.fn(),
|
||||
executeTool: vi.fn(),
|
||||
initialize: vi.fn()
|
||||
}
|
||||
}));
|
||||
|
||||
// Mock configuration helpers
|
||||
@ -82,7 +106,10 @@ vi.mock("../../services/llm/pipeline/configuration_service.js", () => ({
|
||||
maxRetries: 3,
|
||||
timeout: 30000,
|
||||
enableSmartProcessing: true,
|
||||
maxToolIterations: 10
|
||||
maxToolIterations: 10,
|
||||
maxIterations: 10,
|
||||
enabled: true,
|
||||
parallelExecution: true
|
||||
})),
|
||||
getAIConfig: vi.fn(() => ({
|
||||
provider: 'test-provider',
|
||||
@ -90,11 +117,32 @@ vi.mock("../../services/llm/pipeline/configuration_service.js", () => ({
|
||||
})),
|
||||
getDebugConfig: vi.fn(() => ({
|
||||
enableMetrics: true,
|
||||
enableLogging: true
|
||||
enableLogging: true,
|
||||
enabled: true,
|
||||
logLevel: 'info',
|
||||
enableTracing: false
|
||||
})),
|
||||
getStreamingConfig: vi.fn(() => ({
|
||||
enableStreaming: true,
|
||||
chunkSize: 1024
|
||||
enabled: true,
|
||||
chunkSize: 1024,
|
||||
flushInterval: 100
|
||||
})),
|
||||
getDefaultSystemPrompt: vi.fn(() => 'You are a helpful assistant.'),
|
||||
getDefaultConfig: vi.fn(() => ({
|
||||
systemPrompt: 'You are a helpful assistant.',
|
||||
temperature: 0.7,
|
||||
maxTokens: 1000,
|
||||
topP: 1.0,
|
||||
presencePenalty: 0,
|
||||
frequencyPenalty: 0
|
||||
})),
|
||||
getDefaultCompletionOptions: vi.fn(() => ({
|
||||
temperature: 0.7,
|
||||
maxTokens: 1000,
|
||||
topP: 1.0,
|
||||
presencePenalty: 0,
|
||||
frequencyPenalty: 0
|
||||
}))
|
||||
}
|
||||
}));
|
||||
@ -384,7 +432,7 @@ describe("LLM API Tests", () => {
|
||||
|
||||
it("should initiate streaming for a chat message", async () => {
|
||||
// Setup streaming simulation
|
||||
mockChatPipelineExecute.mockImplementation(async (input) => {
|
||||
mockPipelineExecute.mockImplementation(async (input) => {
|
||||
const callback = input.streamCallback;
|
||||
// Simulate streaming chunks
|
||||
await callback('Hello', false, {});
|
||||
@ -498,7 +546,7 @@ describe("LLM API Tests", () => {
|
||||
}));
|
||||
|
||||
// Setup streaming with mention context
|
||||
mockChatPipelineExecute.mockImplementation(async (input) => {
|
||||
mockPipelineExecute.mockImplementation(async (input) => {
|
||||
// Verify mention content is included
|
||||
expect(input.query).toContain('Tell me about this note');
|
||||
expect(input.query).toContain('Root note content for testing');
|
||||
@ -541,7 +589,7 @@ describe("LLM API Tests", () => {
|
||||
});
|
||||
|
||||
it("should handle streaming with thinking states", async () => {
|
||||
mockChatPipelineExecute.mockImplementation(async (input) => {
|
||||
mockPipelineExecute.mockImplementation(async (input) => {
|
||||
const callback = input.streamCallback;
|
||||
// Simulate thinking states
|
||||
await callback('', false, { thinking: 'Analyzing the question...' });
|
||||
@ -581,7 +629,7 @@ describe("LLM API Tests", () => {
|
||||
});
|
||||
|
||||
it("should handle streaming with tool executions", async () => {
|
||||
mockChatPipelineExecute.mockImplementation(async (input) => {
|
||||
mockPipelineExecute.mockImplementation(async (input) => {
|
||||
const callback = input.streamCallback;
|
||||
// Simulate tool execution with standardized response format
|
||||
await callback('Let me calculate that', false, {});
|
||||
@ -648,7 +696,7 @@ describe("LLM API Tests", () => {
|
||||
});
|
||||
|
||||
it("should handle streaming errors gracefully", async () => {
|
||||
mockChatPipelineExecute.mockRejectedValue(new Error('Pipeline error'));
|
||||
mockPipelineExecute.mockRejectedValue(new Error('Pipeline error'));
|
||||
|
||||
const response = await supertest(app)
|
||||
.post(`/api/llm/chat/${testChatId}/messages/stream`)
|
||||
@ -703,7 +751,7 @@ describe("LLM API Tests", () => {
|
||||
|
||||
it("should save chat messages after streaming completion", async () => {
|
||||
const completeResponse = 'This is the complete response';
|
||||
mockChatPipelineExecute.mockImplementation(async (input) => {
|
||||
mockPipelineExecute.mockImplementation(async (input) => {
|
||||
const callback = input.streamCallback;
|
||||
await callback(completeResponse, true, {});
|
||||
});
|
||||
@ -723,12 +771,12 @@ describe("LLM API Tests", () => {
|
||||
// Note: Due to the mocked environment, the actual chat storage might not be called
|
||||
// This test verifies the streaming endpoint works correctly
|
||||
// The actual chat storage behavior is tested in the service layer tests
|
||||
expect(mockChatPipelineExecute).toHaveBeenCalled();
|
||||
expect(mockPipelineExecute).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should handle rapid consecutive streaming requests", async () => {
|
||||
let callCount = 0;
|
||||
mockChatPipelineExecute.mockImplementation(async (input) => {
|
||||
mockPipelineExecute.mockImplementation(async (input) => {
|
||||
callCount++;
|
||||
const callback = input.streamCallback;
|
||||
await callback(`Response ${callCount}`, true, {});
|
||||
@ -755,12 +803,12 @@ describe("LLM API Tests", () => {
|
||||
});
|
||||
|
||||
// Verify all were processed
|
||||
expect(mockChatPipelineExecute).toHaveBeenCalledTimes(3);
|
||||
expect(mockPipelineExecute).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
|
||||
it("should handle large streaming responses", async () => {
|
||||
const largeContent = 'x'.repeat(10000); // 10KB of content
|
||||
mockChatPipelineExecute.mockImplementation(async (input) => {
|
||||
mockPipelineExecute.mockImplementation(async (input) => {
|
||||
const callback = input.streamCallback;
|
||||
// Simulate chunked delivery of large content
|
||||
for (let i = 0; i < 10; i++) {
|
||||
|
||||
@ -42,16 +42,32 @@ interface PipelineConfig {
|
||||
* Simplified Chat Pipeline Implementation
|
||||
*/
|
||||
export class SimplifiedChatPipeline {
|
||||
private config: PipelineConfig;
|
||||
private config: PipelineConfig | null = null;
|
||||
private metrics: Map<string, number> = new Map();
|
||||
|
||||
constructor() {
|
||||
// Load configuration from centralized service
|
||||
this.config = {
|
||||
maxToolIterations: configurationService.getToolConfig().maxIterations,
|
||||
enableMetrics: configurationService.getDebugConfig().enableMetrics,
|
||||
enableStreaming: configurationService.getStreamingConfig().enabled
|
||||
};
|
||||
// Configuration will be loaded lazily on first use
|
||||
}
|
||||
|
||||
private getConfig(): PipelineConfig {
|
||||
if (!this.config) {
|
||||
try {
|
||||
// Load configuration from centralized service
|
||||
this.config = {
|
||||
maxToolIterations: configurationService.getToolConfig().maxIterations,
|
||||
enableMetrics: configurationService.getDebugConfig().enableMetrics,
|
||||
enableStreaming: configurationService.getStreamingConfig().enabled
|
||||
};
|
||||
} catch (error) {
|
||||
// Use defaults if configuration not available
|
||||
this.config = {
|
||||
maxToolIterations: 5,
|
||||
enableMetrics: false,
|
||||
enableStreaming: true
|
||||
};
|
||||
}
|
||||
}
|
||||
return this.config;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -83,7 +99,7 @@ export class SimplifiedChatPipeline {
|
||||
const processedResponse = await this.processResponse(finalResponse, input, logger);
|
||||
|
||||
// Record metrics
|
||||
if (this.config.enableMetrics) {
|
||||
if (this.getConfig().enableMetrics) {
|
||||
this.recordMetric('pipeline_duration', Date.now() - startTime);
|
||||
}
|
||||
|
||||
@ -164,7 +180,7 @@ export class SimplifiedChatPipeline {
|
||||
const options: ChatCompletionOptions = {
|
||||
...configurationService.getDefaultCompletionOptions(),
|
||||
...input.options,
|
||||
stream: this.config.enableStreaming && !!input.streamCallback
|
||||
stream: this.getConfig().enableStreaming && !!input.streamCallback
|
||||
};
|
||||
|
||||
// Add tools if enabled
|
||||
@ -219,9 +235,9 @@ export class SimplifiedChatPipeline {
|
||||
let currentMessages = [...messages];
|
||||
let iterations = 0;
|
||||
|
||||
while (iterations < this.config.maxToolIterations && currentResponse.tool_calls?.length) {
|
||||
while (iterations < this.getConfig().maxToolIterations && currentResponse.tool_calls?.length) {
|
||||
iterations++;
|
||||
logger.log(LogLevel.DEBUG, `Tool iteration ${iterations}/${this.config.maxToolIterations}`);
|
||||
logger.log(LogLevel.DEBUG, `Tool iteration ${iterations}/${this.getConfig().maxToolIterations}`);
|
||||
|
||||
// Add assistant message with tool calls
|
||||
currentMessages.push({
|
||||
@ -262,9 +278,9 @@ export class SimplifiedChatPipeline {
|
||||
}
|
||||
}
|
||||
|
||||
if (iterations >= this.config.maxToolIterations) {
|
||||
if (iterations >= this.getConfig().maxToolIterations) {
|
||||
logger.log(LogLevel.WARN, 'Maximum tool iterations reached', {
|
||||
iterations: this.config.maxToolIterations
|
||||
iterations: this.getConfig().maxToolIterations
|
||||
});
|
||||
}
|
||||
|
||||
@ -400,7 +416,7 @@ export class SimplifiedChatPipeline {
|
||||
* Record a metric
|
||||
*/
|
||||
private recordMetric(name: string, value: number): void {
|
||||
if (!this.config.enableMetrics) return;
|
||||
if (!this.getConfig().enableMetrics) return;
|
||||
|
||||
const current = this.metrics.get(name) || 0;
|
||||
const count = this.metrics.get(`${name}_count`) || 0;
|
||||
|
||||
@ -195,24 +195,6 @@ describe('OllamaService', () => {
|
||||
OllamaMock.mockImplementation(() => mockOllamaInstance);
|
||||
|
||||
service = new OllamaService();
|
||||
|
||||
// Replace the formatter with a mock after construction
|
||||
(service as any).formatter = {
|
||||
formatMessages: vi.fn().mockReturnValue([
|
||||
{ role: 'user', content: 'Hello' }
|
||||
]),
|
||||
formatResponse: vi.fn().mockReturnValue({
|
||||
text: 'Hello! How can I help you today?',
|
||||
provider: 'Ollama',
|
||||
model: 'llama2',
|
||||
usage: {
|
||||
promptTokens: 5,
|
||||
completionTokens: 10,
|
||||
totalTokens: 15
|
||||
},
|
||||
tool_calls: null
|
||||
})
|
||||
};
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@ -220,10 +202,9 @@ describe('OllamaService', () => {
|
||||
});
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with provider name and formatter', () => {
|
||||
it('should initialize with provider name', () => {
|
||||
expect(service).toBeDefined();
|
||||
expect((service as any).name).toBe('Ollama');
|
||||
expect((service as any).formatter).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
@ -487,7 +468,7 @@ describe('OllamaService', () => {
|
||||
expect(result.tool_calls).toHaveLength(1);
|
||||
});
|
||||
|
||||
it('should format messages using the formatter', async () => {
|
||||
it('should pass messages to Ollama client', async () => {
|
||||
vi.mocked(options.getOption).mockReturnValue('http://localhost:11434');
|
||||
|
||||
const mockOptions = {
|
||||
@ -497,17 +478,15 @@ describe('OllamaService', () => {
|
||||
};
|
||||
vi.mocked(providers.getOllamaOptions).mockResolvedValueOnce(mockOptions);
|
||||
|
||||
const formattedMessages = [{ role: 'user', content: 'Hello' }];
|
||||
(service as any).formatter.formatMessages.mockReturnValueOnce(formattedMessages);
|
||||
|
||||
const chatSpy = vi.spyOn(mockOllamaInstance, 'chat');
|
||||
|
||||
await service.generateChatCompletion(messages);
|
||||
|
||||
expect((service as any).formatter.formatMessages).toHaveBeenCalled();
|
||||
expect(chatSpy).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
messages: formattedMessages
|
||||
messages: expect.arrayContaining([
|
||||
expect.objectContaining({ role: 'user', content: 'Hello' })
|
||||
])
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user