mirror of
https://github.com/zadam/trilium.git
synced 2025-12-04 22:44:25 +01:00
feat(llm): update llm tests for update tool executions
This commit is contained in:
parent
778f13e2e6
commit
eb2ace41b0
@ -58,13 +58,37 @@ vi.mock("../../services/llm/ai_service_manager.js", () => ({
|
|||||||
default: mockAiServiceManager
|
default: mockAiServiceManager
|
||||||
}));
|
}));
|
||||||
|
|
||||||
// Mock chat pipeline
|
// Mock simplified pipeline
|
||||||
const mockChatPipelineExecute = vi.fn();
|
const mockPipelineExecute = vi.fn();
|
||||||
const MockChatPipeline = vi.fn().mockImplementation(() => ({
|
vi.mock("../../services/llm/pipeline/simplified_pipeline.js", () => ({
|
||||||
execute: mockChatPipelineExecute
|
default: {
|
||||||
|
execute: mockPipelineExecute
|
||||||
|
}
|
||||||
}));
|
}));
|
||||||
vi.mock("../../services/llm/pipeline/chat_pipeline.js", () => ({
|
|
||||||
ChatPipeline: MockChatPipeline
|
// Mock logging service
|
||||||
|
vi.mock("../../services/llm/pipeline/logging_service.js", () => ({
|
||||||
|
default: {
|
||||||
|
withRequestId: vi.fn(() => ({
|
||||||
|
log: vi.fn()
|
||||||
|
}))
|
||||||
|
},
|
||||||
|
LogLevel: {
|
||||||
|
ERROR: 'error',
|
||||||
|
WARN: 'warn',
|
||||||
|
INFO: 'info',
|
||||||
|
DEBUG: 'debug'
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
|
||||||
|
// Mock tool registry
|
||||||
|
vi.mock("../../services/llm/tools/tool_registry.js", () => ({
|
||||||
|
default: {
|
||||||
|
getTools: vi.fn(() => []),
|
||||||
|
getTool: vi.fn(),
|
||||||
|
executeTool: vi.fn(),
|
||||||
|
initialize: vi.fn()
|
||||||
|
}
|
||||||
}));
|
}));
|
||||||
|
|
||||||
// Mock configuration helpers
|
// Mock configuration helpers
|
||||||
@ -82,7 +106,10 @@ vi.mock("../../services/llm/pipeline/configuration_service.js", () => ({
|
|||||||
maxRetries: 3,
|
maxRetries: 3,
|
||||||
timeout: 30000,
|
timeout: 30000,
|
||||||
enableSmartProcessing: true,
|
enableSmartProcessing: true,
|
||||||
maxToolIterations: 10
|
maxToolIterations: 10,
|
||||||
|
maxIterations: 10,
|
||||||
|
enabled: true,
|
||||||
|
parallelExecution: true
|
||||||
})),
|
})),
|
||||||
getAIConfig: vi.fn(() => ({
|
getAIConfig: vi.fn(() => ({
|
||||||
provider: 'test-provider',
|
provider: 'test-provider',
|
||||||
@ -90,11 +117,32 @@ vi.mock("../../services/llm/pipeline/configuration_service.js", () => ({
|
|||||||
})),
|
})),
|
||||||
getDebugConfig: vi.fn(() => ({
|
getDebugConfig: vi.fn(() => ({
|
||||||
enableMetrics: true,
|
enableMetrics: true,
|
||||||
enableLogging: true
|
enableLogging: true,
|
||||||
|
enabled: true,
|
||||||
|
logLevel: 'info',
|
||||||
|
enableTracing: false
|
||||||
})),
|
})),
|
||||||
getStreamingConfig: vi.fn(() => ({
|
getStreamingConfig: vi.fn(() => ({
|
||||||
enableStreaming: true,
|
enableStreaming: true,
|
||||||
chunkSize: 1024
|
enabled: true,
|
||||||
|
chunkSize: 1024,
|
||||||
|
flushInterval: 100
|
||||||
|
})),
|
||||||
|
getDefaultSystemPrompt: vi.fn(() => 'You are a helpful assistant.'),
|
||||||
|
getDefaultConfig: vi.fn(() => ({
|
||||||
|
systemPrompt: 'You are a helpful assistant.',
|
||||||
|
temperature: 0.7,
|
||||||
|
maxTokens: 1000,
|
||||||
|
topP: 1.0,
|
||||||
|
presencePenalty: 0,
|
||||||
|
frequencyPenalty: 0
|
||||||
|
})),
|
||||||
|
getDefaultCompletionOptions: vi.fn(() => ({
|
||||||
|
temperature: 0.7,
|
||||||
|
maxTokens: 1000,
|
||||||
|
topP: 1.0,
|
||||||
|
presencePenalty: 0,
|
||||||
|
frequencyPenalty: 0
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
}));
|
}));
|
||||||
@ -384,7 +432,7 @@ describe("LLM API Tests", () => {
|
|||||||
|
|
||||||
it("should initiate streaming for a chat message", async () => {
|
it("should initiate streaming for a chat message", async () => {
|
||||||
// Setup streaming simulation
|
// Setup streaming simulation
|
||||||
mockChatPipelineExecute.mockImplementation(async (input) => {
|
mockPipelineExecute.mockImplementation(async (input) => {
|
||||||
const callback = input.streamCallback;
|
const callback = input.streamCallback;
|
||||||
// Simulate streaming chunks
|
// Simulate streaming chunks
|
||||||
await callback('Hello', false, {});
|
await callback('Hello', false, {});
|
||||||
@ -498,7 +546,7 @@ describe("LLM API Tests", () => {
|
|||||||
}));
|
}));
|
||||||
|
|
||||||
// Setup streaming with mention context
|
// Setup streaming with mention context
|
||||||
mockChatPipelineExecute.mockImplementation(async (input) => {
|
mockPipelineExecute.mockImplementation(async (input) => {
|
||||||
// Verify mention content is included
|
// Verify mention content is included
|
||||||
expect(input.query).toContain('Tell me about this note');
|
expect(input.query).toContain('Tell me about this note');
|
||||||
expect(input.query).toContain('Root note content for testing');
|
expect(input.query).toContain('Root note content for testing');
|
||||||
@ -541,7 +589,7 @@ describe("LLM API Tests", () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
it("should handle streaming with thinking states", async () => {
|
it("should handle streaming with thinking states", async () => {
|
||||||
mockChatPipelineExecute.mockImplementation(async (input) => {
|
mockPipelineExecute.mockImplementation(async (input) => {
|
||||||
const callback = input.streamCallback;
|
const callback = input.streamCallback;
|
||||||
// Simulate thinking states
|
// Simulate thinking states
|
||||||
await callback('', false, { thinking: 'Analyzing the question...' });
|
await callback('', false, { thinking: 'Analyzing the question...' });
|
||||||
@ -581,7 +629,7 @@ describe("LLM API Tests", () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
it("should handle streaming with tool executions", async () => {
|
it("should handle streaming with tool executions", async () => {
|
||||||
mockChatPipelineExecute.mockImplementation(async (input) => {
|
mockPipelineExecute.mockImplementation(async (input) => {
|
||||||
const callback = input.streamCallback;
|
const callback = input.streamCallback;
|
||||||
// Simulate tool execution with standardized response format
|
// Simulate tool execution with standardized response format
|
||||||
await callback('Let me calculate that', false, {});
|
await callback('Let me calculate that', false, {});
|
||||||
@ -648,7 +696,7 @@ describe("LLM API Tests", () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
it("should handle streaming errors gracefully", async () => {
|
it("should handle streaming errors gracefully", async () => {
|
||||||
mockChatPipelineExecute.mockRejectedValue(new Error('Pipeline error'));
|
mockPipelineExecute.mockRejectedValue(new Error('Pipeline error'));
|
||||||
|
|
||||||
const response = await supertest(app)
|
const response = await supertest(app)
|
||||||
.post(`/api/llm/chat/${testChatId}/messages/stream`)
|
.post(`/api/llm/chat/${testChatId}/messages/stream`)
|
||||||
@ -703,7 +751,7 @@ describe("LLM API Tests", () => {
|
|||||||
|
|
||||||
it("should save chat messages after streaming completion", async () => {
|
it("should save chat messages after streaming completion", async () => {
|
||||||
const completeResponse = 'This is the complete response';
|
const completeResponse = 'This is the complete response';
|
||||||
mockChatPipelineExecute.mockImplementation(async (input) => {
|
mockPipelineExecute.mockImplementation(async (input) => {
|
||||||
const callback = input.streamCallback;
|
const callback = input.streamCallback;
|
||||||
await callback(completeResponse, true, {});
|
await callback(completeResponse, true, {});
|
||||||
});
|
});
|
||||||
@ -723,12 +771,12 @@ describe("LLM API Tests", () => {
|
|||||||
// Note: Due to the mocked environment, the actual chat storage might not be called
|
// Note: Due to the mocked environment, the actual chat storage might not be called
|
||||||
// This test verifies the streaming endpoint works correctly
|
// This test verifies the streaming endpoint works correctly
|
||||||
// The actual chat storage behavior is tested in the service layer tests
|
// The actual chat storage behavior is tested in the service layer tests
|
||||||
expect(mockChatPipelineExecute).toHaveBeenCalled();
|
expect(mockPipelineExecute).toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
|
|
||||||
it("should handle rapid consecutive streaming requests", async () => {
|
it("should handle rapid consecutive streaming requests", async () => {
|
||||||
let callCount = 0;
|
let callCount = 0;
|
||||||
mockChatPipelineExecute.mockImplementation(async (input) => {
|
mockPipelineExecute.mockImplementation(async (input) => {
|
||||||
callCount++;
|
callCount++;
|
||||||
const callback = input.streamCallback;
|
const callback = input.streamCallback;
|
||||||
await callback(`Response ${callCount}`, true, {});
|
await callback(`Response ${callCount}`, true, {});
|
||||||
@ -755,12 +803,12 @@ describe("LLM API Tests", () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
// Verify all were processed
|
// Verify all were processed
|
||||||
expect(mockChatPipelineExecute).toHaveBeenCalledTimes(3);
|
expect(mockPipelineExecute).toHaveBeenCalledTimes(3);
|
||||||
});
|
});
|
||||||
|
|
||||||
it("should handle large streaming responses", async () => {
|
it("should handle large streaming responses", async () => {
|
||||||
const largeContent = 'x'.repeat(10000); // 10KB of content
|
const largeContent = 'x'.repeat(10000); // 10KB of content
|
||||||
mockChatPipelineExecute.mockImplementation(async (input) => {
|
mockPipelineExecute.mockImplementation(async (input) => {
|
||||||
const callback = input.streamCallback;
|
const callback = input.streamCallback;
|
||||||
// Simulate chunked delivery of large content
|
// Simulate chunked delivery of large content
|
||||||
for (let i = 0; i < 10; i++) {
|
for (let i = 0; i < 10; i++) {
|
||||||
|
|||||||
@ -42,16 +42,32 @@ interface PipelineConfig {
|
|||||||
* Simplified Chat Pipeline Implementation
|
* Simplified Chat Pipeline Implementation
|
||||||
*/
|
*/
|
||||||
export class SimplifiedChatPipeline {
|
export class SimplifiedChatPipeline {
|
||||||
private config: PipelineConfig;
|
private config: PipelineConfig | null = null;
|
||||||
private metrics: Map<string, number> = new Map();
|
private metrics: Map<string, number> = new Map();
|
||||||
|
|
||||||
constructor() {
|
constructor() {
|
||||||
// Load configuration from centralized service
|
// Configuration will be loaded lazily on first use
|
||||||
this.config = {
|
}
|
||||||
maxToolIterations: configurationService.getToolConfig().maxIterations,
|
|
||||||
enableMetrics: configurationService.getDebugConfig().enableMetrics,
|
private getConfig(): PipelineConfig {
|
||||||
enableStreaming: configurationService.getStreamingConfig().enabled
|
if (!this.config) {
|
||||||
};
|
try {
|
||||||
|
// Load configuration from centralized service
|
||||||
|
this.config = {
|
||||||
|
maxToolIterations: configurationService.getToolConfig().maxIterations,
|
||||||
|
enableMetrics: configurationService.getDebugConfig().enableMetrics,
|
||||||
|
enableStreaming: configurationService.getStreamingConfig().enabled
|
||||||
|
};
|
||||||
|
} catch (error) {
|
||||||
|
// Use defaults if configuration not available
|
||||||
|
this.config = {
|
||||||
|
maxToolIterations: 5,
|
||||||
|
enableMetrics: false,
|
||||||
|
enableStreaming: true
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return this.config;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -83,7 +99,7 @@ export class SimplifiedChatPipeline {
|
|||||||
const processedResponse = await this.processResponse(finalResponse, input, logger);
|
const processedResponse = await this.processResponse(finalResponse, input, logger);
|
||||||
|
|
||||||
// Record metrics
|
// Record metrics
|
||||||
if (this.config.enableMetrics) {
|
if (this.getConfig().enableMetrics) {
|
||||||
this.recordMetric('pipeline_duration', Date.now() - startTime);
|
this.recordMetric('pipeline_duration', Date.now() - startTime);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -164,7 +180,7 @@ export class SimplifiedChatPipeline {
|
|||||||
const options: ChatCompletionOptions = {
|
const options: ChatCompletionOptions = {
|
||||||
...configurationService.getDefaultCompletionOptions(),
|
...configurationService.getDefaultCompletionOptions(),
|
||||||
...input.options,
|
...input.options,
|
||||||
stream: this.config.enableStreaming && !!input.streamCallback
|
stream: this.getConfig().enableStreaming && !!input.streamCallback
|
||||||
};
|
};
|
||||||
|
|
||||||
// Add tools if enabled
|
// Add tools if enabled
|
||||||
@ -219,9 +235,9 @@ export class SimplifiedChatPipeline {
|
|||||||
let currentMessages = [...messages];
|
let currentMessages = [...messages];
|
||||||
let iterations = 0;
|
let iterations = 0;
|
||||||
|
|
||||||
while (iterations < this.config.maxToolIterations && currentResponse.tool_calls?.length) {
|
while (iterations < this.getConfig().maxToolIterations && currentResponse.tool_calls?.length) {
|
||||||
iterations++;
|
iterations++;
|
||||||
logger.log(LogLevel.DEBUG, `Tool iteration ${iterations}/${this.config.maxToolIterations}`);
|
logger.log(LogLevel.DEBUG, `Tool iteration ${iterations}/${this.getConfig().maxToolIterations}`);
|
||||||
|
|
||||||
// Add assistant message with tool calls
|
// Add assistant message with tool calls
|
||||||
currentMessages.push({
|
currentMessages.push({
|
||||||
@ -262,9 +278,9 @@ export class SimplifiedChatPipeline {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (iterations >= this.config.maxToolIterations) {
|
if (iterations >= this.getConfig().maxToolIterations) {
|
||||||
logger.log(LogLevel.WARN, 'Maximum tool iterations reached', {
|
logger.log(LogLevel.WARN, 'Maximum tool iterations reached', {
|
||||||
iterations: this.config.maxToolIterations
|
iterations: this.getConfig().maxToolIterations
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -400,7 +416,7 @@ export class SimplifiedChatPipeline {
|
|||||||
* Record a metric
|
* Record a metric
|
||||||
*/
|
*/
|
||||||
private recordMetric(name: string, value: number): void {
|
private recordMetric(name: string, value: number): void {
|
||||||
if (!this.config.enableMetrics) return;
|
if (!this.getConfig().enableMetrics) return;
|
||||||
|
|
||||||
const current = this.metrics.get(name) || 0;
|
const current = this.metrics.get(name) || 0;
|
||||||
const count = this.metrics.get(`${name}_count`) || 0;
|
const count = this.metrics.get(`${name}_count`) || 0;
|
||||||
|
|||||||
@ -195,24 +195,6 @@ describe('OllamaService', () => {
|
|||||||
OllamaMock.mockImplementation(() => mockOllamaInstance);
|
OllamaMock.mockImplementation(() => mockOllamaInstance);
|
||||||
|
|
||||||
service = new OllamaService();
|
service = new OllamaService();
|
||||||
|
|
||||||
// Replace the formatter with a mock after construction
|
|
||||||
(service as any).formatter = {
|
|
||||||
formatMessages: vi.fn().mockReturnValue([
|
|
||||||
{ role: 'user', content: 'Hello' }
|
|
||||||
]),
|
|
||||||
formatResponse: vi.fn().mockReturnValue({
|
|
||||||
text: 'Hello! How can I help you today?',
|
|
||||||
provider: 'Ollama',
|
|
||||||
model: 'llama2',
|
|
||||||
usage: {
|
|
||||||
promptTokens: 5,
|
|
||||||
completionTokens: 10,
|
|
||||||
totalTokens: 15
|
|
||||||
},
|
|
||||||
tool_calls: null
|
|
||||||
})
|
|
||||||
};
|
|
||||||
});
|
});
|
||||||
|
|
||||||
afterEach(() => {
|
afterEach(() => {
|
||||||
@ -220,10 +202,9 @@ describe('OllamaService', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
describe('constructor', () => {
|
describe('constructor', () => {
|
||||||
it('should initialize with provider name and formatter', () => {
|
it('should initialize with provider name', () => {
|
||||||
expect(service).toBeDefined();
|
expect(service).toBeDefined();
|
||||||
expect((service as any).name).toBe('Ollama');
|
expect((service as any).name).toBe('Ollama');
|
||||||
expect((service as any).formatter).toBeDefined();
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -487,7 +468,7 @@ describe('OllamaService', () => {
|
|||||||
expect(result.tool_calls).toHaveLength(1);
|
expect(result.tool_calls).toHaveLength(1);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should format messages using the formatter', async () => {
|
it('should pass messages to Ollama client', async () => {
|
||||||
vi.mocked(options.getOption).mockReturnValue('http://localhost:11434');
|
vi.mocked(options.getOption).mockReturnValue('http://localhost:11434');
|
||||||
|
|
||||||
const mockOptions = {
|
const mockOptions = {
|
||||||
@ -497,17 +478,15 @@ describe('OllamaService', () => {
|
|||||||
};
|
};
|
||||||
vi.mocked(providers.getOllamaOptions).mockResolvedValueOnce(mockOptions);
|
vi.mocked(providers.getOllamaOptions).mockResolvedValueOnce(mockOptions);
|
||||||
|
|
||||||
const formattedMessages = [{ role: 'user', content: 'Hello' }];
|
|
||||||
(service as any).formatter.formatMessages.mockReturnValueOnce(formattedMessages);
|
|
||||||
|
|
||||||
const chatSpy = vi.spyOn(mockOllamaInstance, 'chat');
|
const chatSpy = vi.spyOn(mockOllamaInstance, 'chat');
|
||||||
|
|
||||||
await service.generateChatCompletion(messages);
|
await service.generateChatCompletion(messages);
|
||||||
|
|
||||||
expect((service as any).formatter.formatMessages).toHaveBeenCalled();
|
|
||||||
expect(chatSpy).toHaveBeenCalledWith(
|
expect(chatSpy).toHaveBeenCalledWith(
|
||||||
expect.objectContaining({
|
expect.objectContaining({
|
||||||
messages: formattedMessages
|
messages: expect.arrayContaining([
|
||||||
|
expect.objectContaining({ role: 'user', content: 'Hello' })
|
||||||
|
])
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user