mirror of
https://github.com/zadam/trilium.git
synced 2026-01-03 21:24:24 +01:00
improve embedding precedence
This commit is contained in:
parent
37f1dcdaab
commit
3268c435e2
@ -1,7 +1,7 @@
|
||||
import sql from "../../../services/sql.js";
|
||||
import sql from "../../sql.js";
|
||||
import { randomString } from "../../../services/utils.js";
|
||||
import dateUtils from "../../../services/date_utils.js";
|
||||
import log from "../../../services/log.js";
|
||||
import log from "../../log.js";
|
||||
import { embeddingToBuffer, bufferToEmbedding, cosineSimilarity } from "./vector_utils.js";
|
||||
import type { EmbeddingResult } from "./types.js";
|
||||
import entityChangesService from "../../../services/entity_changes.js";
|
||||
@ -120,12 +120,17 @@ export async function findSimilarNotes(
|
||||
providerId: string,
|
||||
modelId: string,
|
||||
limit = 10,
|
||||
threshold?: number // Made optional to use constants
|
||||
threshold?: number, // Made optional to use constants
|
||||
useFallback = true // Whether to try other providers if no embeddings found
|
||||
): Promise<{noteId: string, similarity: number}[]> {
|
||||
// Import constants dynamically to avoid circular dependencies
|
||||
const { LLM_CONSTANTS } = await import('../../../routes/api/llm.js');
|
||||
// Use provided threshold or default from constants
|
||||
const similarityThreshold = threshold ?? LLM_CONSTANTS.SIMILARITY.DEFAULT_THRESHOLD;
|
||||
|
||||
// Add logging for debugging
|
||||
log.info(`Finding similar notes for provider: ${providerId}, model: ${modelId}`);
|
||||
|
||||
// Get all embeddings for the given provider and model
|
||||
const rows = await sql.getRows(`
|
||||
SELECT embedId, noteId, providerId, modelId, dimension, embedding
|
||||
@ -134,7 +139,103 @@ export async function findSimilarNotes(
|
||||
[providerId, modelId]
|
||||
);
|
||||
|
||||
if (!rows.length) {
|
||||
log.info(`Found ${rows.length} embeddings in database for provider: ${providerId}, model: ${modelId}`);
|
||||
|
||||
// If no embeddings found for this provider/model and fallback is enabled
|
||||
if (rows.length === 0 && useFallback) {
|
||||
log.info(`No embeddings found for ${providerId}/${modelId}. Attempting fallback...`);
|
||||
|
||||
// Define type for available embeddings
|
||||
interface EmbeddingMetadata {
|
||||
providerId: string;
|
||||
modelId: string;
|
||||
count: number;
|
||||
}
|
||||
|
||||
// Get all available embedding providers and models
|
||||
const availableEmbeddings = await sql.getRows(`
|
||||
SELECT DISTINCT providerId, modelId, COUNT(*) as count
|
||||
FROM note_embeddings
|
||||
GROUP BY providerId, modelId
|
||||
ORDER BY count DESC`
|
||||
) as EmbeddingMetadata[];
|
||||
|
||||
if (availableEmbeddings.length > 0) {
|
||||
log.info(`Available embeddings: ${JSON.stringify(availableEmbeddings)}`);
|
||||
|
||||
// Import the AIServiceManager to get provider precedence
|
||||
const { default: aiManager } = await import('../ai_service_manager.js');
|
||||
|
||||
// Get providers in user-defined precedence order
|
||||
// This uses the internal providerOrder property that's set from user preferences
|
||||
const availableProviderIds = availableEmbeddings.map(e => e.providerId);
|
||||
|
||||
// Get dedicated embedding provider precedence from options
|
||||
const options = (await import('../../options.js')).default;
|
||||
let preferredProviders: string[] = [];
|
||||
|
||||
const embeddingPrecedence = await options.getOption('embeddingProviderPrecedence');
|
||||
|
||||
if (embeddingPrecedence) {
|
||||
// Parse the precedence string (similar to aiProviderPrecedence parsing)
|
||||
if (embeddingPrecedence.startsWith('[') && embeddingPrecedence.endsWith(']')) {
|
||||
preferredProviders = JSON.parse(embeddingPrecedence);
|
||||
} else if (typeof embeddingPrecedence === 'string') {
|
||||
if (embeddingPrecedence.includes(',')) {
|
||||
preferredProviders = embeddingPrecedence.split(',').map(p => p.trim());
|
||||
} else {
|
||||
preferredProviders = [embeddingPrecedence];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fall back to the AI provider precedence if embedding-specific one isn't set
|
||||
// Get the AIServiceManager instance to access its properties
|
||||
const aiManagerInstance = aiManager.getInstance();
|
||||
|
||||
// @ts-ignore - Accessing private property
|
||||
preferredProviders = aiManagerInstance.providerOrder || ['openai', 'anthropic', 'ollama'];
|
||||
}
|
||||
|
||||
log.info(`Embedding provider precedence order: ${preferredProviders.join(', ')}`);
|
||||
|
||||
// Try each provider in order of precedence
|
||||
for (const provider of preferredProviders) {
|
||||
// Skip the original provider we already tried
|
||||
if (provider === providerId) continue;
|
||||
|
||||
// Skip providers that don't have embeddings
|
||||
if (!availableProviderIds.includes(provider)) continue;
|
||||
|
||||
// Find the model with the most embeddings for this provider
|
||||
const providerEmbeddings = availableEmbeddings.filter(e => e.providerId === provider);
|
||||
|
||||
if (providerEmbeddings.length > 0) {
|
||||
// Use the model with the most embeddings
|
||||
const bestModel = providerEmbeddings.sort((a, b) => b.count - a.count)[0];
|
||||
|
||||
log.info(`Trying fallback provider: ${provider}, model: ${bestModel.modelId}`);
|
||||
|
||||
// Recursive call with the new provider/model, but disable further fallbacks
|
||||
return findSimilarNotes(
|
||||
embedding,
|
||||
provider,
|
||||
bestModel.modelId,
|
||||
limit,
|
||||
threshold,
|
||||
false // Prevent infinite recursion
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
log.info(`No suitable fallback providers found. Available embeddings: ${JSON.stringify(availableEmbeddings)}`);
|
||||
} else {
|
||||
log.info(`No embeddings found in the database at all. You need to generate embeddings first.`);
|
||||
}
|
||||
|
||||
return [];
|
||||
} else if (rows.length === 0) {
|
||||
// No embeddings found and fallback disabled
|
||||
log.info(`No embeddings found for ${providerId}/${modelId} and fallback is disabled.`);
|
||||
return [];
|
||||
}
|
||||
|
||||
@ -149,10 +250,13 @@ export async function findSimilarNotes(
|
||||
});
|
||||
|
||||
// Filter by threshold and sort by similarity (highest first)
|
||||
return similarities
|
||||
const results = similarities
|
||||
.filter(item => item.similarity >= similarityThreshold)
|
||||
.sort((a, b) => b.similarity - a.similarity)
|
||||
.slice(0, limit);
|
||||
|
||||
log.info(`Returning ${results.length} similar notes with similarity >= ${similarityThreshold}`);
|
||||
return results;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user