From 3268c435e2699928e7654a8ad39069c335fa1492 Mon Sep 17 00:00:00 2001 From: perf3ct Date: Mon, 17 Mar 2025 21:03:42 +0000 Subject: [PATCH] improve embedding precedence --- src/services/llm/embeddings/storage.ts | 114 +++++++++++++++++++++++-- 1 file changed, 109 insertions(+), 5 deletions(-) diff --git a/src/services/llm/embeddings/storage.ts b/src/services/llm/embeddings/storage.ts index 06d74f2fe..7192e38d6 100644 --- a/src/services/llm/embeddings/storage.ts +++ b/src/services/llm/embeddings/storage.ts @@ -1,7 +1,7 @@ -import sql from "../../../services/sql.js"; +import sql from "../../sql.js"; import { randomString } from "../../../services/utils.js"; import dateUtils from "../../../services/date_utils.js"; -import log from "../../../services/log.js"; +import log from "../../log.js"; import { embeddingToBuffer, bufferToEmbedding, cosineSimilarity } from "./vector_utils.js"; import type { EmbeddingResult } from "./types.js"; import entityChangesService from "../../../services/entity_changes.js"; @@ -120,12 +120,17 @@ export async function findSimilarNotes( providerId: string, modelId: string, limit = 10, - threshold?: number // Made optional to use constants + threshold?: number, // Made optional to use constants + useFallback = true // Whether to try other providers if no embeddings found ): Promise<{noteId: string, similarity: number}[]> { // Import constants dynamically to avoid circular dependencies const { LLM_CONSTANTS } = await import('../../../routes/api/llm.js'); // Use provided threshold or default from constants const similarityThreshold = threshold ?? LLM_CONSTANTS.SIMILARITY.DEFAULT_THRESHOLD; + + // Add logging for debugging + log.info(`Finding similar notes for provider: ${providerId}, model: ${modelId}`); + // Get all embeddings for the given provider and model const rows = await sql.getRows(` SELECT embedId, noteId, providerId, modelId, dimension, embedding @@ -134,7 +139,103 @@ export async function findSimilarNotes( [providerId, modelId] ); - if (!rows.length) { + log.info(`Found ${rows.length} embeddings in database for provider: ${providerId}, model: ${modelId}`); + + // If no embeddings found for this provider/model and fallback is enabled + if (rows.length === 0 && useFallback) { + log.info(`No embeddings found for ${providerId}/${modelId}. Attempting fallback...`); + + // Define type for available embeddings + interface EmbeddingMetadata { + providerId: string; + modelId: string; + count: number; + } + + // Get all available embedding providers and models + const availableEmbeddings = await sql.getRows(` + SELECT DISTINCT providerId, modelId, COUNT(*) as count + FROM note_embeddings + GROUP BY providerId, modelId + ORDER BY count DESC` + ) as EmbeddingMetadata[]; + + if (availableEmbeddings.length > 0) { + log.info(`Available embeddings: ${JSON.stringify(availableEmbeddings)}`); + + // Import the AIServiceManager to get provider precedence + const { default: aiManager } = await import('../ai_service_manager.js'); + + // Get providers in user-defined precedence order + // This uses the internal providerOrder property that's set from user preferences + const availableProviderIds = availableEmbeddings.map(e => e.providerId); + + // Get dedicated embedding provider precedence from options + const options = (await import('../../options.js')).default; + let preferredProviders: string[] = []; + + const embeddingPrecedence = await options.getOption('embeddingProviderPrecedence'); + + if (embeddingPrecedence) { + // Parse the precedence string (similar to aiProviderPrecedence parsing) + if (embeddingPrecedence.startsWith('[') && embeddingPrecedence.endsWith(']')) { + preferredProviders = JSON.parse(embeddingPrecedence); + } else if (typeof embeddingPrecedence === 'string') { + if (embeddingPrecedence.includes(',')) { + preferredProviders = embeddingPrecedence.split(',').map(p => p.trim()); + } else { + preferredProviders = [embeddingPrecedence]; + } + } + } else { + // Fall back to the AI provider precedence if embedding-specific one isn't set + // Get the AIServiceManager instance to access its properties + const aiManagerInstance = aiManager.getInstance(); + + // @ts-ignore - Accessing private property + preferredProviders = aiManagerInstance.providerOrder || ['openai', 'anthropic', 'ollama']; + } + + log.info(`Embedding provider precedence order: ${preferredProviders.join(', ')}`); + + // Try each provider in order of precedence + for (const provider of preferredProviders) { + // Skip the original provider we already tried + if (provider === providerId) continue; + + // Skip providers that don't have embeddings + if (!availableProviderIds.includes(provider)) continue; + + // Find the model with the most embeddings for this provider + const providerEmbeddings = availableEmbeddings.filter(e => e.providerId === provider); + + if (providerEmbeddings.length > 0) { + // Use the model with the most embeddings + const bestModel = providerEmbeddings.sort((a, b) => b.count - a.count)[0]; + + log.info(`Trying fallback provider: ${provider}, model: ${bestModel.modelId}`); + + // Recursive call with the new provider/model, but disable further fallbacks + return findSimilarNotes( + embedding, + provider, + bestModel.modelId, + limit, + threshold, + false // Prevent infinite recursion + ); + } + } + + log.info(`No suitable fallback providers found. Available embeddings: ${JSON.stringify(availableEmbeddings)}`); + } else { + log.info(`No embeddings found in the database at all. You need to generate embeddings first.`); + } + + return []; + } else if (rows.length === 0) { + // No embeddings found and fallback disabled + log.info(`No embeddings found for ${providerId}/${modelId} and fallback is disabled.`); return []; } @@ -149,10 +250,13 @@ export async function findSimilarNotes( }); // Filter by threshold and sort by similarity (highest first) - return similarities + const results = similarities .filter(item => item.similarity >= similarityThreshold) .sort((a, b) => b.similarity - a.similarity) .slice(0, limit); + + log.info(`Returning ${results.length} similar notes with similarity >= ${similarityThreshold}`); + return results; } /**