using System.Reflection; using Azure.Identity; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Diagnostics; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.HttpOverrides; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; using Serilog; using Swashbuckle.AspNetCore.SwaggerGen; using Swashbuckle.AspNetCore.Annotations; namespace StartupHelpers; public static class StartupExtensions { public static void LoadDotEnvFile() { DotNetEnv.Env.Load(); } public static string GetApplicationVersion(Assembly assembly) { return assembly.GetCustomAttribute()?.InformationalVersion ?? assembly.GetName().Version?.ToString() ?? "unknown"; } public static void ConfigureJsonSerilog(this WebApplicationBuilder builder, string serviceName, string appVersion) { builder.Host.UseSerilog((context, services, configuration) => { configuration .ReadFrom.Configuration(context.Configuration) .ReadFrom.Services(services) .Enrich.FromLogContext() .Enrich.WithMachineName() .Enrich.WithEnvironmentName() .Enrich.WithProperty("Service", serviceName) .Enrich.WithProperty("AppVersion", appVersion) .WriteTo.Console(new Serilog.Formatting.Json.JsonFormatter()); }); } public static void ConfigureJsonSerilog(this HostApplicationBuilder builder, string serviceName, string appVersion) { builder.Services.AddSerilog((services, configuration) => { configuration .ReadFrom.Configuration(builder.Configuration) .ReadFrom.Services(services) .Enrich.FromLogContext() .Enrich.WithMachineName() .Enrich.WithEnvironmentName() .Enrich.WithProperty("Service", serviceName) .Enrich.WithProperty("AppVersion", appVersion) .WriteTo.Console(new Serilog.Formatting.Json.JsonFormatter()); }); } public static void AddAzureKeyVaultIfConfigured(this WebApplicationBuilder builder) { var keyVaultUri = builder.Configuration["KeyVault:VaultUri"]; var keyVaultEnabled = builder.Configuration.GetValue("KeyVault:Enabled"); if (!keyVaultEnabled || string.IsNullOrWhiteSpace(keyVaultUri)) { Log.Information("Azure Key Vault is disabled or not configured"); return; } Log.Information("Loading configuration from Azure Key Vault: {VaultUri}", keyVaultUri); try { builder.Configuration.AddAzureKeyVault(new Uri(keyVaultUri), new DefaultAzureCredential()); Log.Information("Azure Key Vault configuration loaded successfully"); } catch (Exception ex) { Log.Warning(ex, "Failed to load Azure Key Vault configuration. Continuing with other configuration sources."); } } public static void AddSwaggerWithXmlComments(this IServiceCollection services, Assembly assembly, string fallbackName, bool enableAnnotations = true) { services.AddEndpointsApiExplorer(); services.AddSwaggerGen(options => { var xmlFile = (assembly.GetName().Name ?? fallbackName) + ".xml"; var xmlPath = Path.Combine(AppContext.BaseDirectory, xmlFile); if (File.Exists(xmlPath)) { options.IncludeXmlComments(xmlPath); } if (enableAnnotations) { options.EnableAnnotations(); } }); } public static void ConfigureCaddyForwardedHeaders(this IServiceCollection services) { services.Configure(options => { options.ForwardedHeaders = ForwardedHeaders.XForwardedFor | ForwardedHeaders.XForwardedProto; options.ForwardedForHeaderName = "X-Real-IP"; options.KnownIPNetworks.Clear(); options.KnownProxies.Clear(); options.ForwardLimit = 1; }); } public static void AddFrontendCorsFromConfiguration(this IServiceCollection services, IConfiguration configuration, string policyName = "FrontendOnly") { var allowedOrigins = configuration.GetSection("Cors:AllowedOrigins").Get() ?? Array.Empty(); services.AddCors(options => { options.AddPolicy(policyName, policy => { if (allowedOrigins.Length > 0) { policy.WithOrigins(allowedOrigins) .WithMethods("POST", "OPTIONS") .WithHeaders("Content-Type") .SetPreflightMaxAge(TimeSpan.FromHours(1)); } }); }); } public static void LogStartupDiagnostics(this WebApplication app, string serviceName) { var logger = app.Services.GetRequiredService().CreateLogger(serviceName); logger.LogInformation("{Service} starting up...", serviceName); logger.LogInformation("Environment: {Environment}", app.Environment.EnvironmentName); var logEnvironmentOnStartup = app.Configuration.GetValue("LogEnvironmentOnStartup", defaultValue: true); if (logEnvironmentOnStartup) { EnvironmentDiagnostics.LogEnvironmentSettings(logger, app.Configuration, app.Environment); } } public static void LogHostStartupDiagnostics(this IHost host, string serviceName) { var logger = host.Services.GetRequiredService().CreateLogger(serviceName); logger.LogInformation("{Service} starting up...", serviceName); var environment = host.Services.GetRequiredService(); logger.LogInformation("Environment: {Environment}", environment.EnvironmentName); var configuration = host.Services.GetRequiredService(); var logEnvironmentOnStartup = configuration.GetValue("LogEnvironmentOnStartup", defaultValue: true); if (logEnvironmentOnStartup) { EnvironmentDiagnostics.LogEnvironmentSettings(logger, configuration, environment); } } public static void UseDefaultSerilogRequestLogging(this WebApplication app, bool includeProxyHeaders = false) { app.UseSerilogRequestLogging(options => { options.MessageTemplate = "HTTP {RequestMethod} {RequestPath} responded {StatusCode} in {Elapsed:0.0000} ms"; options.EnrichDiagnosticContext = (diagnosticContext, httpContext) => { diagnosticContext.Set("RequestHost", httpContext.Request.Host.Value); diagnosticContext.Set("RequestScheme", httpContext.Request.Scheme); diagnosticContext.Set("RemoteIP", httpContext.Connection.RemoteIpAddress?.ToString()); diagnosticContext.Set("UserAgent", httpContext.Request.Headers.UserAgent.ToString()); if (includeProxyHeaders) { diagnosticContext.Set("XRealIP", httpContext.Request.Headers["X-Real-IP"].ToString()); diagnosticContext.Set("XForwardedFor", httpContext.Request.Headers["X-Forwarded-For"].ToString()); } }; }); } public static void UseJsonExceptionHandler(this WebApplication app, string serviceName) { app.UseExceptionHandler(errorApp => { errorApp.Run(async context => { var feature = context.Features.Get(); var logger = context.RequestServices.GetRequiredService().CreateLogger(serviceName); if (feature?.Error is not null) { logger.LogError(feature.Error, "Unhandled exception in {Service}", serviceName); } context.Response.StatusCode = StatusCodes.Status500InternalServerError; context.Response.ContentType = "application/json"; await context.Response.WriteAsJsonAsync(new { error = "Unexpected server error." }); }); }); } public static void UseInternalApiKeyProtection(this WebApplication app, string sectionName = "InternalApi") { app.Use(async (context, next) => { var requireApiKey = context.RequestServices.GetRequiredService().GetValue($"{sectionName}:RequireApiKey"); if (requireApiKey) { var configuredApiKey = context.RequestServices.GetRequiredService()[$"{sectionName}:ApiKey"]; var headerApiKey = context.Request.Headers["X-Internal-Api-Key"].ToString(); if (string.IsNullOrWhiteSpace(configuredApiKey) || headerApiKey != configuredApiKey) { var logger = context.RequestServices.GetRequiredService().CreateLogger("InternalApiKey"); logger.LogWarning( "Rejected unauthorized internal API call. Path={Path}, RemoteIP={RemoteIP}", context.Request.Path, context.Connection.RemoteIpAddress?.ToString()); context.Response.StatusCode = StatusCodes.Status401Unauthorized; await context.Response.WriteAsJsonAsync(new { error = "Unauthorized internal API call." }); return; } } await next(); }); } public static void UseSwaggerInDevelopment(this WebApplication app, string documentTitle, string endpointName) { if (!app.Environment.IsDevelopment()) { return; } app.UseSwagger(); app.UseSwaggerUI(options => { options.DocumentTitle = documentTitle; options.SwaggerEndpoint("/swagger/v1/swagger.json", $"{endpointName} v1"); options.RoutePrefix = "swagger"; }); } }