diff --git a/lib/ai/ollama.js b/lib/ai/ollama.js index fd5dcf4..f485587 100644 --- a/lib/ai/ollama.js +++ b/lib/ai/ollama.js @@ -22,3 +22,40 @@ export function padTo(vector, dim) { while (out.length < dim) out.push(0); return out; } + +// Split text into chunks of at most `size` chars, breaking on line boundaries +// where possible (never mid-word-loss): accumulate lines until adding the next +// would exceed `size`. A single over-long line is hard-split. Returns [] for empty. +export function chunkText(text, size = 1500) { + const s = (text || '').trim(); + if (!s) return []; + const chunks = []; + let cur = ''; + for (const line of s.split('\n')) { + if (line.length > size) { + if (cur) { chunks.push(cur); cur = ''; } + for (let i = 0; i < line.length; i += size) chunks.push(line.slice(i, i + size)); + continue; + } + if (cur.length + line.length + 1 > size) { if (cur) chunks.push(cur); cur = line; } + else { cur = cur ? cur + '\n' + line : line; } + } + if (cur) chunks.push(cur); + return chunks; +} + +// Embed possibly-long text by chunking, embedding each chunk, and mean-pooling +// the resulting vectors element-wise. Returns a single embedding vector. +// 1 chunk => identical to embedText. Caps the number of chunks to bound cost. +export async function embedTextPooled(text, { model = 'nomic-embed-text', timeoutMs = 60_000, maxChunks = 64, chunkSize = 1500 } = {}) { + let chunks = chunkText(text, chunkSize); + if (chunks.length === 0) chunks = ['']; + if (chunks.length > maxChunks) chunks = chunks.slice(0, maxChunks); + const vecs = []; + for (const c of chunks) vecs.push(await embedText(c, { model, timeoutMs })); + const dim = vecs[0].length; + const pooled = new Array(dim).fill(0); + for (const v of vecs) for (let i = 0; i < dim; i++) pooled[i] += (v[i] || 0); + for (let i = 0; i < dim; i++) pooled[i] /= vecs.length; + return pooled; +} diff --git a/lib/jobs/workers/embed.js b/lib/jobs/workers/embed.js index 3f596b0..2ce415b 100644 --- a/lib/jobs/workers/embed.js +++ b/lib/jobs/workers/embed.js @@ -1,4 +1,4 @@ -import { embedText, padTo } from '../../ai/ollama.js'; +import { embedTextPooled, padTo } from '../../ai/ollama.js'; import { pool } from '../../db/pool.js'; import { recordAudit } from '../../db/repos/audit.js'; @@ -19,8 +19,8 @@ export async function handler(job) { if (!table) throw new Error(`unknown entity_type: ${entity_type}`); const { rows: [row] } = await pool.query(`SELECT * FROM ${table} WHERE id=$1`, [entity_id]); if (!row) return { skipped: 'gone' }; - const text = STRING_BUILDERS[entity_type](row).slice(0, 6_000); - const v = await embedText(text); + const text = STRING_BUILDERS[entity_type](row); + const v = await embedTextPooled(text); const padded = padTo(v, 1024); const literal = '[' + padded.join(',') + ']'; await pool.query(`UPDATE ${table} SET embedding=$1::vector WHERE id=$2`, [literal, entity_id]); diff --git a/tests/ai/embed_chunking.test.js b/tests/ai/embed_chunking.test.js new file mode 100644 index 0000000..2069cdd --- /dev/null +++ b/tests/ai/embed_chunking.test.js @@ -0,0 +1,36 @@ +import { describe, it, expect, vi, afterEach } from 'vitest'; +import { chunkText, embedTextPooled } from '../../lib/ai/ollama.js'; + +afterEach(() => { vi.unstubAllGlobals(); }); + +describe('chunkText', () => { + it('returns [] for empty', () => { expect(chunkText('')).toEqual([]); }); + it('keeps short text as one chunk', () => { expect(chunkText('hello\nworld', 1500)).toEqual(['hello\nworld']); }); + it('splits long text into <=size chunks covering all chars', () => { + const text = Array.from({length: 50}, (_,i)=>`line ${i} ${'x'.repeat(40)}`).join('\n'); + const chunks = chunkText(text, 200); + expect(chunks.length).toBeGreaterThan(1); + for (const c of chunks) expect(c.length).toBeLessThanOrEqual(200); + }); + it('hard-splits a single over-long line', () => { + const chunks = chunkText('y'.repeat(500), 100); + expect(chunks.length).toBe(5); + expect(chunks.every(c => c.length <= 100)).toBe(true); + }); +}); + +describe('embedTextPooled', () => { + it('mean-pools chunk vectors', async () => { + // two chunks (size 5 forces split), fetch returns embedding = [callCount, callCount] + let n = 0; + vi.stubGlobal('fetch', vi.fn(async () => { n++; return { ok: true, json: async () => ({ embedding: [n, n] }) }; })); + const v = await embedTextPooled('aaaaa\nbbbbb', { chunkSize: 5 }); + // chunks: ['aaaaa','bbbbb'] -> vectors [1,1],[2,2] -> mean [1.5,1.5] + expect(v).toEqual([1.5, 1.5]); + }); + it('single chunk equals single embed', async () => { + vi.stubGlobal('fetch', vi.fn(async () => ({ ok: true, json: async () => ({ embedding: [7, 8, 9] }) }))); + const v = await embedTextPooled('short', { chunkSize: 1500 }); + expect(v).toEqual([7, 8, 9]); + }); +});