mirror of
https://github.com/zadam/trilium.git
synced 2025-11-01 03:59:05 +01:00
set up embedding API endpoints
This commit is contained in:
parent
c442943672
commit
1361e4d438
233
src/routes/api/embeddings.ts
Normal file
233
src/routes/api/embeddings.ts
Normal file
@ -0,0 +1,233 @@
|
||||
import options from "../../services/options.js";
|
||||
import vectorStore from "../../services/llm/embeddings/vector_store.js";
|
||||
import providerManager from "../../services/llm/embeddings/providers.js";
|
||||
import becca from "../../becca/becca.js";
|
||||
import type { Request, Response } from "express";
|
||||
|
||||
/**
|
||||
* Get similar notes based on note ID
|
||||
*/
|
||||
async function findSimilarNotes(req: Request, res: Response) {
|
||||
const noteId = req.params.noteId;
|
||||
const providerId = req.query.providerId as string || 'openai';
|
||||
const modelId = req.query.modelId as string || 'text-embedding-3-small';
|
||||
const limit = parseInt(req.query.limit as string || '10', 10);
|
||||
const threshold = parseFloat(req.query.threshold as string || '0.7');
|
||||
|
||||
if (!noteId) {
|
||||
return res.status(400).send({
|
||||
success: false,
|
||||
message: "Note ID is required"
|
||||
});
|
||||
}
|
||||
|
||||
try {
|
||||
const embedding = await vectorStore.getEmbeddingForNote(noteId, providerId, modelId);
|
||||
|
||||
if (!embedding) {
|
||||
// If no embedding exists for this note yet, generate one
|
||||
const note = becca.getNote(noteId);
|
||||
if (!note) {
|
||||
return res.status(404).send({
|
||||
success: false,
|
||||
message: "Note not found"
|
||||
});
|
||||
}
|
||||
|
||||
const context = await vectorStore.getNoteEmbeddingContext(noteId);
|
||||
const provider = providerManager.getEmbeddingProvider(providerId);
|
||||
|
||||
if (!provider) {
|
||||
return res.status(400).send({
|
||||
success: false,
|
||||
message: `Embedding provider '${providerId}' not found`
|
||||
});
|
||||
}
|
||||
|
||||
const newEmbedding = await provider.generateNoteEmbeddings(context);
|
||||
await vectorStore.storeNoteEmbedding(noteId, providerId, modelId, newEmbedding);
|
||||
|
||||
const similarNotes = await vectorStore.findSimilarNotes(
|
||||
newEmbedding, providerId, modelId, limit, threshold
|
||||
);
|
||||
|
||||
return res.send({
|
||||
success: true,
|
||||
similarNotes
|
||||
});
|
||||
}
|
||||
|
||||
const similarNotes = await vectorStore.findSimilarNotes(
|
||||
embedding.embedding, providerId, modelId, limit, threshold
|
||||
);
|
||||
|
||||
return res.send({
|
||||
success: true,
|
||||
similarNotes
|
||||
});
|
||||
} catch (error: any) {
|
||||
return res.status(500).send({
|
||||
success: false,
|
||||
message: error.message || "Unknown error"
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Search notes by text
|
||||
*/
|
||||
async function searchByText(req: Request, res: Response) {
|
||||
const { text } = req.body;
|
||||
const providerId = req.query.providerId as string || 'openai';
|
||||
const modelId = req.query.modelId as string || 'text-embedding-3-small';
|
||||
const limit = parseInt(req.query.limit as string || '10', 10);
|
||||
const threshold = parseFloat(req.query.threshold as string || '0.7');
|
||||
|
||||
if (!text) {
|
||||
return res.status(400).send({
|
||||
success: false,
|
||||
message: "Search text is required"
|
||||
});
|
||||
}
|
||||
|
||||
try {
|
||||
const provider = providerManager.getEmbeddingProvider(providerId);
|
||||
|
||||
if (!provider) {
|
||||
return res.status(400).send({
|
||||
success: false,
|
||||
message: `Embedding provider '${providerId}' not found`
|
||||
});
|
||||
}
|
||||
|
||||
// Generate embedding for the search text
|
||||
const embedding = await provider.generateEmbeddings(text);
|
||||
|
||||
// Find similar notes
|
||||
const similarNotes = await vectorStore.findSimilarNotes(
|
||||
embedding, providerId, modelId, limit, threshold
|
||||
);
|
||||
|
||||
return res.send({
|
||||
success: true,
|
||||
similarNotes
|
||||
});
|
||||
} catch (error: any) {
|
||||
return res.status(500).send({
|
||||
success: false,
|
||||
message: error.message || "Unknown error"
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get embedding providers
|
||||
*/
|
||||
async function getProviders(req: Request, res: Response) {
|
||||
try {
|
||||
const providerConfigs = await providerManager.getEmbeddingProviderConfigs();
|
||||
return res.send({
|
||||
success: true,
|
||||
providers: providerConfigs
|
||||
});
|
||||
} catch (error: any) {
|
||||
return res.status(500).send({
|
||||
success: false,
|
||||
message: error.message || "Unknown error"
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update provider configuration
|
||||
*/
|
||||
async function updateProvider(req: Request, res: Response) {
|
||||
const { providerId } = req.params;
|
||||
const { isEnabled, priority, config } = req.body;
|
||||
|
||||
try {
|
||||
const success = await providerManager.updateEmbeddingProviderConfig(
|
||||
providerId, isEnabled, priority, config
|
||||
);
|
||||
|
||||
if (!success) {
|
||||
return res.status(404).send({
|
||||
success: false,
|
||||
message: "Provider not found"
|
||||
});
|
||||
}
|
||||
|
||||
return res.send({
|
||||
success: true
|
||||
});
|
||||
} catch (error: any) {
|
||||
return res.status(500).send({
|
||||
success: false,
|
||||
message: error.message || "Unknown error"
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Manually trigger a reprocessing of all notes
|
||||
*/
|
||||
async function reprocessAllNotes(req: Request, res: Response) {
|
||||
try {
|
||||
await vectorStore.reprocessAllNotes();
|
||||
|
||||
return res.send({
|
||||
success: true,
|
||||
message: "Notes queued for reprocessing"
|
||||
});
|
||||
} catch (error: any) {
|
||||
return res.status(500).send({
|
||||
success: false,
|
||||
message: error.message || "Unknown error"
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get embedding queue status
|
||||
*/
|
||||
async function getQueueStatus(req: Request, res: Response) {
|
||||
try {
|
||||
// Use sql directly instead of becca.sqliteDB
|
||||
const sql = require("../../services/sql.js").default;
|
||||
|
||||
const queueCount = await sql.getValue(
|
||||
"SELECT COUNT(*) FROM embedding_queue"
|
||||
);
|
||||
|
||||
const failedCount = await sql.getValue(
|
||||
"SELECT COUNT(*) FROM embedding_queue WHERE attempts > 0"
|
||||
);
|
||||
|
||||
const totalEmbeddingsCount = await sql.getValue(
|
||||
"SELECT COUNT(*) FROM note_embeddings"
|
||||
);
|
||||
|
||||
return res.send({
|
||||
success: true,
|
||||
status: {
|
||||
queueCount,
|
||||
failedCount,
|
||||
totalEmbeddingsCount
|
||||
}
|
||||
});
|
||||
} catch (error: any) {
|
||||
return res.status(500).send({
|
||||
success: false,
|
||||
message: error.message || "Unknown error"
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
export default {
|
||||
findSimilarNotes,
|
||||
searchByText,
|
||||
getProviders,
|
||||
updateProvider,
|
||||
reprocessAllNotes,
|
||||
getQueueStatus
|
||||
};
|
||||
@ -60,6 +60,7 @@ import etapiTokensApiRoutes from "./api/etapi_tokens.js";
|
||||
import relationMapApiRoute from "./api/relation-map.js";
|
||||
import otherRoute from "./api/other.js";
|
||||
import shareRoutes from "../share/routes.js";
|
||||
import embeddingsRoute from "./api/embeddings.js";
|
||||
|
||||
import etapiAuthRoutes from "../etapi/auth.js";
|
||||
import etapiAppInfoRoutes from "../etapi/app_info.js";
|
||||
@ -369,6 +370,14 @@ function register(app: express.Application) {
|
||||
etapiSpecRoute.register(router);
|
||||
etapiBackupRoute.register(router);
|
||||
|
||||
// Embeddings API endpoints
|
||||
route(GET, "/api/embeddings/similar/:noteId", [auth.checkApiAuth], embeddingsRoute.findSimilarNotes, apiResultHandler);
|
||||
route(PST, "/api/embeddings/search", [auth.checkApiAuth, csrfMiddleware], embeddingsRoute.searchByText, apiResultHandler);
|
||||
route(GET, "/api/embeddings/providers", [auth.checkApiAuth], embeddingsRoute.getProviders, apiResultHandler);
|
||||
route(PATCH, "/api/embeddings/providers/:providerId", [auth.checkApiAuth, csrfMiddleware], embeddingsRoute.updateProvider, apiResultHandler);
|
||||
route(PST, "/api/embeddings/reprocess", [auth.checkApiAuth, csrfMiddleware], embeddingsRoute.reprocessAllNotes, apiResultHandler);
|
||||
route(GET, "/api/embeddings/queue-status", [auth.checkApiAuth], embeddingsRoute.getQueueStatus, apiResultHandler);
|
||||
|
||||
// API Documentation
|
||||
apiDocsRoute.register(app);
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user