From 1361e4d43839c213278100f8c439c869d7eb0d70 Mon Sep 17 00:00:00 2001 From: perf3ct Date: Sat, 8 Mar 2025 22:04:10 +0000 Subject: [PATCH] set up embedding API endpoints --- src/routes/api/embeddings.ts | 233 +++++++++++++++++++++++++++++++++++ src/routes/routes.ts | 9 ++ 2 files changed, 242 insertions(+) create mode 100644 src/routes/api/embeddings.ts diff --git a/src/routes/api/embeddings.ts b/src/routes/api/embeddings.ts new file mode 100644 index 000000000..178074c8d --- /dev/null +++ b/src/routes/api/embeddings.ts @@ -0,0 +1,233 @@ +import options from "../../services/options.js"; +import vectorStore from "../../services/llm/embeddings/vector_store.js"; +import providerManager from "../../services/llm/embeddings/providers.js"; +import becca from "../../becca/becca.js"; +import type { Request, Response } from "express"; + +/** + * Get similar notes based on note ID + */ +async function findSimilarNotes(req: Request, res: Response) { + const noteId = req.params.noteId; + const providerId = req.query.providerId as string || 'openai'; + const modelId = req.query.modelId as string || 'text-embedding-3-small'; + const limit = parseInt(req.query.limit as string || '10', 10); + const threshold = parseFloat(req.query.threshold as string || '0.7'); + + if (!noteId) { + return res.status(400).send({ + success: false, + message: "Note ID is required" + }); + } + + try { + const embedding = await vectorStore.getEmbeddingForNote(noteId, providerId, modelId); + + if (!embedding) { + // If no embedding exists for this note yet, generate one + const note = becca.getNote(noteId); + if (!note) { + return res.status(404).send({ + success: false, + message: "Note not found" + }); + } + + const context = await vectorStore.getNoteEmbeddingContext(noteId); + const provider = providerManager.getEmbeddingProvider(providerId); + + if (!provider) { + return res.status(400).send({ + success: false, + message: `Embedding provider '${providerId}' not found` + }); + } + + const newEmbedding = await provider.generateNoteEmbeddings(context); + await vectorStore.storeNoteEmbedding(noteId, providerId, modelId, newEmbedding); + + const similarNotes = await vectorStore.findSimilarNotes( + newEmbedding, providerId, modelId, limit, threshold + ); + + return res.send({ + success: true, + similarNotes + }); + } + + const similarNotes = await vectorStore.findSimilarNotes( + embedding.embedding, providerId, modelId, limit, threshold + ); + + return res.send({ + success: true, + similarNotes + }); + } catch (error: any) { + return res.status(500).send({ + success: false, + message: error.message || "Unknown error" + }); + } +} + +/** + * Search notes by text + */ +async function searchByText(req: Request, res: Response) { + const { text } = req.body; + const providerId = req.query.providerId as string || 'openai'; + const modelId = req.query.modelId as string || 'text-embedding-3-small'; + const limit = parseInt(req.query.limit as string || '10', 10); + const threshold = parseFloat(req.query.threshold as string || '0.7'); + + if (!text) { + return res.status(400).send({ + success: false, + message: "Search text is required" + }); + } + + try { + const provider = providerManager.getEmbeddingProvider(providerId); + + if (!provider) { + return res.status(400).send({ + success: false, + message: `Embedding provider '${providerId}' not found` + }); + } + + // Generate embedding for the search text + const embedding = await provider.generateEmbeddings(text); + + // Find similar notes + const similarNotes = await vectorStore.findSimilarNotes( + embedding, providerId, modelId, limit, threshold + ); + + return res.send({ + success: true, + similarNotes + }); + } catch (error: any) { + return res.status(500).send({ + success: false, + message: error.message || "Unknown error" + }); + } +} + +/** + * Get embedding providers + */ +async function getProviders(req: Request, res: Response) { + try { + const providerConfigs = await providerManager.getEmbeddingProviderConfigs(); + return res.send({ + success: true, + providers: providerConfigs + }); + } catch (error: any) { + return res.status(500).send({ + success: false, + message: error.message || "Unknown error" + }); + } +} + +/** + * Update provider configuration + */ +async function updateProvider(req: Request, res: Response) { + const { providerId } = req.params; + const { isEnabled, priority, config } = req.body; + + try { + const success = await providerManager.updateEmbeddingProviderConfig( + providerId, isEnabled, priority, config + ); + + if (!success) { + return res.status(404).send({ + success: false, + message: "Provider not found" + }); + } + + return res.send({ + success: true + }); + } catch (error: any) { + return res.status(500).send({ + success: false, + message: error.message || "Unknown error" + }); + } +} + +/** + * Manually trigger a reprocessing of all notes + */ +async function reprocessAllNotes(req: Request, res: Response) { + try { + await vectorStore.reprocessAllNotes(); + + return res.send({ + success: true, + message: "Notes queued for reprocessing" + }); + } catch (error: any) { + return res.status(500).send({ + success: false, + message: error.message || "Unknown error" + }); + } +} + +/** + * Get embedding queue status + */ +async function getQueueStatus(req: Request, res: Response) { + try { + // Use sql directly instead of becca.sqliteDB + const sql = require("../../services/sql.js").default; + + const queueCount = await sql.getValue( + "SELECT COUNT(*) FROM embedding_queue" + ); + + const failedCount = await sql.getValue( + "SELECT COUNT(*) FROM embedding_queue WHERE attempts > 0" + ); + + const totalEmbeddingsCount = await sql.getValue( + "SELECT COUNT(*) FROM note_embeddings" + ); + + return res.send({ + success: true, + status: { + queueCount, + failedCount, + totalEmbeddingsCount + } + }); + } catch (error: any) { + return res.status(500).send({ + success: false, + message: error.message || "Unknown error" + }); + } +} + +export default { + findSimilarNotes, + searchByText, + getProviders, + updateProvider, + reprocessAllNotes, + getQueueStatus +}; diff --git a/src/routes/routes.ts b/src/routes/routes.ts index de2055d99..b5a35a5f4 100644 --- a/src/routes/routes.ts +++ b/src/routes/routes.ts @@ -60,6 +60,7 @@ import etapiTokensApiRoutes from "./api/etapi_tokens.js"; import relationMapApiRoute from "./api/relation-map.js"; import otherRoute from "./api/other.js"; import shareRoutes from "../share/routes.js"; +import embeddingsRoute from "./api/embeddings.js"; import etapiAuthRoutes from "../etapi/auth.js"; import etapiAppInfoRoutes from "../etapi/app_info.js"; @@ -369,6 +370,14 @@ function register(app: express.Application) { etapiSpecRoute.register(router); etapiBackupRoute.register(router); + // Embeddings API endpoints + route(GET, "/api/embeddings/similar/:noteId", [auth.checkApiAuth], embeddingsRoute.findSimilarNotes, apiResultHandler); + route(PST, "/api/embeddings/search", [auth.checkApiAuth, csrfMiddleware], embeddingsRoute.searchByText, apiResultHandler); + route(GET, "/api/embeddings/providers", [auth.checkApiAuth], embeddingsRoute.getProviders, apiResultHandler); + route(PATCH, "/api/embeddings/providers/:providerId", [auth.checkApiAuth, csrfMiddleware], embeddingsRoute.updateProvider, apiResultHandler); + route(PST, "/api/embeddings/reprocess", [auth.checkApiAuth, csrfMiddleware], embeddingsRoute.reprocessAllNotes, apiResultHandler); + route(GET, "/api/embeddings/queue-status", [auth.checkApiAuth], embeddingsRoute.getQueueStatus, apiResultHandler); + // API Documentation apiDocsRoute.register(app);