using Rag.Data; using Rag.Data.Entities; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.Logging; using Rag.Data.Repositories.Contracts; using Rag.Models; namespace Rag.Data.Repositories; public sealed class EfRagRepository : IRagRepository { private readonly RagDbContext _db; private readonly ILogger _logger; public EfRagRepository(RagDbContext db, ILogger logger) { _db = db; _logger = logger; } public async Task InitializeAsync(CancellationToken ct) { _logger.LogInformation("Ensuring RAG database schema exists using EF Core"); //await _db.Database.EnsureCreatedAsync(ct); } public async Task GetDocumentByTextHashAsync(string textHash, string? sourceUrl, CancellationToken ct) { var query = _db.RagDocuments .AsNoTracking() .Where(x => x.TextHash == textHash); if (!string.IsNullOrWhiteSpace(sourceUrl)) { query = query.Where(x => x.SourceUrl == sourceUrl); } var entity = await query .OrderByDescending(x => x.CreatedAt) .FirstOrDefaultAsync(ct); return entity is null ? null : ToRecord(entity); } public async Task GetDocumentByIdAsync(string id, CancellationToken ct) { var entity = await _db.RagDocuments .AsNoTracking() .FirstOrDefaultAsync(x => x.Id == id, ct); return entity is null ? null : ToRecord(entity); } public async Task SaveDocumentAsync(RagDocumentRecord document, IReadOnlyList chunks, CancellationToken ct) { var exists = await _db.RagDocuments.AnyAsync(x => x.Id == document.Id, ct); if (exists) { _logger.LogInformation("RAG document already exists. DocumentId={DocumentId}", document.Id); return; } var entity = new RagDocumentEntity { Id = document.Id, DocumentType = document.DocumentType, Title = document.Title, SourceUrl = document.SourceUrl, RawText = document.Text, TextHash = document.TextHash, TypeConfidence = document.TypeConfidence, MetadataJson = document.MetadataJson, CreatedAt = document.CreatedAt.UtcDateTime, Chunks = chunks.Select(chunk => new RagChunkEntity { Id = chunk.Id, DocumentId = chunk.DocumentId, ChunkIndex = chunk.ChunkIndex, Text = chunk.Text, Embedding = VectorSerializer.ToBytes(chunk.Embedding) }).ToList() }; _db.RagDocuments.Add(entity); await _db.SaveChangesAsync(ct); } public async Task> SearchChunksAsync( float[] queryEmbedding, IReadOnlyList? targetTypes, int topK, CancellationToken ct) { var types = targetTypes? .Where(x => !string.IsNullOrWhiteSpace(x)) .Select(x => x.Trim().ToLowerInvariant()) .Distinct() .ToArray() ?? System.Array.Empty(); var query = _db.RagChunks .AsNoTracking() .Include(x => x.Document) .AsQueryable(); if (types.Length > 0) { query = query.Where(x => x.Document != null && types.Contains(x.Document.DocumentType.ToLower())); } var rows = await query.ToListAsync(ct); return rows .Where(x => x.Document is not null) .Select(x => new SearchCandidateChunk { Document = ToRecord(x.Document!), Chunk = new RagChunkRecord { Id = x.Id, DocumentId = x.DocumentId, ChunkIndex = x.ChunkIndex, Text = x.Text, Embedding = VectorSerializer.FromBytes(x.Embedding) }, Score = VectorSerializer.CosineSimilarity(queryEmbedding, VectorSerializer.FromBytes(x.Embedding)) }) .OrderByDescending(x => x.Score) .Take(Math.Max(topK * 4, topK)) .ToList(); } public async Task GetEmbeddingAsync(string cacheKey, CancellationToken ct) { var entry = await _db.RagEmbeddingCache .AsNoTracking() .FirstOrDefaultAsync(x => x.CacheKey == cacheKey, ct); return entry is null ? null : VectorSerializer.FromBytes(entry.Vector); } public async Task SaveEmbeddingAsync(string cacheKey, string model, string textHash, float[] vector, CancellationToken ct) { var exists = await _db.RagEmbeddingCache.AnyAsync(x => x.CacheKey == cacheKey, ct); if (exists) return; _db.RagEmbeddingCache.Add(new RagEmbeddingCacheEntity { CacheKey = cacheKey, Model = model, TextHash = textHash, Vector = VectorSerializer.ToBytes(vector), CreatedAt = DateTime.UtcNow }); await _db.SaveChangesAsync(ct); } public async Task GetChatCompletionAsync(string cacheKey, CancellationToken ct) { return await _db.RagChatCompletionCache .AsNoTracking() .Where(x => x.CacheKey == cacheKey) .Select(x => x.ResponseText) .FirstOrDefaultAsync(ct); } public async Task SaveChatCompletionAsync(string cacheKey, string model, decimal temperature, string responseText, CancellationToken ct) { var exists = await _db.RagChatCompletionCache.AnyAsync(x => x.CacheKey == cacheKey, ct); if (exists) return; _db.RagChatCompletionCache.Add(new RagChatCompletionCacheEntity { CacheKey = cacheKey, Model = model, Temperature = temperature, ResponseText = responseText, CreatedAt = DateTime.UtcNow }); await _db.SaveChangesAsync(ct); } private static RagDocumentRecord ToRecord(RagDocumentEntity entity) => new() { Id = entity.Id, DocumentType = entity.DocumentType, Title = entity.Title, SourceUrl = entity.SourceUrl, Text = entity.RawText, TextHash = entity.TextHash, TypeConfidence = entity.TypeConfidence, MetadataJson = entity.MetadataJson, CreatedAt = new DateTimeOffset(DateTime.SpecifyKind(entity.CreatedAt, DateTimeKind.Utc)) }; }