using Microsoft.Data.SqlClient; using Api.Services.Contracts; using Api.Services.Contracts.Models; namespace Api.Services; public sealed class SqlRagRepository : IRagRepository { private readonly string _connectionString; public SqlRagRepository(IConfiguration configuration) { _connectionString = configuration.GetConnectionString("RagDb") ?? throw new InvalidOperationException("Connection string 'RagDb' is missing."); } public async Task InitializeAsync(CancellationToken ct) { await EnsureDatabaseExistsAsync(ct); var sql = await File.ReadAllTextAsync(Path.Combine(AppContext.BaseDirectory, "Database", "schema.sql"), ct); await using var connection = new SqlConnection(_connectionString); await connection.OpenAsync(ct); foreach (var commandText in sql.Split("GO", StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries)) { await using var command = new SqlCommand(commandText, connection); await command.ExecuteNonQueryAsync(ct); } } public async Task GetDocumentByTextHashAsync(string textHash, string? sourceUrl, CancellationToken ct) { const string sql = """ SELECT TOP 1 Id, DocumentType, Title, SourceUrl, RawText, TextHash, TypeConfidence, MetadataJson, CreatedAt FROM RagDocuments WHERE TextHash = @TextHash AND (@SourceUrl IS NULL OR SourceUrl = @SourceUrl) ORDER BY CreatedAt DESC """; await using var connection = new SqlConnection(_connectionString); await connection.OpenAsync(ct); await using var command = new SqlCommand(sql, connection); command.Parameters.AddWithValue("@TextHash", textHash); command.Parameters.AddWithValue("@SourceUrl", (object?)sourceUrl ?? DBNull.Value); await using var reader = await command.ExecuteReaderAsync(ct); return await reader.ReadAsync(ct) ? ReadDocument(reader) : null; } public async Task GetDocumentByIdAsync(string id, CancellationToken ct) { const string sql = """ SELECT Id, DocumentType, Title, SourceUrl, RawText, TextHash, TypeConfidence, MetadataJson, CreatedAt FROM RagDocuments WHERE Id = @Id """; await using var connection = new SqlConnection(_connectionString); await connection.OpenAsync(ct); await using var command = new SqlCommand(sql, connection); command.Parameters.AddWithValue("@Id", id); await using var reader = await command.ExecuteReaderAsync(ct); return await reader.ReadAsync(ct) ? ReadDocument(reader) : null; } public async Task SaveDocumentAsync(RagDocumentRecord document, IReadOnlyList chunks, CancellationToken ct) { await using var connection = new SqlConnection(_connectionString); await connection.OpenAsync(ct); await using var tx = (SqlTransaction)await connection.BeginTransactionAsync(ct); try { const string insertDoc = """ INSERT INTO RagDocuments (Id, DocumentType, Title, SourceUrl, RawText, TextHash, TypeConfidence, MetadataJson, CreatedAt) VALUES (@Id, @DocumentType, @Title, @SourceUrl, @RawText, @TextHash, @TypeConfidence, @MetadataJson, @CreatedAt) """; await using (var command = new SqlCommand(insertDoc, connection, tx)) { command.Parameters.AddWithValue("@Id", document.Id); command.Parameters.AddWithValue("@DocumentType", document.DocumentType); command.Parameters.AddWithValue("@Title", document.Title); command.Parameters.AddWithValue("@SourceUrl", (object?)document.SourceUrl ?? DBNull.Value); command.Parameters.AddWithValue("@RawText", document.Text); command.Parameters.AddWithValue("@TextHash", document.TextHash); command.Parameters.AddWithValue("@TypeConfidence", document.TypeConfidence); command.Parameters.AddWithValue("@MetadataJson", document.MetadataJson); command.Parameters.AddWithValue("@CreatedAt", document.CreatedAt.UtcDateTime); await command.ExecuteNonQueryAsync(ct); } const string insertChunk = """ INSERT INTO RagChunks (Id, DocumentId, ChunkIndex, Text, Embedding) VALUES (@Id, @DocumentId, @ChunkIndex, @Text, @Embedding) """; foreach (var chunk in chunks) { await using var command = new SqlCommand(insertChunk, connection, tx); command.Parameters.AddWithValue("@Id", chunk.Id); command.Parameters.AddWithValue("@DocumentId", document.Id); command.Parameters.AddWithValue("@ChunkIndex", chunk.ChunkIndex); command.Parameters.AddWithValue("@Text", chunk.Text); command.Parameters.AddWithValue("@Embedding", VectorSerializer.ToBytes(chunk.Embedding)); await command.ExecuteNonQueryAsync(ct); } await tx.CommitAsync(ct); } catch { await tx.RollbackAsync(ct); throw; } } public async Task> SearchChunksAsync(float[] queryEmbedding, IReadOnlyList? targetTypes, int topK, CancellationToken ct) { var types = targetTypes?.Where(x => !string.IsNullOrWhiteSpace(x)).Select(x => x.Trim().ToLowerInvariant()).Distinct().ToArray() ?? []; var sql = """ SELECT d.Id, d.DocumentType, d.Title, d.SourceUrl, d.RawText, d.TextHash, d.TypeConfidence, d.MetadataJson, d.CreatedAt, c.Id, c.DocumentId, c.ChunkIndex, c.Text, c.Embedding FROM RagChunks c INNER JOIN RagDocuments d ON d.Id = c.DocumentId """; if (types.Length > 0) { sql += " WHERE LOWER(d.DocumentType) IN (" + string.Join(',', types.Select((_, i) => $"@Type{i}")) + ")"; } await using var connection = new SqlConnection(_connectionString); await connection.OpenAsync(ct); await using var command = new SqlCommand(sql, connection); for (var i = 0; i < types.Length; i++) command.Parameters.AddWithValue($"@Type{i}", types[i]); await using var reader = await command.ExecuteReaderAsync(ct); var candidates = new List(); while (await reader.ReadAsync(ct)) { var doc = ReadDocument(reader, 0); var chunk = new RagChunkRecord { Id = reader.GetString(9), DocumentId = reader.GetString(10), ChunkIndex = reader.GetInt32(11), Text = reader.GetString(12), Embedding = VectorSerializer.FromBytes((byte[])reader[13]) }; candidates.Add(new SearchCandidateChunk { Document = doc, Chunk = chunk, Score = VectorSerializer.CosineSimilarity(queryEmbedding, chunk.Embedding) }); } return candidates .OrderByDescending(x => x.Score) .Take(Math.Max(topK * 4, topK)) .ToList(); } public async Task GetEmbeddingAsync(string cacheKey, CancellationToken ct) { const string sql = "SELECT Vector FROM RagEmbeddingCache WHERE CacheKey = @CacheKey"; await using var connection = new SqlConnection(_connectionString); await connection.OpenAsync(ct); await using var command = new SqlCommand(sql, connection); command.Parameters.AddWithValue("@CacheKey", cacheKey); var value = await command.ExecuteScalarAsync(ct); return value is byte[] bytes ? VectorSerializer.FromBytes(bytes) : null; } public async Task SaveEmbeddingAsync(string cacheKey, string model, string textHash, float[] vector, CancellationToken ct) { const string sql = """ IF NOT EXISTS (SELECT 1 FROM RagEmbeddingCache WHERE CacheKey = @CacheKey) INSERT INTO RagEmbeddingCache (CacheKey, Model, TextHash, Vector, CreatedAt) VALUES (@CacheKey, @Model, @TextHash, @Vector, SYSUTCDATETIME()) """; await using var connection = new SqlConnection(_connectionString); await connection.OpenAsync(ct); await using var command = new SqlCommand(sql, connection); command.Parameters.AddWithValue("@CacheKey", cacheKey); command.Parameters.AddWithValue("@Model", model); command.Parameters.AddWithValue("@TextHash", textHash); command.Parameters.AddWithValue("@Vector", VectorSerializer.ToBytes(vector)); await command.ExecuteNonQueryAsync(ct); } public async Task GetChatCompletionAsync(string cacheKey, CancellationToken ct) { const string sql = "SELECT ResponseText FROM RagChatCompletionCache WHERE CacheKey = @CacheKey"; await using var connection = new SqlConnection(_connectionString); await connection.OpenAsync(ct); await using var command = new SqlCommand(sql, connection); command.Parameters.AddWithValue("@CacheKey", cacheKey); return await command.ExecuteScalarAsync(ct) as string; } public async Task SaveChatCompletionAsync(string cacheKey, string model, decimal temperature, string responseText, CancellationToken ct) { const string sql = """ IF NOT EXISTS (SELECT 1 FROM RagChatCompletionCache WHERE CacheKey = @CacheKey) INSERT INTO RagChatCompletionCache (CacheKey, Model, Temperature, ResponseText, CreatedAt) VALUES (@CacheKey, @Model, @Temperature, @ResponseText, SYSUTCDATETIME()) """; await using var connection = new SqlConnection(_connectionString); await connection.OpenAsync(ct); await using var command = new SqlCommand(sql, connection); command.Parameters.AddWithValue("@CacheKey", cacheKey); command.Parameters.AddWithValue("@Model", model); command.Parameters.AddWithValue("@Temperature", temperature); command.Parameters.AddWithValue("@ResponseText", responseText); await command.ExecuteNonQueryAsync(ct); } private static RagDocumentRecord ReadDocument(SqlDataReader reader, int offset = 0) => new() { Id = reader.GetString(offset), DocumentType = reader.GetString(offset + 1), Title = reader.GetString(offset + 2), SourceUrl = reader.IsDBNull(offset + 3) ? null : reader.GetString(offset + 3), Text = reader.GetString(offset + 4), TextHash = reader.GetString(offset + 5), TypeConfidence = Convert.ToDouble(reader.GetValue(offset + 6)), MetadataJson = reader.GetString(offset + 7), CreatedAt = new DateTimeOffset(reader.GetDateTime(offset + 8), TimeSpan.Zero) }; private async Task EnsureDatabaseExistsAsync(CancellationToken ct) { var builder = new SqlConnectionStringBuilder(_connectionString); var databaseName = builder.InitialCatalog; if (string.IsNullOrWhiteSpace(databaseName)) return; builder.InitialCatalog = "master"; await using var connection = new SqlConnection(builder.ConnectionString); await connection.OpenAsync(ct); var safeName = databaseName.Replace("]", "]]" ); await using var command = new SqlCommand($"IF DB_ID(@DatabaseName) IS NULL EXEC('CREATE DATABASE [{safeName}]')", connection); command.Parameters.AddWithValue("@DatabaseName", databaseName); await command.ExecuteNonQueryAsync(ct); } }