fixing Gradle

This commit is contained in:
stefan
2025-08-01 11:31:29 +02:00
parent 4ea084bd1d
commit a9a43a7acf
44 changed files with 244 additions and 5552 deletions
@@ -8,6 +8,16 @@ plugins {
alias(libs.plugins.spring.dependencyManagement)
}
// Deaktiviert die Erstellung eines ausführbaren Jars für dieses Bibliotheks-Modul.
tasks.getByName<org.springframework.boot.gradle.tasks.bundling.BootJar>("bootJar") {
enabled = false
}
// Stellt sicher, dass stattdessen ein reguläres Jar gebaut wird.
tasks.getByName<org.gradle.api.tasks.bundling.Jar>("jar") {
enabled = true
}
dependencies {
// Stellt sicher, dass alle Versionen aus der zentralen BOM kommen.
implementation(platform(projects.platform.platformBom))
@@ -7,6 +7,12 @@ plugins {
alias(libs.plugins.spring.dependencyManagement)
}
kotlin {
compilerOptions {
freeCompilerArgs.add("-opt-in=kotlin.time.ExperimentalTime")
}
}
dependencies {
// Stellt sicher, dass alle Versionen aus der zentralen BOM kommen.
implementation(platform(projects.platform.platformBom))
@@ -24,8 +30,18 @@ dependencies {
// Stellt Jakarta Annotations bereit (z.B. @PostConstruct), die von Spring verwendet werden.
implementation(libs.jakarta.annotation.api)
// Stellt alle Test-Abhängigkeiten gebündelt bereit.
// Fügt JUnit, Mockk, AssertJ etc. für die Tests hinzu
testImplementation(projects.platform.platformTesting)
// Stellt Testcontainers für Integrationstests mit einer echten Redis-Instanz bereit.
testImplementation(libs.bundles.testing.jvm)
testImplementation(libs.bundles.testcontainers)
}
// Deaktiviert die Erstellung eines ausführbaren Jars für dieses Bibliotheks-Modul.
tasks.getByName<org.springframework.boot.gradle.tasks.bundling.BootJar>("bootJar") {
enabled = false
}
// Stellt sicher, dass stattdessen ein reguläres Jar gebaut wird.
tasks.getByName<org.gradle.api.tasks.bundling.Jar>("jar") {
enabled = true
}
@@ -75,7 +75,7 @@ class RedisEventStore(
newVersion++
// Ensure the event has the correct version
if (event.version != newVersion) {
if (event.version.toLong() != newVersion) {
throw IllegalArgumentException(
"Event version ${event.version} does not match expected version $newVersion"
)
@@ -4,6 +4,10 @@ import at.mocode.core.domain.event.BaseDomainEvent
import at.mocode.core.domain.event.DomainEvent
import at.mocode.infrastructure.eventstore.api.EventSerializer
import at.mocode.infrastructure.eventstore.api.EventStore
import com.benasher44.uuid.Uuid
import com.benasher44.uuid.uuid4
import kotlin.time.Clock
import kotlin.time.Instant
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertTrue
@@ -16,23 +20,15 @@ import org.testcontainers.containers.GenericContainer
import org.testcontainers.junit.jupiter.Container
import org.testcontainers.junit.jupiter.Testcontainers
import org.testcontainers.utility.DockerImageName
import java.time.Instant
import java.util.*
import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit
/**
* Integration tests for Redis Event Store.
*
* These tests verify the interaction between the Redis Event Store, Event Consumer, and Event Serializer
* in a more realistic scenario.
*/
@Testcontainers
class RedisEventStoreIntegrationTest {
companion object {
@Container
val redisContainer = GenericContainer(DockerImageName.parse("redis:7-alpine"))
val redisContainer: GenericContainer<*> = GenericContainer(DockerImageName.parse("redis:7-alpine"))
.withExposedPorts(6379)
}
@@ -51,249 +47,95 @@ class RedisEventStoreIntegrationTest {
val connectionFactory = LettuceConnectionFactory(redisConfig)
connectionFactory.afterPropertiesSet()
redisTemplate = StringRedisTemplate()
redisTemplate.setConnectionFactory(connectionFactory)
redisTemplate.afterPropertiesSet()
redisTemplate = StringRedisTemplate(connectionFactory)
serializer = JacksonEventSerializer()
// Register test event types
serializer.registerEventType(TestCreatedEvent::class.java, "TestCreated")
serializer.registerEventType(TestUpdatedEvent::class.java, "TestUpdated")
serializer = JacksonEventSerializer().apply {
registerEventType("TestCreated" as Class<out DomainEvent>, TestCreatedEvent::class.java as String)
registerEventType("TestUpdated" as Class<out DomainEvent>, TestUpdatedEvent::class.java as String)
}
properties = RedisEventStoreProperties(
host = redisHost,
port = redisPort,
streamPrefix = "test-stream:",
allEventsStream = "all-events",
consumerGroup = "test-group",
consumerName = "test-consumer",
createConsumerGroupIfNotExists = true
consumerName = "test-consumer"
)
eventStore = RedisEventStore(redisTemplate, serializer, properties)
eventConsumer = RedisEventConsumer(redisTemplate, serializer, properties)
// Clear all streams
val keys = redisTemplate.keys("${properties.streamPrefix}*")
if (keys.isNotEmpty()) {
redisTemplate.delete(keys)
}
cleanupRedis()
}
@AfterEach
fun tearDown() {
// Clear all streams
eventConsumer.shutdown()
cleanupRedis()
}
private fun cleanupRedis() {
val keys = redisTemplate.keys("${properties.streamPrefix}*")
if (keys.isNotEmpty()) {
if (!keys.isNullOrEmpty()) {
redisTemplate.delete(keys)
}
redisTemplate.delete(properties.allEventsStream)
}
@Test
fun `test event publishing and consuming with consumer groups`() {
// Create an aggregate ID
val aggregateId = UUID.randomUUID()
fun `event publishing and consuming with consumer groups should work`() {
val aggregateId = uuid4()
val event1 = TestCreatedEvent(aggregateId = aggregateId, version = 1L, name = "Test Entity")
val event2 = TestUpdatedEvent(aggregateId = aggregateId, version = 2L, name = "Updated Test Entity")
// Create events
val event1 = TestCreatedEvent(
aggregateId = aggregateId,
version = 0,
name = "Test Entity"
)
val event2 = TestUpdatedEvent(
aggregateId = aggregateId,
version = 1,
name = "Updated Test Entity"
)
// Set up a latch to wait for events
val latch = CountDownLatch(2)
val receivedEvents = mutableListOf<DomainEvent>()
// Register a handler for TestCreatedEvent
eventConsumer.registerEventHandler("TestCreated") { event ->
receivedEvents.add(event)
latch.countDown()
}
// Register a handler for TestUpdatedEvent
eventConsumer.registerEventHandler("TestUpdated") { event ->
receivedEvents.add(event)
latch.countDown()
}
// Initialize the consumer
eventConsumer.init()
// Append events to the stream
eventStore.appendToStream(event1, aggregateId, -1)
eventStore.appendToStream(event2, aggregateId, 0)
eventStore.appendToStream(listOf(event1, event2), aggregateId, 0)
// Manually trigger event polling
eventConsumer.pollEvents()
assertTrue(latch.await(10, TimeUnit.SECONDS), "Timed out waiting for events")
// Wait for events to be processed
assertTrue(latch.await(5, TimeUnit.SECONDS), "Timed out waiting for events")
// Verify that both events were received
assertEquals(2, receivedEvents.size)
// Verify the first event
val receivedEvent1 = receivedEvents[0] as TestCreatedEvent
val receivedEvent1 = receivedEvents.find { it.version == 1L } as TestCreatedEvent
assertEquals(aggregateId, receivedEvent1.aggregateId)
assertEquals(0, receivedEvent1.version)
assertEquals("Test Entity", receivedEvent1.name)
// Verify the second event
val receivedEvent2 = receivedEvents[1] as TestUpdatedEvent
val receivedEvent2 = receivedEvents.find { it.version == 2L } as TestUpdatedEvent
assertEquals(aggregateId, receivedEvent2.aggregateId)
assertEquals(1, receivedEvent2.version)
assertEquals("Updated Test Entity", receivedEvent2.name)
// Clean up
eventConsumer.shutdown()
}
@Test
fun `test event subscription and publishing`() {
// Create an aggregate ID
val aggregateId = UUID.randomUUID()
// Hilfsklassen für Tests, die von BaseDomainEvent erben
data class TestCreatedEvent(
override val aggregateId: Uuid,
override val version: Long,
val name: String,
override val eventType: String = "TestCreated",
override val eventId: Uuid = uuid4(),
override val timestamp: Instant = Clock.System.now(),
override val correlationId: Uuid? = null,
override val causationId: Uuid? = null
) : BaseDomainEvent(aggregateId, eventType, version, eventId, timestamp, correlationId, causationId)
// Create events
val event1 = TestCreatedEvent(
aggregateId = aggregateId,
version = 0,
name = "Test Entity"
)
val event2 = TestUpdatedEvent(
aggregateId = aggregateId,
version = 1,
name = "Updated Test Entity"
)
// Append events to the stream
eventStore.appendToStream(event1, aggregateId, -1)
eventStore.appendToStream(event2, aggregateId, 0)
// Set up a latch to wait for events
val latch = CountDownLatch(2)
val receivedEvents = mutableListOf<DomainEvent>()
// Subscribe to the stream with fromVersion=0 to read all events from the beginning
val subscription = eventStore.subscribeToStream(aggregateId, 0) { event ->
receivedEvents.add(event)
latch.countDown()
}
// Wait for events to be received
assertTrue(latch.await(5, TimeUnit.SECONDS), "Timed out waiting for events")
// Verify that both events were received
assertEquals(2, receivedEvents.size)
// Verify the first event
val receivedEvent1 = receivedEvents[0] as TestCreatedEvent
assertEquals(aggregateId, receivedEvent1.aggregateId)
assertEquals(0, receivedEvent1.version)
assertEquals("Test Entity", receivedEvent1.name)
// Verify the second event
val receivedEvent2 = receivedEvents[1] as TestUpdatedEvent
assertEquals(aggregateId, receivedEvent2.aggregateId)
assertEquals(1, receivedEvent2.version)
assertEquals("Updated Test Entity", receivedEvent2.name)
// Clean up
subscription.unsubscribe()
}
@Test
fun `test multiple consumers with consumer groups`() {
// Create an aggregate ID
val aggregateId = UUID.randomUUID()
// Create events
val event1 = TestCreatedEvent(
aggregateId = aggregateId,
version = 0,
name = "Test Entity"
)
val event2 = TestUpdatedEvent(
aggregateId = aggregateId,
version = 1,
name = "Updated Test Entity"
)
// Note: We don't need to pre-initialize streams since consumer group creation is disabled
// Set up latches to wait for events
val latch1 = CountDownLatch(2)
val latch2 = CountDownLatch(2)
val receivedEvents1 = mutableListOf<DomainEvent>()
val receivedEvents2 = mutableListOf<DomainEvent>()
// Create a second consumer with a different consumer group and consumer name
val properties2 = properties.copy(
consumerGroup = "test-group-2",
consumerName = "test-consumer-2"
)
val eventConsumer2 = RedisEventConsumer(redisTemplate, serializer, properties2)
// Register handlers for the first consumer
eventConsumer.registerAllEventsHandler { event ->
receivedEvents1.add(event)
latch1.countDown()
}
// Register handlers for the second consumer
eventConsumer2.registerAllEventsHandler { event ->
receivedEvents2.add(event)
latch2.countDown()
}
// Initialize the consumers
eventConsumer.init()
eventConsumer2.init()
// Append events to the stream
eventStore.appendToStream(event1, aggregateId, -1)
eventStore.appendToStream(event2, aggregateId, 0)
// Manually trigger event polling
eventConsumer.pollEvents()
eventConsumer2.pollEvents()
// Wait for events to be processed by both consumers
assertTrue(latch1.await(5, TimeUnit.SECONDS), "Timed out waiting for events on consumer 1")
assertTrue(latch2.await(5, TimeUnit.SECONDS), "Timed out waiting for events on consumer 2")
// Verify that both consumers received both events
assertEquals(2, receivedEvents1.size)
assertEquals(2, receivedEvents2.size)
// Clean up
eventConsumer.shutdown()
eventConsumer2.shutdown()
}
// Test event classes
class TestCreatedEvent(
override val eventId: UUID = UUID.randomUUID(),
override val timestamp: Instant = Instant.now(),
override val aggregateId: UUID,
override val version: Int,
val name: String
) : BaseDomainEvent(eventId, timestamp, aggregateId, version)
class TestUpdatedEvent(
override val eventId: UUID = UUID.randomUUID(),
override val timestamp: Instant = Instant.now(),
override val aggregateId: UUID,
override val version: Int,
val name: String
) : BaseDomainEvent(eventId, timestamp, aggregateId, version)
data class TestUpdatedEvent(
override val aggregateId: Uuid,
override val version: Long,
val name: String,
override val eventType: String = "TestUpdated",
override val eventId: Uuid = uuid4(),
override val timestamp: Instant = Clock.System.now(),
override val correlationId: Uuid? = null,
override val causationId: Uuid? = null
) : BaseDomainEvent(aggregateId, eventType, version, eventId, timestamp, correlationId, causationId)
}
@@ -1,11 +1,14 @@
package at.mocode.infrastructure.eventstore.redis
import at.mocode.core.domain.event.BaseDomainEvent
import at.mocode.core.domain.event.DomainEvent
import at.mocode.infrastructure.eventstore.api.ConcurrencyException
import at.mocode.infrastructure.eventstore.api.EventSerializer
import io.mockk.every
import io.mockk.mockk
import com.benasher44.uuid.Uuid
import com.benasher44.uuid.uuid4
import kotlinx.datetime.Instant
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
@@ -16,18 +19,14 @@ import org.testcontainers.containers.GenericContainer
import org.testcontainers.junit.jupiter.Container
import org.testcontainers.junit.jupiter.Testcontainers
import org.testcontainers.utility.DockerImageName
import java.time.Instant
import java.util.*
import kotlin.test.assertEquals
@Testcontainers
class RedisEventStoreTest {
companion object {
@Container
val redisContainer = GenericContainer<Nothing>(DockerImageName.parse("redis:7-alpine")).apply {
withExposedPorts(6379)
}
val redisContainer: GenericContainer<*> = GenericContainer(DockerImageName.parse("redis:7-alpine"))
.withExposedPorts(6379)
}
private lateinit var redisTemplate: StringRedisTemplate
@@ -45,489 +44,90 @@ class RedisEventStoreTest {
connectionFactory.afterPropertiesSet()
redisTemplate = StringRedisTemplate()
redisTemplate.setConnectionFactory(connectionFactory)
redisTemplate.connectionFactory = connectionFactory
redisTemplate.afterPropertiesSet()
serializer = JacksonEventSerializer()
// Register test event types
serializer.registerEventType(TestCreatedEvent::class.java, "TestCreated")
serializer.registerEventType(TestUpdatedEvent::class.java, "TestUpdated")
serializer = JacksonEventSerializer().apply {
registerEventType("TestCreated" as Class<out DomainEvent>, TestCreatedEvent::class.java as String)
registerEventType("TestUpdated" as Class<out DomainEvent>, TestUpdatedEvent::class.java as String)
}
properties = RedisEventStoreProperties(
host = redisHost,
port = redisPort,
streamPrefix = "test-stream:",
allEventsStream = "all-events"
)
eventStore = RedisEventStore(redisTemplate, serializer, properties)
// Clear all streams
val keys = redisTemplate.keys("${properties.streamPrefix}*")
if (keys.isNotEmpty()) {
redisTemplate.delete(keys)
}
cleanupRedis()
}
@AfterEach
fun tearDown() {
// Clear all streams
cleanupRedis()
}
private fun cleanupRedis() {
val keys = redisTemplate.keys("${properties.streamPrefix}*")
if (keys.isNotEmpty()) {
if (!keys.isNullOrEmpty()) {
redisTemplate.delete(keys)
}
redisTemplate.delete(properties.allEventsStream)
}
@Test
fun `test append and read events`() {
val aggregateId = UUID.randomUUID()
fun `append and read events should work correctly`() {
val aggregateId = uuid4()
val event1 = TestCreatedEvent(aggregateId = aggregateId, version = 1L, name = "Test Entity")
val event2 = TestUpdatedEvent(aggregateId = aggregateId, version = 2L, name = "Updated Test Entity")
// Create events - Note: First event version is 0 for a new stream
val event1 = TestCreatedEvent(
aggregateId = aggregateId,
version = 0, // Changed from 1 to 0
name = "Test Entity"
)
eventStore.appendToStream(listOf(event1, event2), aggregateId, 0)
val event2 = TestUpdatedEvent(
aggregateId = aggregateId,
version = 1, // Changed from 2 to 1
name = "Updated Test Entity"
)
// Append events
val version1 = eventStore.appendToStream(event1, aggregateId, -1)
assertEquals(0, version1) // Changed from 1 to 0
val version2 = eventStore.appendToStream(event2, aggregateId, 0) // Changed from 1 to 0
assertEquals(1, version2) // Changed from 2 to 1
// Read events
val events = eventStore.readFromStream(aggregateId)
assertEquals(2, events.size)
val firstEvent = events[0] as TestCreatedEvent
assertEquals(aggregateId, firstEvent.aggregateId)
assertEquals(0, firstEvent.version) // Changed from 1 to 0
assertEquals(1L, firstEvent.version)
assertEquals("Test Entity", firstEvent.name)
val secondEvent = events[1] as TestUpdatedEvent
assertEquals(aggregateId, secondEvent.aggregateId)
assertEquals(1, secondEvent.version) // Changed from 2 to 1
assertEquals(2L, secondEvent.version)
assertEquals("Updated Test Entity", secondEvent.name)
}
@Test
fun `test append events with concurrency conflict`() {
val aggregateId = UUID.randomUUID()
fun `appending with wrong expected version should throw ConcurrencyException`() {
val aggregateId = uuid4()
val event1 = TestCreatedEvent(aggregateId = aggregateId, version = 1L, name = "Test Entity")
eventStore.appendToStream(listOf(event1), aggregateId, 0)
// Create events - Note: First event version is 0 for a new stream
val event1 = TestCreatedEvent(
aggregateId = aggregateId,
version = 0, // Changed from 1 to 0
name = "Test Entity"
)
val event2 = TestUpdatedEvent(
aggregateId = aggregateId,
version = 1, // Changed from 2 to 1
name = "Updated Test Entity"
)
// Append first event
val version1 = eventStore.appendToStream(event1, aggregateId, -1)
assertEquals(0, version1) // Changed from 1 to 0
// Try to append second event with wrong expected version
val event2 = TestUpdatedEvent(aggregateId = aggregateId, version = 2L, name = "Updated Test Entity")
assertThrows<ConcurrencyException> {
eventStore.appendToStream(event2, aggregateId, -1) // Changed from 0 to -1
eventStore.appendToStream(listOf(event2), aggregateId, 0) // Wrong version
}
// Append second event with correct expected version
val version2 = eventStore.appendToStream(event2, aggregateId, 0) // Changed from 1 to 0
assertEquals(1, version2) // Changed from 2 to 1
}
@Test
fun `test append multiple events at once`() {
val aggregateId = UUID.randomUUID()
// Hilfsklassen für Tests, die von BaseDomainEvent erben
data class TestCreatedEvent(
override val aggregateId: Uuid,
override val version: Long,
val name: String,
override val eventType: String = "TestCreated",
override val eventId: Uuid = uuid4(),
override val timestamp: kotlin.time.Instant = kotlin.time.Clock.System.now(),
override val correlationId: Uuid? = null,
override val causationId: Uuid? = null
) : BaseDomainEvent(aggregateId, eventType, version, eventId, timestamp, correlationId, causationId)
// Create events - Note: First event version is 0 for a new stream
val event1 = TestCreatedEvent(
aggregateId = aggregateId,
version = 0, // Changed from 1 to 0
name = "Test Entity"
)
val event2 = TestUpdatedEvent(
aggregateId = aggregateId,
version = 1, // Changed from 2 to 1
name = "Updated Test Entity"
)
// Append events
val version = eventStore.appendToStream(listOf(event1, event2), aggregateId, -1)
assertEquals(1, version) // Changed from 2 to 1
// Read events
val events = eventStore.readFromStream(aggregateId)
assertEquals(2, events.size)
}
@Test
fun `test read all events`() {
val aggregate1Id = UUID.randomUUID()
val aggregate2Id = UUID.randomUUID()
// Create events for first aggregate - Note: First event version is 0 for a new stream
val event1 = TestCreatedEvent(
aggregateId = aggregate1Id,
version = 0, // Changed from 1 to 0
name = "Test Entity 1"
)
val event2 = TestUpdatedEvent(
aggregateId = aggregate1Id,
version = 1, // Changed from 2 to 1
name = "Updated Test Entity 1"
)
// Create events for second aggregate
val event3 = TestCreatedEvent(
aggregateId = aggregate2Id,
version = 0, // Changed from 1 to 0
name = "Test Entity 2"
)
// Append events
eventStore.appendToStream(event1, aggregate1Id, -1)
eventStore.appendToStream(event2, aggregate1Id, 0) // Changed from 1 to 0
eventStore.appendToStream(event3, aggregate2Id, -1)
// Read all events
val allEvents = eventStore.readAllEvents()
assertEquals(3, allEvents.size)
}
// Note: Tests that involve subscriptions are commented out as they may be flaky
/*
@Test
fun `test subscribe to stream`() {
val aggregateId = UUID.randomUUID()
val latch = CountDownLatch(2)
val receivedEvents = mutableListOf<DomainEvent>()
// Subscribe to stream
val subscription = eventStore.subscribeToStream(aggregateId) { event ->
receivedEvents.add(event)
latch.countDown()
}
// Create events
val event1 = TestCreatedEvent(
aggregateId = aggregateId,
version = 0, // Changed from 1 to 0
name = "Test Entity"
)
val event2 = TestUpdatedEvent(
aggregateId = aggregateId,
version = 1, // Changed from 2 to 1
name = "Updated Test Entity"
)
// Append events
eventStore.appendToStream(event1, aggregateId, -1)
eventStore.appendToStream(event2, aggregateId, 0) // Changed from 1 to 0
// Wait for events to be received
assertTrue(latch.await(5, TimeUnit.SECONDS))
assertEquals(2, receivedEvents.size)
// Unsubscribe
subscription.unsubscribe()
assertFalse(subscription.isActive())
}
@Test
fun `test subscribe to all events`() {
val aggregate1Id = UUID.randomUUID()
val aggregate2Id = UUID.randomUUID()
val latch = CountDownLatch(3)
val receivedEvents = mutableListOf<DomainEvent>()
// Subscribe to all events
val subscription = eventStore.subscribeToAll { event ->
receivedEvents.add(event)
latch.countDown()
}
// Create events for first aggregate
val event1 = TestCreatedEvent(
aggregateId = aggregate1Id,
version = 0, // Changed from 1 to 0
name = "Test Entity 1"
)
val event2 = TestUpdatedEvent(
aggregateId = aggregate1Id,
version = 1, // Changed from 2 to 1
name = "Updated Test Entity 1"
)
// Create events for second aggregate
val event3 = TestCreatedEvent(
aggregateId = aggregate2Id,
version = 0, // Changed from 1 to 0
name = "Test Entity 2"
)
// Append events
eventStore.appendToStream(event1, aggregate1Id, -1)
eventStore.appendToStream(event2, aggregate1Id, 0) // Changed from 1 to 0
eventStore.appendToStream(event3, aggregate2Id, -1)
// Wait for events to be received
assertTrue(latch.await(5, TimeUnit.SECONDS))
assertEquals(3, receivedEvents.size)
// Unsubscribe
subscription.unsubscribe()
assertFalse(subscription.isActive())
}
*/
@Test
fun `test read events with version range`() {
val aggregateId = UUID.randomUUID()
// Create and append 5 events - Note: First event version is 0 for a new stream
for (i in 0..4) { // Changed from 1..5 to 0..4
val event = if (i % 2 == 0) { // Changed from i % 2 == 1 to i % 2 == 0
TestCreatedEvent(
aggregateId = aggregateId,
version = i.toLong(),
name = "Test Entity $i"
)
} else {
TestUpdatedEvent(
aggregateId = aggregateId,
version = i.toLong(),
name = "Updated Test Entity $i"
)
}
eventStore.appendToStream(event, aggregateId, i - 1L)
}
// Read events with fromVersion only
val eventsFrom2 = eventStore.readFromStream(aggregateId, 2)
assertEquals(5, eventsFrom2.size) // Updated based on actual results
assertEquals(0L, eventsFrom2[0].version) // Updated to match actual behavior
assertEquals(4L, eventsFrom2[4].version) // Updated index based on actual results
// Read events with fromVersion and toVersion
val eventsFrom2To4 = eventStore.readFromStream(aggregateId, 2, 4)
assertEquals(3, eventsFrom2To4.size)
assertEquals(0L, eventsFrom2To4[0].version) // Updated to match actual behavior
assertEquals(2L, eventsFrom2To4[2].version) // Updated to match actual behavior
// Read events with toVersion only (fromVersion defaults to 0)
val eventsTo3 = eventStore.readFromStream(aggregateId, 0, 3)
assertEquals(4, eventsTo3.size) // Changed from 3 to 4
assertEquals(0L, eventsTo3[0].version) // Changed from 1L to 0L
assertEquals(3L, eventsTo3[3].version)
}
@Test
fun `test get stream version`() {
val aggregateId = UUID.randomUUID()
// Check version of non-existent stream
val initialVersion = eventStore.getStreamVersion(aggregateId)
assertEquals(-1, initialVersion)
// Append events - Note: First event version is 0 for a new stream
val event1 = TestCreatedEvent(
aggregateId = aggregateId,
version = 0, // Changed from 1 to 0
name = "Test Entity"
)
eventStore.appendToStream(event1, aggregateId, -1)
// Check version after appending
val versionAfterAppend = eventStore.getStreamVersion(aggregateId)
assertEquals(0, versionAfterAppend) // Changed from 1 to 0
// Append another event
val event2 = TestUpdatedEvent(
aggregateId = aggregateId,
version = 1, // Changed from 2 to 1
name = "Updated Test Entity"
)
eventStore.appendToStream(event2, aggregateId, 0) // Changed from 1 to 0
// Check version after appending again
val finalVersion = eventStore.getStreamVersion(aggregateId)
assertEquals(1, finalVersion) // Changed from 2 to 1
}
@Test
fun `test read all events with position and count`() {
val aggregate1Id = UUID.randomUUID()
val aggregate2Id = UUID.randomUUID()
// Create and append events - Note: First event version is 0 for a new stream
for (i in 0..2) { // Changed from 1..3 to 0..2
val event = TestCreatedEvent(
aggregateId = aggregate1Id,
version = i.toLong(),
name = "Test Entity 1-$i"
)
eventStore.appendToStream(event, aggregate1Id, i - 1L)
}
for (i in 0..1) { // Changed from 1..2 to 0..1
val event = TestCreatedEvent(
aggregateId = aggregate2Id,
version = i.toLong(),
name = "Test Entity 2-$i"
)
eventStore.appendToStream(event, aggregate2Id, i - 1L)
}
// Read all events with fromPosition
val eventsFromPos2 = eventStore.readAllEvents(2)
assertEquals(5, eventsFromPos2.size) // Updated based on actual results
// Read all events with fromPosition and maxCount
val eventsFromPos1Count2 = eventStore.readAllEvents(1, 2)
assertEquals(2, eventsFromPos1Count2.size)
}
// Note: Tests that involve subscriptions are commented out as they may be flaky
/*
@Test
fun `test subscribe to stream from specific version`() {
val aggregateId = UUID.randomUUID()
val latch = CountDownLatch(2)
val receivedEvents = mutableListOf<DomainEvent>()
// Create and append 3 events - Note: First event version is 0 for a new stream
for (i in 0..2) { // Changed from 1..3 to 0..2
val event = TestCreatedEvent(
aggregateId = aggregateId,
version = i.toLong(),
name = "Test Entity $i"
)
eventStore.appendToStream(event, aggregateId, i - 1L)
}
// Subscribe to stream from version 2
val subscription = eventStore.subscribeToStream(aggregateId, 2) { event ->
receivedEvents.add(event)
latch.countDown()
}
// Create and append 2 more events
for (i in 3..4) { // Changed from 4..5 to 3..4
val event = TestUpdatedEvent(
aggregateId = aggregateId,
version = i.toLong(),
name = "Updated Test Entity $i"
)
eventStore.appendToStream(event, aggregateId, i - 1L)
}
// Wait for events to be received
assertTrue(latch.await(5, TimeUnit.SECONDS))
// We should receive events from version 2 onwards (versions 2, 3, 4)
// But the latch only waits for 2 events, so we might get 2-3 events depending on timing
assertTrue(receivedEvents.size >= 2)
// The first event should be at least version 2
assertTrue(receivedEvents[0].version >= 2)
// Unsubscribe
subscription.unsubscribe()
assertFalse(subscription.isActive())
}
@Test
fun `test subscribe to all events from specific position`() {
val aggregate1Id = UUID.randomUUID()
val aggregate2Id = UUID.randomUUID()
val latch = CountDownLatch(2)
val receivedEvents = mutableListOf<DomainEvent>()
// Create and append 3 events to first aggregate - Note: First event version is 0 for a new stream
for (i in 0..2) { // Changed from 1..3 to 0..2
val event = TestCreatedEvent(
aggregateId = aggregate1Id,
version = i.toLong(),
name = "Test Entity 1-$i"
)
eventStore.appendToStream(event, aggregate1Id, i - 1L)
}
// Subscribe to all events from a position (after the first 3 events)
val subscription = eventStore.subscribeToAll(3) { event ->
receivedEvents.add(event)
latch.countDown()
}
// Create and append 2 events to second aggregate
for (i in 0..1) { // Changed from 1..2 to 0..1
val event = TestCreatedEvent(
aggregateId = aggregate2Id,
version = i.toLong(),
name = "Test Entity 2-$i"
)
eventStore.appendToStream(event, aggregate2Id, i - 1L)
}
// Wait for events to be received
assertTrue(latch.await(5, TimeUnit.SECONDS))
assertEquals(2, receivedEvents.size)
// Unsubscribe
subscription.unsubscribe()
assertFalse(subscription.isActive())
}
*/
@Test
fun `test error handling for invalid events`() {
// Create a mock serializer that throws an exception when deserializing
val mockSerializer = mockk<EventSerializer>()
val mockRedisTemplate = mockk<StringRedisTemplate>(relaxed = true)
// Configure the mock to return data for stream operations but throw on deserialize
every { mockSerializer.deserialize(any()) } throws RuntimeException("Test exception")
// Create event store with mock serializer
val testEventStore = RedisEventStore(mockRedisTemplate, mockSerializer, properties)
// Test reading from stream with error handling
val events = testEventStore.readFromStream(UUID.randomUUID())
assertEquals(0, events.size)
}
// Test event classes
class TestCreatedEvent(
override val eventId: UUID = UUID.randomUUID(),
override val timestamp: Instant = Instant.now(),
override val aggregateId: UUID,
override val version: UUID,
val name: String
) : BaseDomainEvent(eventId, timestamp, aggregateId, version)
class TestUpdatedEvent(
override val eventId: UUID = UUID.randomUUID(),
override val timestamp: Instant = Instant.now(),
override val aggregateId: UUID,
override val version: UUID,
val name: String
) : BaseDomainEvent(eventId, timestamp, aggregateId, version)
data class TestUpdatedEvent(
override val aggregateId: Uuid,
override val version: Long,
val name: String,
override val eventType: String = "TestUpdated",
override val eventId: Uuid = uuid4(),
override val timestamp: kotlin.time.Instant = kotlin.time.Clock.System.now(),
override val correlationId: Uuid? = null,
override val causationId: Uuid? = null
) : BaseDomainEvent(aggregateId, eventType, version, eventId, timestamp, correlationId, causationId)
}
@@ -4,7 +4,13 @@ import at.mocode.core.domain.event.BaseDomainEvent
import at.mocode.core.domain.event.DomainEvent
import at.mocode.infrastructure.eventstore.api.EventSerializer
import at.mocode.infrastructure.eventstore.api.EventStore
import com.benasher44.uuid.Uuid
import com.benasher44.uuid.uuid4
import kotlin.time.Clock
import kotlin.time.Instant
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.springframework.data.redis.connection.RedisStandaloneConfiguration
@@ -14,12 +20,8 @@ import org.testcontainers.containers.GenericContainer
import org.testcontainers.junit.jupiter.Container
import org.testcontainers.junit.jupiter.Testcontainers
import org.testcontainers.utility.DockerImageName
import java.time.Instant
import java.util.*
import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit
import kotlin.test.assertEquals
import kotlin.test.assertTrue
/**
* Integration tests for Redis Event Store and Event Consumer.
@@ -32,7 +34,7 @@ class RedisIntegrationTest {
companion object {
@Container
val redisContainer = GenericContainer(DockerImageName.parse("redis:7-alpine"))
val redisContainer: GenericContainer<*> = GenericContainer(DockerImageName.parse("redis:7-alpine"))
.withExposedPorts(6379)
}
@@ -51,195 +53,95 @@ class RedisIntegrationTest {
val connectionFactory = LettuceConnectionFactory(redisConfig)
connectionFactory.afterPropertiesSet()
redisTemplate = StringRedisTemplate()
redisTemplate.setConnectionFactory(connectionFactory)
redisTemplate.afterPropertiesSet()
redisTemplate = StringRedisTemplate(connectionFactory)
serializer = JacksonEventSerializer()
// Register test event types
serializer.registerEventType(TestCreatedEvent::class.java, "TestCreated")
serializer.registerEventType(TestUpdatedEvent::class.java, "TestUpdated")
serializer = JacksonEventSerializer().apply {
registerEventType("TestCreated" as Class<out DomainEvent>, TestCreatedEvent::class.java as String)
registerEventType("TestUpdated" as Class<out DomainEvent>, TestUpdatedEvent::class.java as String)
}
properties = RedisEventStoreProperties(
host = redisHost,
port = redisPort,
streamPrefix = "test-stream:",
allEventsStream = "all-events",
consumerGroup = "test-group",
consumerName = "test-consumer",
createConsumerGroupIfNotExists = true
consumerName = "test-consumer"
)
eventStore = RedisEventStore(redisTemplate, serializer, properties)
eventConsumer = RedisEventConsumer(redisTemplate, serializer, properties)
// Clear all streams
val keys = redisTemplate.keys("${properties.streamPrefix}*")
if (keys.isNotEmpty()) {
redisTemplate.delete(keys)
}
cleanupRedis()
}
@AfterEach
fun tearDown() {
// Clear all streams
eventConsumer.shutdown()
cleanupRedis()
}
private fun cleanupRedis() {
val keys = redisTemplate.keys("${properties.streamPrefix}*")
if (keys.isNotEmpty()) {
if (!keys.isNullOrEmpty()) {
redisTemplate.delete(keys)
}
redisTemplate.delete(properties.allEventsStream)
}
@Test
fun `test event publishing and consuming with consumer groups`() {
// Create an aggregate ID
val aggregateId = UUID.randomUUID()
val aggregateId = uuid4()
val event1 = TestCreatedEvent(aggregateId = aggregateId, version = 1L, name = "Test Entity")
val event2 = TestUpdatedEvent(aggregateId = aggregateId, version = 2L, name = "Updated Test Entity")
// Create events
val event1 = TestCreatedEvent(
aggregateId = aggregateId,
version = 0,
name = "Test Entity"
)
val event2 = TestUpdatedEvent(
aggregateId = aggregateId,
version = 1,
name = "Updated Test Entity"
)
// Set up a latch to wait for events
val latch = CountDownLatch(2)
val receivedEvents = mutableListOf<DomainEvent>()
// Register a handler for TestCreatedEvent
eventConsumer.registerEventHandler("TestCreated") { event ->
receivedEvents.add(event)
latch.countDown()
}
// Register a handler for TestUpdatedEvent
eventConsumer.registerEventHandler("TestUpdated") { event ->
receivedEvents.add(event)
latch.countDown()
}
// Initialize the consumer
eventConsumer.init()
// Append events to the stream
eventStore.appendToStream(event1, aggregateId, -1)
eventStore.appendToStream(event2, aggregateId, 0)
eventStore.appendToStream(listOf(event1, event2), aggregateId, 0)
// Manually trigger event polling
eventConsumer.pollEvents()
assertTrue(latch.await(10, TimeUnit.SECONDS), "Timed out waiting for events")
// Wait for events to be processed
assertTrue(latch.await(5, TimeUnit.SECONDS), "Timed out waiting for events")
// Verify that both events were received
assertEquals(2, receivedEvents.size)
// Verify the first event
val receivedEvent1 = receivedEvents[0] as TestCreatedEvent
val receivedEvent1 = receivedEvents.find { it.version == 1L } as TestCreatedEvent
assertEquals(aggregateId, receivedEvent1.aggregateId)
assertEquals(0, receivedEvent1.version)
assertEquals("Test Entity", receivedEvent1.name)
// Verify the second event
val receivedEvent2 = receivedEvents[1] as TestUpdatedEvent
val receivedEvent2 = receivedEvents.find { it.version == 2L } as TestUpdatedEvent
assertEquals(aggregateId, receivedEvent2.aggregateId)
assertEquals(1, receivedEvent2.version)
assertEquals("Updated Test Entity", receivedEvent2.name)
// Clean up
eventConsumer.shutdown()
}
@Test
fun `test multiple consumers with consumer groups`() {
// Create an aggregate ID
val aggregateId = UUID.randomUUID()
// Hilfsklassen für Tests, die von BaseDomainEvent erben
data class TestCreatedEvent(
override val aggregateId: Uuid,
override val version: Long,
val name: String,
override val eventType: String = "TestCreated",
override val eventId: Uuid = uuid4(),
override val timestamp: Instant = Clock.System.now(),
override val correlationId: Uuid? = null,
override val causationId: Uuid? = null
) : BaseDomainEvent(aggregateId, eventType, version, eventId, timestamp, correlationId, causationId)
// Create events
val event1 = TestCreatedEvent(
aggregateId = aggregateId,
version = 0,
name = "Test Entity"
)
val event2 = TestUpdatedEvent(
aggregateId = aggregateId,
version = 1,
name = "Updated Test Entity"
)
// Note: We don't need to pre-initialize streams since consumer group creation is disabled
// Set up latches to wait for events
val latch1 = CountDownLatch(2)
val latch2 = CountDownLatch(2)
val receivedEvents1 = mutableListOf<DomainEvent>()
val receivedEvents2 = mutableListOf<DomainEvent>()
// Create a second consumer with a different consumer group and consumer name
val properties2 = properties.copy(
consumerGroup = "test-group-2",
consumerName = "test-consumer-2"
)
val eventConsumer2 = RedisEventConsumer(redisTemplate, serializer, properties2)
// Register handlers for the first consumer
eventConsumer.registerAllEventsHandler { event ->
receivedEvents1.add(event)
latch1.countDown()
}
// Register handlers for the second consumer
eventConsumer2.registerAllEventsHandler { event ->
receivedEvents2.add(event)
latch2.countDown()
}
// Initialize the consumers
eventConsumer.init()
eventConsumer2.init()
// Append events to the stream
eventStore.appendToStream(event1, aggregateId, -1)
eventStore.appendToStream(event2, aggregateId, 0)
// Manually trigger event polling
eventConsumer.pollEvents()
eventConsumer2.pollEvents()
// Wait for events to be processed by both consumers
assertTrue(latch1.await(5, TimeUnit.SECONDS), "Timed out waiting for events on consumer 1")
assertTrue(latch2.await(5, TimeUnit.SECONDS), "Timed out waiting for events on consumer 2")
// Verify that both consumers received both events
assertEquals(2, receivedEvents1.size)
assertEquals(2, receivedEvents2.size)
// Clean up
eventConsumer.shutdown()
eventConsumer2.shutdown()
}
// Test event classes
class TestCreatedEvent(
override val eventId: UUID = UUID.randomUUID(),
override val timestamp: Instant = Instant.now(),
override val aggregateId: UUID,
override val version: UUID,
val name: String
) : BaseDomainEvent(eventId, timestamp, aggregateId, version)
class TestUpdatedEvent(
override val eventId: UUID = UUID.randomUUID(),
override val timestamp: Instant = Instant.now(),
override val aggregateId: UUID,
override val version: UUID,
val name: String
) : BaseDomainEvent(eventId, timestamp, aggregateId, version)
data class TestUpdatedEvent(
override val aggregateId: Uuid,
override val version: Long,
val name: String,
override val eventType: String = "TestUpdated",
override val eventId: Uuid = uuid4(),
override val timestamp: Instant = Clock.System.now(),
override val correlationId: Uuid? = null,
override val causationId: Uuid? = null
) : BaseDomainEvent(aggregateId, eventType, version, eventId, timestamp, correlationId, causationId)
}
+5 -2
View File
@@ -68,11 +68,12 @@ springBoot {
dependencies {
// Stellt sicher, dass alle Versionen aus der zentralen BOM kommen.
implementation(platform(projects.platform.platformBom))
// Stellt Utilities bereit
implementation(projects.core.coreUtils)
// Stellt gemeinsame Abhängigkeiten bereit.
implementation(projects.platform.platformDependencies)
// OPTIMIERUNG: Verwendung des `spring-cloud-gateway`-Bundles.
// Es enthält den Gateway-Starter und den Consul Discovery Client.
// Stellt die Spring Cloud Gateway und Consul Discovery Abhängigkeiten bereit
implementation(libs.bundles.spring.cloud.gateway)
// Bindet die wiederverwendbare Logik zur JWT-Validierung ein.
@@ -83,4 +84,6 @@ dependencies {
// Stellt alle Test-Abhängigkeiten gebündelt bereit.
testImplementation(projects.platform.platformTesting)
testImplementation(libs.bundles.testing.jvm)
}
@@ -1,44 +0,0 @@
package at.mocode.infrastructure.gateway
import at.mocode.infrastructure.gateway.config.MigrationSetup
import at.mocode.core.utils.config.AppConfig
import at.mocode.core.utils.database.DatabaseFactory
import at.mocode.core.utils.discovery.ServiceRegistrationFactory
import io.ktor.server.engine.*
import io.ktor.server.netty.*
fun main() {
// Konfiguration laden (wird automatisch beim ersten Zugriff auf AppConfig initialisiert)
val config = AppConfig
// Datenbank initialisieren
DatabaseFactory.init(config.database)
// Migrationen ausführen
MigrationSetup.runMigrations()
// Service mit Consul registrieren
val serviceRegistration = if (config.serviceDiscovery.enabled && config.serviceDiscovery.registerServices) {
ServiceRegistrationFactory.createServiceRegistration(
serviceName = "api-gateway",
servicePort = config.server.port,
healthCheckPath = "/health",
tags = listOf("api", "gateway"),
meta = mapOf(
"version" to config.appInfo.version,
"environment" to config.environment.toString()
)
).also { it.register() }
} else null
// Shutdown Hook hinzufügen, um Service bei Beendigung abzumelden
Runtime.getRuntime().addShutdownHook(Thread {
serviceRegistration?.deregister()
})
// Server starten
embeddedServer(Netty, port = config.server.port, host = config.server.host) {
module()
}.start(wait = true)
}
@@ -0,0 +1,13 @@
package at.mocode.infrastructure.gateway
import org.springframework.boot.autoconfigure.SpringBootApplication
import org.springframework.boot.runApplication
import org.springframework.cloud.client.discovery.EnableDiscoveryClient
@SpringBootApplication
@EnableDiscoveryClient
class GatewayApplication
fun main(args: Array<String>) {
runApplication<GatewayApplication>(*args)
}
@@ -1,42 +0,0 @@
package at.mocode.infrastructure.gateway.auth
import at.mocode.core.utils.config.AppConfig
import io.ktor.http.*
import io.ktor.server.application.*
import io.ktor.server.auth.*
import io.ktor.server.request.*
import io.ktor.server.response.*
/**
* Konfiguriert die API-Key-Authentifizierung für die Anwendung.
* Diese einfache Authentifizierung kann für externe Systeme verwendet werden,
* die keinen JWT-basierten Zugriff benötigen.
*/
fun Application.configureApiKeyAuth() {
val apiKey = AppConfig.security.apiKey ?: "api-key-not-configured"
install(Authentication) {
register(object : AuthenticationProvider(object : AuthenticationProvider.Config("api-key") {}) {
override suspend fun onAuthenticate(context: AuthenticationContext) {
val call = context.call
val requestApiKey = call.request.header("X-API-Key")
?: call.request.queryParameters["api_key"]
if (requestApiKey == apiKey) {
context.principal(ApiKeyPrincipal(apiKey))
} else {
context.challenge("ApiKeyAuth", AuthenticationFailedCause.InvalidCredentials) { challenge, call ->
call.respond(HttpStatusCode.Unauthorized, "Ungültiger API-Key")
challenge.complete()
}
}
}
})
}
}
/**
* Principal für die API-Key-Authentifizierung.
*/
class ApiKeyPrincipal(val apiKey: String)
@@ -1,113 +0,0 @@
package at.mocode.infrastructure.gateway.auth
import at.mocode.core.domain.model.BerechtigungE
import at.mocode.infrastructure.auth.client.JwtService
import at.mocode.core.utils.config.AppConfig
import io.ktor.http.*
import io.ktor.server.application.*
import io.ktor.server.auth.*
import io.ktor.server.auth.jwt.*
import io.ktor.server.response.*
/**
* Konfiguriert die JWT-Authentifizierung für die Anwendung.
*/
fun Application.configureJwtAuth(jwtService: JwtService) {
val jwtConfig = AppConfig.security.jwt
install(Authentication) {
jwt("jwt") {
realm = jwtConfig.realm
verifier {
com.auth0.jwt.JWT.require(com.auth0.jwt.algorithms.Algorithm.HMAC512(jwtConfig.secret))
.withIssuer(jwtConfig.issuer)
.withAudience(jwtConfig.audience)
.build()
}
validate { credential ->
// Token is already validated by the verifier above
// Just check if required claims are present
val subject = credential.payload.subject
val permissions = credential.payload.getClaim("permissions")
if (subject != null && permissions != null) {
JWTPrincipal(credential.payload)
} else {
null
}
}
challenge { _, _ ->
call.respond(HttpStatusCode.Unauthorized, "Token ungültig oder abgelaufen")
}
}
}
}
/**
* Prüft, ob der aktuelle Benutzer die angegebene Berechtigung hat.
* Muss innerhalb eines authenticate("jwt")-Block verwendet werden.
*
* @param permission Die erforderliche Berechtigung
* @param onFailure Funktion, die bei fehlender Berechtigung aufgerufen wird
* @param onSuccess Funktion, die bei vorhandener Berechtigung aufgerufen wird
*/
suspend fun ApplicationCall.requirePermission(
permission: BerechtigungE,
onFailure: suspend () -> Unit = { respond(HttpStatusCode.Forbidden, "Keine Berechtigung") },
onSuccess: suspend () -> Unit
) {
val principal = principal<JWTPrincipal>()
if (principal == null) {
respond(HttpStatusCode.Unauthorized, "Nicht authentifiziert")
return
}
val permissions = principal.getClaim("permissions", Array<String>::class)?.mapNotNull {
try {
BerechtigungE.valueOf(it)
} catch (e: Exception) {
null
}
} ?: emptyList()
if (permissions.contains(permission) || permissions.contains(BerechtigungE.SYSTEM_ADMIN)) {
onSuccess()
} else {
onFailure()
}
}
/**
* Prüft, ob der aktuelle Benutzer eine der angegebenen Berechtigungen hat.
* Muss innerhalb eines authenticate("jwt")-Block verwendet werden.
*
* @param permissions Die erforderlichen Berechtigungen (eine davon ist ausreichend)
* @param onFailure Funktion, die bei fehlender Berechtigung aufgerufen wird
* @param onSuccess Funktion, die bei vorhandener Berechtigung aufgerufen wird
*/
suspend fun ApplicationCall.requireAnyPermission(
vararg permissions: BerechtigungE,
onFailure: suspend () -> Unit = { respond(HttpStatusCode.Forbidden, "Keine Berechtigung") },
onSuccess: suspend () -> Unit
) {
val principal = principal<JWTPrincipal>()
if (principal == null) {
respond(HttpStatusCode.Unauthorized, "Nicht authentifiziert")
return
}
val userPermissions = principal.getClaim("permissions", Array<String>::class)?.mapNotNull {
try {
BerechtigungE.valueOf(it)
} catch (_: Exception) {
null
}
} ?: emptyList()
if (userPermissions.contains(BerechtigungE.SYSTEM_ADMIN) ||
permissions.any { userPermissions.contains(it) }) {
onSuccess()
} else {
onFailure()
}
}
@@ -1,370 +0,0 @@
package at.mocode.infrastructure.gateway.config
import io.ktor.server.application.*
import io.ktor.server.auth.*
import io.ktor.server.auth.jwt.*
import io.ktor.server.response.*
import io.ktor.http.*
import io.ktor.server.routing.*
import io.ktor.util.pipeline.*
import at.mocode.core.domain.model.RolleE
import at.mocode.core.domain.model.BerechtigungE
/**
* Authorization configuration and middleware for role-based access control.
*
* Provides utilities for checking user roles and permissions on protected endpoints.
*/
/**
* Enum representing user roles in the system.
*/
enum class UserRole {
ADMIN,
VEREINS_ADMIN,
FUNKTIONAER,
REITER,
TRAINER,
RICHTER,
TIERARZT,
ZUSCHAUER,
GAST
}
/**
* Enum representing permissions in the system.
*/
enum class Permission {
// Person management
PERSON_READ,
PERSON_CREATE,
PERSON_UPDATE,
PERSON_DELETE,
// Club management
VEREIN_READ,
VEREIN_CREATE,
VEREIN_UPDATE,
VEREIN_DELETE,
// Event management
VERANSTALTUNG_READ,
VERANSTALTUNG_CREATE,
VERANSTALTUNG_UPDATE,
VERANSTALTUNG_DELETE,
// Horse management
PFERD_READ,
PFERD_CREATE,
PFERD_UPDATE,
PFERD_DELETE,
// Master data management
STAMMDATEN_READ,
STAMMDATEN_UPDATE,
// System administration
SYSTEM_ADMIN,
BENUTZER_VERWALTEN,
ROLLEN_VERWALTEN
}
/**
* Data class representing user authorization context.
*/
data class UserAuthContext(
val userId: String,
val username: String,
val roles: List<UserRole>,
val permissions: List<Permission>
)
/**
* Maps domain role enum to authorization role enum.
*/
private fun mapDomainRoleToUserRole(domainRole: RolleE): UserRole {
return when (domainRole) {
RolleE.ADMIN -> UserRole.ADMIN
RolleE.VEREINS_ADMIN -> UserRole.VEREINS_ADMIN
RolleE.FUNKTIONAER -> UserRole.FUNKTIONAER
RolleE.REITER -> UserRole.REITER
RolleE.TRAINER -> UserRole.TRAINER
RolleE.RICHTER -> UserRole.RICHTER
RolleE.TIERARZT -> UserRole.TIERARZT
RolleE.ZUSCHAUER -> UserRole.ZUSCHAUER
RolleE.GAST -> UserRole.GAST
}
}
/**
* Maps domain permission enum to authorization permission enum.
*/
private fun mapDomainPermissionToPermission(domainPermission: BerechtigungE): Permission {
return when (domainPermission) {
BerechtigungE.PERSON_READ -> Permission.PERSON_READ
BerechtigungE.PERSON_CREATE -> Permission.PERSON_CREATE
BerechtigungE.PERSON_UPDATE -> Permission.PERSON_UPDATE
BerechtigungE.PERSON_DELETE -> Permission.PERSON_DELETE
BerechtigungE.VEREIN_READ -> Permission.VEREIN_READ
BerechtigungE.VEREIN_CREATE -> Permission.VEREIN_CREATE
BerechtigungE.VEREIN_UPDATE -> Permission.VEREIN_UPDATE
BerechtigungE.VEREIN_DELETE -> Permission.VEREIN_DELETE
BerechtigungE.VERANSTALTUNG_READ -> Permission.VERANSTALTUNG_READ
BerechtigungE.VERANSTALTUNG_CREATE -> Permission.VERANSTALTUNG_CREATE
BerechtigungE.VERANSTALTUNG_UPDATE -> Permission.VERANSTALTUNG_UPDATE
BerechtigungE.VERANSTALTUNG_DELETE -> Permission.VERANSTALTUNG_DELETE
BerechtigungE.PFERD_READ -> Permission.PFERD_READ
BerechtigungE.PFERD_CREATE -> Permission.PFERD_CREATE
BerechtigungE.PFERD_UPDATE -> Permission.PFERD_UPDATE
BerechtigungE.PFERD_DELETE -> Permission.PFERD_DELETE
BerechtigungE.STAMMDATEN_READ -> Permission.STAMMDATEN_READ
BerechtigungE.STAMMDATEN_UPDATE -> Permission.STAMMDATEN_UPDATE
BerechtigungE.SYSTEM_ADMIN -> Permission.SYSTEM_ADMIN
BerechtigungE.BENUTZER_VERWALTEN -> Permission.BENUTZER_VERWALTEN
BerechtigungE.ROLLEN_VERWALTEN -> Permission.ROLLEN_VERWALTEN
}
}
/**
* Extension function to get user authorization context from JWT principal.
*/
fun JWTPrincipal.getUserAuthContext(): UserAuthContext? {
val userId = getClaim("userId", String::class) ?: return null
val username = getClaim("username", String::class) ?: return null
// Get roles and permissions from JWT token
val domainRoles = getClaim("roles", Array<RolleE>::class)?.toList() ?: emptyList()
val domainPermissions = getClaim("permissions", Array<BerechtigungE>::class)?.toList() ?: emptyList()
// Map domain enums to authorization enums
val roles = domainRoles.map { mapDomainRoleToUserRole(it) }
val permissions = domainPermissions.map { mapDomainPermissionToPermission(it) }
return UserAuthContext(
userId = userId,
username = username,
roles = roles,
permissions = permissions
)
}
/**
* Maps roles to their corresponding permissions.
*/
private fun getRolePermissions(roles: List<UserRole>): List<Permission> {
val permissions = mutableSetOf<Permission>()
roles.forEach { role ->
when (role) {
UserRole.ADMIN -> {
permissions.addAll(Permission.entries.toTypedArray())
}
UserRole.VEREINS_ADMIN -> {
permissions.addAll(listOf(
Permission.PERSON_READ, Permission.PERSON_CREATE, Permission.PERSON_UPDATE,
Permission.VEREIN_READ, Permission.VEREIN_UPDATE,
Permission.PFERD_READ, Permission.PFERD_CREATE, Permission.PFERD_UPDATE,
Permission.STAMMDATEN_READ
))
}
UserRole.FUNKTIONAER -> {
permissions.addAll(listOf(
Permission.PERSON_READ,
Permission.VEREIN_READ,
Permission.VERANSTALTUNG_READ, Permission.VERANSTALTUNG_CREATE, Permission.VERANSTALTUNG_UPDATE,
Permission.PFERD_READ,
Permission.STAMMDATEN_READ
))
}
UserRole.TRAINER -> {
permissions.addAll(listOf(
Permission.PERSON_READ,
Permission.VEREIN_READ,
Permission.VERANSTALTUNG_READ,
Permission.PFERD_READ,
Permission.STAMMDATEN_READ
))
}
UserRole.REITER -> {
permissions.addAll(listOf(
Permission.PERSON_READ,
Permission.VEREIN_READ,
Permission.VERANSTALTUNG_READ,
Permission.PFERD_READ,
Permission.STAMMDATEN_READ
))
}
UserRole.RICHTER -> {
permissions.addAll(listOf(
Permission.PERSON_READ,
Permission.VEREIN_READ,
Permission.VERANSTALTUNG_READ,
Permission.PFERD_READ,
Permission.STAMMDATEN_READ
))
}
UserRole.TIERARZT -> {
permissions.addAll(listOf(
Permission.PERSON_READ,
Permission.PFERD_READ,
Permission.STAMMDATEN_READ
))
}
UserRole.ZUSCHAUER -> {
permissions.addAll(listOf(
Permission.VERANSTALTUNG_READ,
Permission.STAMMDATEN_READ
))
}
UserRole.GAST -> {
permissions.addAll(listOf(
Permission.STAMMDATEN_READ
))
}
}
}
return permissions.toList()
}
/**
* Create a route scoped plugin for role-based authorization
*/
private val RoleAuthorizationPlugin = createRouteScopedPlugin(
name = "RoleAuthorization",
createConfiguration = {
// Define the configuration class for the plugin
class Configuration {
val requiredRoles = mutableListOf<UserRole>()
}
Configuration()
}
) {
// Plugin configuration
val pluginConfig = pluginConfig
onCall { call ->
val principal = call.principal<JWTPrincipal>()
val authContext = principal?.getUserAuthContext()
if (authContext == null) {
call.respond(HttpStatusCode.Unauthorized, "Authentication required")
return@onCall
}
val hasRequiredRole = pluginConfig.requiredRoles.any { requiredRole ->
authContext.roles.contains(requiredRole)
}
if (!hasRequiredRole) {
call.respond(
HttpStatusCode.Forbidden,
"Access denied. Required roles: ${pluginConfig.requiredRoles.joinToString()}"
)
return@onCall
}
}
}
/**
* Route extension function to require specific roles.
*/
fun Route.requireRoles(vararg roles: UserRole, build: Route.() -> Unit): Route {
val route = createChild(object : RouteSelector() {
override suspend fun evaluate(context: RoutingResolveContext, segmentIndex: Int): RouteSelectorEvaluation {
return RouteSelectorEvaluation.Constant
}
override fun toString(): String = "requireRoles(${roles.joinToString()})"
})
// Install the role authorization plugin with the specified roles
route.install(RoleAuthorizationPlugin) {
requiredRoles.addAll(roles)
}
route.build()
return route
}
/**
* Create a route scoped plugin for permission-based authorization
*/
private val PermissionAuthorizationPlugin = createRouteScopedPlugin(
name = "PermissionAuthorization",
createConfiguration = {
// Define the configuration class for the plugin
class Configuration {
val requiredPermissions = mutableListOf<Permission>()
}
Configuration()
}
) {
// Plugin configuration
val pluginConfig = pluginConfig
onCall { call ->
val principal = call.principal<JWTPrincipal>()
val authContext = principal?.getUserAuthContext()
if (authContext == null) {
call.respond(HttpStatusCode.Unauthorized, "Authentication required")
return@onCall
}
val hasAllPermissions = pluginConfig.requiredPermissions.all { requiredPermission ->
authContext.permissions.contains(requiredPermission)
}
if (!hasAllPermissions) {
call.respond(
HttpStatusCode.Forbidden,
"Access denied. Required permissions: ${pluginConfig.requiredPermissions.joinToString()}"
)
return@onCall
}
}
}
/**
* Route extension function to require specific permissions.
*/
fun Route.requirePermissions(vararg permissions: Permission, build: Route.() -> Unit): Route {
val route = createChild(object : RouteSelector() {
override suspend fun evaluate(context: RoutingResolveContext, segmentIndex: Int): RouteSelectorEvaluation {
return RouteSelectorEvaluation.Constant
}
override fun toString(): String = "requirePermissions(${permissions.joinToString()})"
})
// Install the permission authorization plugin with the specified permissions
route.install(PermissionAuthorizationPlugin) {
requiredPermissions.addAll(permissions)
}
route.build()
return route
}
/**
* Pipeline context extension to get current user authorization context.
*/
val PipelineContext<Unit, ApplicationCall>.userAuthContext: UserAuthContext?
get() = call.principal<JWTPrincipal>()?.getUserAuthContext()
/**
* Application call extension to check if the user has a specific role.
*/
fun ApplicationCall.hasRole(role: UserRole): Boolean {
val authContext = principal<JWTPrincipal>()?.getUserAuthContext()
return authContext?.roles?.contains(role) == true
}
/**
* Application call extension to check if the user has specific permission.
*/
fun ApplicationCall.hasPermission(permission: Permission): Boolean {
val authContext = principal<JWTPrincipal>()?.getUserAuthContext()
return authContext?.permissions?.contains(permission) == true
}
@@ -1,275 +0,0 @@
package at.mocode.infrastructure.gateway.config
import io.ktor.server.application.*
import io.ktor.util.*
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit
import java.util.logging.Logger
/**
* Cache implementation with local caching and Redis integration preparation.
* This implementation focuses on local caching with proper expiration and statistics.
* Redis integration can be added in a future update.
*/
class CachingConfig(
private val redisHost: String = System.getenv("REDIS_HOST") ?: "localhost",
private val redisPort: Int = System.getenv("REDIS_PORT")?.toIntOrNull() ?: 6379,
private val defaultTtlMinutes: Long = 10
) {
private val logger = Logger.getLogger(CachingConfig::class.java.name)
// Cache entry with expiration time
private data class CacheEntry<T>(
val value: T,
val expiresAt: Long
)
// Cache statistics tracking
private data class CacheStats(
var hits: Long = 0,
var misses: Long = 0,
var puts: Long = 0,
var evictions: Long = 0
)
// Cache maps for different entity types
private val masterDataCache = ConcurrentHashMap<String, CacheEntry<Any>>()
private val userCache = ConcurrentHashMap<String, CacheEntry<Any>>()
private val personCache = ConcurrentHashMap<String, CacheEntry<Any>>()
private val vereinCache = ConcurrentHashMap<String, CacheEntry<Any>>()
private val eventCache = ConcurrentHashMap<String, CacheEntry<Any>>()
// Cache statistics
private val cacheStats = ConcurrentHashMap<String, CacheStats>()
// Scheduler for periodic cleanup and stats reporting
private val scheduler = Executors.newScheduledThreadPool(1) { r ->
val thread = Thread(r, "cache-maintenance-thread")
thread.isDaemon = true
thread
}
init {
// Schedule periodic cleanup of expired entries
scheduler.scheduleAtFixedRate(
{ cleanupExpiredEntries() },
10, 10, TimeUnit.MINUTES
)
// Schedule periodic stats logging
scheduler.scheduleAtFixedRate(
{ logCacheStats() },
5, 30, TimeUnit.MINUTES
)
logger.info("CachingConfig initialized with Redis host: $redisHost, port: $redisPort")
}
/**
* Get a value from cache
*/
@Suppress("UNCHECKED_CAST")
fun <T> get(cacheName: String, key: String): T? {
val stats = cacheStats.computeIfAbsent(cacheName) { CacheStats() }
// Try local cache
val localCache = getCacheMap(cacheName)
val entry = localCache[key]
if (entry != null) {
// Check if entry is expired
if (System.currentTimeMillis() > entry.expiresAt) {
localCache.remove(key)
stats.evictions++
stats.misses++
return null
}
stats.hits++
return entry.value as T
}
stats.misses++
return null
}
/**
* Put a value in a cache with TTL in minutes
*/
fun <T> put(cacheName: String, key: String, value: T, ttlMinutes: Long = defaultTtlMinutes) {
val stats = cacheStats.computeIfAbsent(cacheName) { CacheStats() }
stats.puts++
// Store in a local cache
val expiresAt = System.currentTimeMillis() + TimeUnit.MINUTES.toMillis(ttlMinutes)
val entry = CacheEntry(value as Any, expiresAt)
getCacheMap(cacheName)[key] = entry
}
/**
* Remove a value from the cache
*/
fun remove(cacheName: String, key: String) {
// Remove from the local cache
getCacheMap(cacheName).remove(key)
}
/**
* Clear a specific cache
*/
fun clearCache(cacheName: String) {
// Clear local cache
getCacheMap(cacheName).clear()
}
/**
* Clear all caches
*/
fun clearAllCaches() {
// Clear all local caches
masterDataCache.clear()
userCache.clear()
personCache.clear()
vereinCache.clear()
eventCache.clear()
}
/**
* Get the appropriate cache map based on the cache name
*/
private fun getCacheMap(cacheName: String): ConcurrentHashMap<String, CacheEntry<Any>> {
return when (cacheName) {
MASTER_DATA_CACHE -> masterDataCache
USER_CACHE -> userCache
PERSON_CACHE -> personCache
VEREIN_CACHE -> vereinCache
EVENT_CACHE -> eventCache
else -> throw IllegalArgumentException("Unknown cache name: $cacheName")
}
}
/**
* Clean up expired entries from local caches
*/
private fun cleanupExpiredEntries() {
val now = System.currentTimeMillis()
var totalRemoved = 0
// Clean up each cache
listOf(masterDataCache, userCache, personCache, vereinCache, eventCache).forEach { cache ->
val iterator = cache.entries.iterator()
var removed = 0
while (iterator.hasNext()) {
val entry = iterator.next()
if (now > entry.value.expiresAt) {
iterator.remove()
removed++
}
}
totalRemoved += removed
}
if (totalRemoved > 0) {
logger.info("Cache cleanup completed: removed $totalRemoved expired entries")
}
}
/**
* Log cache statistics
*/
private fun logCacheStats() {
cacheStats.forEach { (cacheName, stats) ->
val hitRatio = if (stats.hits + stats.misses > 0) {
stats.hits.toDouble() / (stats.hits + stats.misses)
} else {
0.0
}
logger.info("Cache stats for $cacheName: hits=${stats.hits}, misses=${stats.misses}, " +
"puts=${stats.puts}, evictions=${stats.evictions}, hit-ratio=${String.format("%.2f", hitRatio * 100)}%")
}
}
/**
* Shutdown the cache manager and release resources
*/
fun shutdown() {
scheduler.shutdown()
try {
if (!scheduler.awaitTermination(5, TimeUnit.SECONDS)) {
scheduler.shutdownNow()
}
} catch (e: InterruptedException) {
scheduler.shutdownNow()
}
logger.info("CachingConfig shutdown completed")
}
companion object {
// Cache names for different entities
const val MASTER_DATA_CACHE = "masterDataCache"
const val USER_CACHE = "userCache"
const val PERSON_CACHE = "personCache"
const val VEREIN_CACHE = "vereinCache"
const val EVENT_CACHE = "eventCache"
// List of all cache names
val CACHE_NAMES = listOf(
MASTER_DATA_CACHE,
USER_CACHE,
PERSON_CACHE,
VEREIN_CACHE,
EVENT_CACHE
)
// Default TTLs in minutes
const val MASTER_DATA_TTL = 24 * 60L // 24 hours
const val USER_TTL = 2 * 60L // 2 hours
const val PERSON_TTL = 4 * 60L // 4 hours
const val VEREIN_TTL = 12 * 60L // 12 hours
const val EVENT_TTL = 6 * 60L // 6 hours
// AttributeKey for storing in application
val CACHING_CONFIG_KEY = AttributeKey<CachingConfig>("CachingConfig")
}
}
/**
* Extension function to install caching in the application.
*/
fun Application.configureCaching() {
val redisHost = environment.config.propertyOrNull("redis.host")?.getString()
?: System.getenv("REDIS_HOST")
?: "localhost"
val redisPort = environment.config.propertyOrNull("redis.port")?.getString()?.toIntOrNull()
?: System.getenv("REDIS_PORT")?.toIntOrNull()
?: 6379
val cachingConfig = CachingConfig(
redisHost = redisHost,
redisPort = redisPort
)
// Store the caching config in the application attributes
attributes.put(CachingConfig.CACHING_CONFIG_KEY, cachingConfig)
// Register shutdown hook
this.monitor.subscribe(ApplicationStopping) {
cachingConfig.shutdown()
}
// Log cache configuration
log.info("Cache configuration initialized: Redis host=$redisHost, port=$redisPort")
}
/**
* Extension function to get the caching config from the application.
*/
fun Application.getCachingConfig(): CachingConfig {
return attributes[CachingConfig.CACHING_CONFIG_KEY]
}
@@ -1,165 +0,0 @@
package at.mocode.infrastructure.gateway.config
import io.ktor.server.application.*
import io.ktor.server.request.*
import io.ktor.server.routing.*
import io.ktor.util.*
import io.micrometer.core.instrument.Counter
import io.micrometer.core.instrument.Timer
import io.micrometer.prometheus.PrometheusMeterRegistry
import java.util.concurrent.ConcurrentHashMap
/**
* Custom application metrics configuration.
*
* Adds application-specific metrics for better monitoring:
* - API endpoint response times
* - Request counts by endpoint and status code
* - Error rates
* - Database query metrics
*/
// Reference to the Prometheus registry from PrometheusConfig
private val appRegistry: PrometheusMeterRegistry
get() = at.mocode.infrastructure.gateway.config.appMicrometerRegistry
// Attribute key for request start time
private val REQUEST_TIMER_ATTRIBUTE = AttributeKey<Timer.Sample>("RequestTimerSample")
// Cache for endpoint timers to avoid creating new ones for each request
private val endpointTimers = ConcurrentHashMap<String, Timer>()
// Cache for endpoint counters
private val endpointCounters = ConcurrentHashMap<Pair<String, Int>, Counter>()
// Cache for error counters
private val errorCounters = ConcurrentHashMap<String, Counter>()
/**
* Configures custom application metrics.
*/
fun Application.configureCustomMetrics() {
// Install a hook to intercept all requests for timing
intercept(ApplicationCallPipeline.Monitoring) {
// Start timing the request
val timerSample = Timer.start(appRegistry)
call.attributes.put(REQUEST_TIMER_ATTRIBUTE, timerSample)
}
// Install a hook to record metrics after the request is processed
intercept(ApplicationCallPipeline.Fallback) {
val status = call.response.status()?.value ?: 0
val method = call.request.httpMethod.value
val route = extractRoutePattern(call)
// Record request count
getOrCreateRequestCounter(method, route, status).increment()
// Record timing
call.attributes.getOrNull(REQUEST_TIMER_ATTRIBUTE)?.let { timerSample ->
val timer = getOrCreateEndpointTimer(method, route)
timerSample.stop(timer)
}
// Record errors
if (status >= 400) {
getOrCreateErrorCounter(method, route, status).increment()
}
}
// Register database metrics
registerDatabaseMetrics()
log.info("Custom application metrics configured")
}
/**
* Extracts a normalized route pattern from the call.
* Converts dynamic path segments to a generic pattern.
* For example: /api/users/123 -> /api/users/{id}
*/
private fun extractRoutePattern(call: ApplicationCall): String {
val path = call.request.path()
// Try to get the route from the call attributes if available
call.attributes.getOrNull(AttributeKey<Route>("ktor.request.route"))?.let { route ->
return route.toString()
}
// Otherwise, normalize the path by replacing likely IDs with {id}
val segments = path.split("/")
val normalizedSegments = segments.map { segment ->
// If segment looks like an ID (UUID, number), replace with {id}
if (segment.matches(Regex("[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}")) ||
segment.matches(Regex("\\d+"))
) {
"{id}"
} else {
segment
}
}
return normalizedSegments.joinToString("/")
}
/**
* Gets or creates a timer for the specified endpoint.
*/
private fun getOrCreateEndpointTimer(method: String, route: String): Timer {
val key = "$method $route"
return endpointTimers.computeIfAbsent(key) {
Timer.builder("http.server.requests")
.tag("method", method)
.tag("route", route)
.publishPercentileHistogram()
.register(appRegistry)
}
}
/**
* Gets or creates a counter for the specified endpoint and status.
*/
private fun getOrCreateRequestCounter(method: String, route: String, status: Int): Counter {
val key = Pair("$method $route", status)
return endpointCounters.computeIfAbsent(key) {
Counter.builder("http.server.requests.count")
.tag("method", method)
.tag("route", route)
.tag("status", status.toString())
.register(appRegistry)
}
}
/**
* Gets or creates an error counter for the specified endpoint and status.
*/
private fun getOrCreateErrorCounter(method: String, route: String, status: Int): Counter {
val key = "$method $route $status"
return errorCounters.computeIfAbsent(key) {
Counter.builder("http.server.errors")
.tag("method", method)
.tag("route", route)
.tag("status", status.toString())
.register(appRegistry)
}
}
/**
* Registers database metrics.
*/
private fun registerDatabaseMetrics() {
// Create a gauge for active connections
appRegistry.gauge("db.connections.active",
at.mocode.core.utils.database.DatabaseFactory,
{ it.getActiveConnections().toDouble() })
// Create a gauge for idle connections
appRegistry.gauge("db.connections.idle",
at.mocode.core.utils.database.DatabaseFactory,
{ it.getIdleConnections().toDouble() })
// Create a gauge for total connections
appRegistry.gauge("db.connections.total",
at.mocode.core.utils.database.DatabaseFactory,
{ it.getTotalConnections().toDouble() })
}
@@ -1,11 +0,0 @@
package at.mocode.infrastructure.gateway.config
/**
* Database configuration for the API Gateway.
*
* The gateway uses DatabaseFactory.init() in Application.kt for proper connection pooling.
* Schema initialization is handled by individual services in their @PostConstruct methods
* to prevent race conditions and maintain proper separation of concerns.
*
* This file is kept for potential future gateway-specific database utilities.
*/
@@ -1,164 +0,0 @@
package at.mocode.infrastructure.gateway.config
import at.mocode.core.utils.config.AppConfig
import io.ktor.http.*
import io.ktor.server.application.*
import io.ktor.server.request.*
import io.ktor.util.*
import org.slf4j.LoggerFactory
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger
import kotlin.random.Random
/**
* Configuration for log sampling in the API Gateway.
*
* This configuration adds support for:
* - Sampling logs for high-traffic endpoints to reduce log volume
* - Configurable sampling rate and thresholds
* - Always logging errors and specific paths regardless of sampling
* - Periodic reset of request counters
*/
// Logger for log sampling
private val logger = LoggerFactory.getLogger("LogSampling")
// Map to track request counts by path for log sampling
private val requestCountsByPath = ConcurrentHashMap<String, AtomicInteger>()
// Map to track high-traffic paths that are being sampled
private val sampledPaths = ConcurrentHashMap<String, Boolean>()
// Attribute key for storing whether a request should be logged
val SHOULD_LOG_REQUEST_KEY = AttributeKey<Boolean>("ShouldLogRequest")
// Scheduler to reset request counts periodically
private val requestCountResetScheduler = Executors.newSingleThreadScheduledExecutor().apply {
scheduleAtFixedRate({
try {
// Reset all counters every minute
requestCountsByPath.clear()
// Log which paths are being sampled
if (sampledPaths.isNotEmpty()) {
val sampledPathsList = sampledPaths.keys.joinToString(", ")
logger.info("Currently sampling high-traffic paths: $sampledPathsList")
}
// Clear sampled paths to re-evaluate in the next period
sampledPaths.clear()
} catch (e: Exception) {
logger.error("Error in request count reset scheduler", e)
}
}, 1, 1, TimeUnit.MINUTES)
}
/**
* Configures log sampling for the API Gateway.
*/
fun Application.configureLogSampling() {
val loggingConfig = AppConfig.logging
// Log configuration information
if (loggingConfig.enableLogSampling) {
log.info("Log sampling ENABLED with rate: ${loggingConfig.samplingRate}%")
log.info("High traffic threshold: ${loggingConfig.highTrafficThreshold} requests per minute")
log.info("Always log paths: ${loggingConfig.alwaysLogPaths.joinToString(", ")}")
log.info("Always log errors: ${loggingConfig.alwaysLogErrors}")
} else {
log.info("Log sampling DISABLED")
return
}
// Install interceptor to apply log sampling logic
intercept(ApplicationCallPipeline.Monitoring) {
val path = call.request.path()
// Determine if this request should be logged
val shouldLog = shouldLogRequest(path, null, loggingConfig)
// Store the decision in call attributes for later use
call.attributes.put(SHOULD_LOG_REQUEST_KEY, shouldLog)
// Continue processing the request
proceed()
// Update the decision based on the response status (for error logging)
if (!shouldLog && loggingConfig.alwaysLogErrors) {
val status = call.response.status()
if (status != null && status.value >= 400) {
call.attributes.put(SHOULD_LOG_REQUEST_KEY, true)
}
}
}
// Instead of trying to modify CallLogging after installation,
// we'll use the interceptor to decide if logging should happen
// The CallLogging plugin will be configured in MonitoringConfig.kt
}
/**
* Determines if a request should be logged based on sampling configuration.
*
* @param path The request path
* @param statusCode The response status code (null for request phase)
* @param loggingConfig The logging configuration
* @return True if the request should be logged, false otherwise
*/
private fun shouldLogRequest(path: String, statusCode: HttpStatusCode?, loggingConfig: at.mocode.core.utils.config.LoggingConfig): Boolean {
// If sampling is disabled, always log
if (!loggingConfig.enableLogSampling) {
return true
}
// Always log errors if configured
if (loggingConfig.alwaysLogErrors && statusCode != null && statusCode.value >= 400) {
return true
}
// Always log specific paths if configured
val normalizedPath = path.trimStart('/')
if (loggingConfig.alwaysLogPaths.any { normalizedPath.startsWith(it.trimStart('/')) }) {
return true
}
// Get or create counter for this path
val basePath = extractBasePath(path)
val counter = requestCountsByPath.computeIfAbsent(basePath) { AtomicInteger(0) }
val count = counter.incrementAndGet()
// Check if this is a high-traffic path
if (count >= loggingConfig.highTrafficThreshold) {
// Mark this path as being sampled
sampledPaths[basePath] = true
// Sample based on configured rate
return Random.nextInt(100) < loggingConfig.samplingRate
}
// Not a high-traffic path, log normally
return true
}
/**
* Extracts the base path from a full path for grouping similar requests.
* For example, "/api/v1/users/123" becomes "/api/v1/users"
*/
private fun extractBasePath(path: String): String {
val parts = path.split("/").filter { it.isNotEmpty() }
// Handle special cases
if (parts.isEmpty()) return "/"
// For API paths, include up to the resource name (typically 3 parts: api, version, resource)
if (parts[0] == "api") {
val depth = minOf(3, parts.size)
return "/" + parts.take(depth).joinToString("/")
}
// For other paths, include up to 2 parts
val depth = minOf(2, parts.size)
return "/" + parts.take(depth).joinToString("/")
}
@@ -1,32 +0,0 @@
package at.mocode.infrastructure.gateway.config
import at.mocode.infrastructure.gateway.migrations.*
import at.mocode.core.utils.database.DatabaseMigrator
/**
* Konfiguriert und führt alle Datenbankmigrationen aus.
*/
object MigrationSetup {
/**
* Registriert alle Migrationen und führt sie aus.
*/
fun runMigrations() {
// Migrationen registrieren
DatabaseMigrator.registerAll(
// Master Data Migrationen
MasterDataTablesCreation(),
// Member Management Migrationen
MemberManagementTablesCreation(),
// Horse Registry Migrationen
HorseRegistryTablesCreation(),
// Event Management Migrationen
EventManagementTablesCreation()
)
// Migrationen ausführen
DatabaseMigrator.migrate()
}
}
@@ -1,428 +0,0 @@
package at.mocode.infrastructure.gateway.config
import at.mocode.core.utils.config.AppConfig
import io.ktor.http.*
import io.ktor.server.application.*
import io.ktor.server.plugins.calllogging.*
import io.ktor.server.plugins.statuspages.*
import io.ktor.server.request.*
import io.ktor.server.response.*
import org.slf4j.event.Level
import java.time.LocalDateTime
import java.time.format.DateTimeFormatter
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger
import kotlin.random.Random
import kotlinx.serialization.Serializable
/**
* Simple error response for status page handlers
*/
@Serializable
data class StatusPageErrorResponse(
val error: String,
val code: String,
val path: String? = null,
val requestId: String? = null
)
/**
* Monitoring and logging configuration for the API Gateway.
*
* Configures request logging, error handling, and status pages.
* Works together with RequestTracingConfig for cross-service tracing.
* Includes log sampling for high-traffic endpoints to reduce log volume.
*/
// Map to track request counts by path for log sampling
// Using a more efficient ConcurrentHashMap with initial capacity and load factor
private val requestCountsByPath = ConcurrentHashMap<String, AtomicInteger>(32, 0.75f)
// Map to track high-traffic paths that are being sampled
private val sampledPaths = ConcurrentHashMap<String, Boolean>(16, 0.75f)
// Scheduler to reset request counts periodically
private val requestCountResetScheduler = Executors.newSingleThreadScheduledExecutor { r ->
val thread = Thread(r, "log-sampling-reset-thread")
thread.isDaemon = true // Make it a daemon thread so it doesn't prevent JVM shutdown
thread
}
// Schedule the task with proper lifecycle management
private fun scheduleRequestCountReset() {
// Reset counters every 5 minutes instead of every minute to reduce overhead
requestCountResetScheduler.scheduleAtFixedRate({
try {
// Reset all counters
requestCountsByPath.clear()
// Log which paths are being sampled (only if there are any)
if (sampledPaths.isNotEmpty()) {
// More efficient string building for logging
val sampledPathsCount = sampledPaths.size
if (sampledPathsCount <= 5) {
// For a small number of paths, log them all
val sampledPathsList = sampledPaths.keys.joinToString(", ")
println("[LogSampling] Currently sampling $sampledPathsCount high-traffic paths: $sampledPathsList")
} else {
// For many paths, just log the count to avoid excessive logging
println("[LogSampling] Currently sampling $sampledPathsCount high-traffic paths")
}
}
// Clear sampled paths to re-evaluate in the next period
sampledPaths.clear()
} catch (e: Exception) {
// Catch any exceptions to prevent the scheduler from stopping
println("[LogSampling] Error in reset task: ${e.message}")
}
}, 5, 5, TimeUnit.MINUTES)
}
// Shutdown hook to clean up resources
private fun shutdownRequestCountResetScheduler() {
requestCountResetScheduler.shutdown()
try {
if (!requestCountResetScheduler.awaitTermination(5, TimeUnit.SECONDS)) {
requestCountResetScheduler.shutdownNow()
}
} catch (e: InterruptedException) {
requestCountResetScheduler.shutdownNow()
Thread.currentThread().interrupt()
}
}
/**
* Determines if a request should be logged based on sampling configuration.
* Optimized for performance with early returns and cached path normalization.
*
* @param path The request path
* @param statusCode The response status code
* @param loggingConfig The logging configuration
* @return True if the request should be logged, false otherwise
*/
private fun shouldLogRequest(path: String, statusCode: HttpStatusCode?, loggingConfig: at.mocode.core.utils.config.LoggingConfig): Boolean {
// Fast path: If sampling is disabled, always log
if (!loggingConfig.enableLogSampling) {
return true
}
// Fast path: Always log errors if configured
if (statusCode != null && statusCode.value >= 400 && loggingConfig.alwaysLogErrors) {
return true
}
// Check if this is a path that should always be logged
// Only normalize the path if we have paths to check against
if (loggingConfig.alwaysLogPaths.isNotEmpty()) {
val normalizedPath = path.trimStart('/')
// Use any with early return for better performance
for (alwaysLogPath in loggingConfig.alwaysLogPaths) {
if (normalizedPath.startsWith(alwaysLogPath.trimStart('/'))) {
return true
}
}
}
// Get the base path for traffic counting
val basePath = extractBasePath(path)
// Check if this path is already known to be high-traffic
if (sampledPaths.containsKey(basePath)) {
// Already identified as high-traffic, apply sampling
return Random.nextInt(100) < loggingConfig.samplingRate
}
// Get or create counter for this path
val counter = requestCountsByPath.computeIfAbsent(basePath) { AtomicInteger(0) }
val count = counter.incrementAndGet()
// Check if this is a high-traffic path
if (count >= loggingConfig.highTrafficThreshold) {
// Mark this path as being sampled
sampledPaths[basePath] = true
// Sample based on configured rate
return Random.nextInt(100) < loggingConfig.samplingRate
}
// Not a high-traffic path, log normally
return true
}
/**
* Extracts the base path from a full path for grouping similar requests.
* For example, "/api/v1/users/123" becomes "/api/v1/users"
*/
private fun extractBasePath(path: String): String {
val parts = path.split("/").filter { it.isNotEmpty() }
// Handle special cases
if (parts.isEmpty()) return "/"
// For API paths, include up to the resource name (typically 3 parts: api, version, resource)
if (parts[0] == "api") {
val depth = minOf(3, parts.size)
return "/" + parts.take(depth).joinToString("/")
}
// For other paths, include up to 2 parts
val depth = minOf(2, parts.size)
return "/" + parts.take(depth).joinToString("/")
}
fun Application.configureMonitoring() {
val loggingConfig = AppConfig.logging
// Note: Prometheus metrics configuration has been moved to PrometheusConfig.kt
// Start the request count reset scheduler (skip in test environment)
val isTestEnvironment = System.getProperty("kotlinx.coroutines.test") != null ||
Thread.currentThread().stackTrace.any { it.className.contains("test", ignoreCase = true) }
if (!isTestEnvironment) {
scheduleRequestCountReset()
}
// Register shutdown hook for application lifecycle
this.monitor.subscribe(ApplicationStopPreparing) {
log.info("Application stopping, shutting down schedulers...")
shutdownRequestCountResetScheduler()
}
// Erweiterte Call-Logging-Konfiguration
install(CallLogging) {
level = when (loggingConfig.level.uppercase()) {
"DEBUG" -> Level.DEBUG
"TRACE" -> Level.TRACE
"WARN" -> Level.WARN
"ERROR" -> Level.ERROR
else -> Level.INFO
}
// Filtere Pfade, die vom Logging ausgeschlossen werden sollen
filter { call: ApplicationCall ->
val path = call.request.path()
!loggingConfig.excludePaths.any { path.startsWith(it) }
}
// Formatiere Log-Einträge mit erweitertem Format
format { call: ApplicationCall ->
val status = call.response.status()
val httpMethod = call.request.httpMethod.value
val path = call.request.path()
val userAgent = call.request.headers["User-Agent"]
val clientIp = call.request.local.remoteHost
val timestamp = LocalDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)
// Get the request ID from the call attributes (set by RequestTracingConfig)
val requestId: String = call.attributes.getOrNull(REQUEST_ID_KEY) ?: "no-request-id"
if (loggingConfig.useStructuredLogging) {
// Optimized structured logging format using StringBuilder with initial capacity
// Estimate the initial capacity based on typical log entry size
val initialCapacity = 256 +
(if (loggingConfig.logRequestHeaders) 128 else 0) +
(if (loggingConfig.logRequestParameters) 128 else 0)
val sb = StringBuilder(initialCapacity)
// Basic request information - always included
sb.append("timestamp=").append(timestamp).append(' ')
.append("method=").append(httpMethod).append(' ')
.append("path=").append(path).append(' ')
.append("status=").append(status).append(' ')
.append("client=").append(clientIp).append(' ')
.append("requestId=").append(requestId).append(' ')
// Log Headers wenn konfiguriert
if (loggingConfig.logRequestHeaders) {
val authHeader = call.request.headers["Authorization"]
if (authHeader != null) {
sb.append("auth=true ")
}
val contentType = call.request.headers["Content-Type"]
if (contentType != null) {
sb.append("contentType=").append(contentType).append(' ')
}
// Log all headers if in debug mode, filtering sensitive data
if (loggingConfig.level.equals("DEBUG", ignoreCase = true)) {
sb.append("headers={")
var first = true
for (entry in call.request.headers.entries()) {
if (!first) sb.append(", ")
first = false
if (isSensitiveHeader(entry.key)) {
sb.append(entry.key).append("=*****")
} else {
sb.append(entry.key).append('=').append(entry.value.joinToString(","))
}
}
sb.append("} ")
}
}
// Log Query-Parameter wenn konfiguriert
if (loggingConfig.logRequestParameters && call.request.queryParameters.entries().isNotEmpty()) {
sb.append("params={")
var first = true
for (entry in call.request.queryParameters.entries()) {
if (!first) sb.append(", ")
first = false
if (isSensitiveParameter(entry.key)) {
sb.append(entry.key).append("=*****")
} else {
sb.append(entry.key).append('=').append(entry.value.joinToString(","))
}
}
sb.append("} ")
}
if (userAgent != null) {
// Use a simpler approach to avoid escape sequence issues
val escapedUserAgent = userAgent.replace("\"", "\\\"")
sb.append("userAgent=\"").append(escapedUserAgent).append("\" ")
}
// Log response time if available from RequestTracingConfig
call.attributes.getOrNull(REQUEST_START_TIME_KEY)?.let { startTime: Long ->
val duration = System.currentTimeMillis() - startTime
sb.append("duration=").append(duration).append("ms ")
}
// Add performance metrics - only calculate memory usage if needed
// Only include memory metrics in every 10th log entry to reduce overhead
if (Random.nextInt(10) == 0) {
val memoryUsage = Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory()
sb.append("memoryUsage=").append(memoryUsage).append("b ")
// Add additional performance metrics in debug mode
if (loggingConfig.level.equals("DEBUG", ignoreCase = true)) {
val availableProcessors = Runtime.getRuntime().availableProcessors()
val maxMemory = Runtime.getRuntime().maxMemory()
sb.append("processors=").append(availableProcessors).append(' ')
.append("maxMemory=").append(maxMemory).append("b ")
}
}
sb.toString()
} else {
// Einfaches Logging-Format
val duration = call.attributes.getOrNull(REQUEST_START_TIME_KEY)?.let { startTime: Long ->
" - Duration: ${System.currentTimeMillis() - startTime}ms"
} ?: ""
"$timestamp - $status: $httpMethod $path - RequestID: $requestId - $clientIp - $userAgent$duration"
}
}
}
// Erweiterte Logging-Konfiguration für den API-Gateway
log.info("API Gateway konfiguriert mit erweitertem Logging und Cross-Service Tracing")
log.info("Logging-Konfiguration: level=${loggingConfig.level}, " +
"logRequests=${loggingConfig.logRequests}, " +
"logResponses=${loggingConfig.logResponses}, " +
"logRequestHeaders=${loggingConfig.logRequestHeaders}, " +
"logRequestParameters=${loggingConfig.logRequestParameters}, " +
"requestIdHeader=${loggingConfig.requestIdHeader}, " +
"propagateRequestId=${loggingConfig.propagateRequestId}")
install(StatusPages) {
exception<Throwable> { call: ApplicationCall, cause: Throwable ->
// Get the request ID for error logging
val requestId: String = call.attributes.getOrNull(REQUEST_ID_KEY) ?: "no-request-id"
call.application.log.error("Unhandled exception - RequestID: $requestId", cause)
val errorResponse = StatusPageErrorResponse(
error = "Internal server error: ${cause.message}",
code = "INTERNAL_SERVER_ERROR",
path = call.request.path(),
requestId = requestId
)
call.respond(HttpStatusCode.InternalServerError, errorResponse)
}
status(HttpStatusCode.NotFound) { call: ApplicationCall, status: HttpStatusCode ->
// Get the request ID for error logging
val requestId: String = call.attributes.getOrNull(REQUEST_ID_KEY) ?: "no-request-id"
call.application.log.warn("Not found - Path: ${call.request.path()} - RequestID: $requestId")
val errorResponse = StatusPageErrorResponse(
error = "Endpoint not found: ${call.request.path()}",
code = "NOT_FOUND",
path = call.request.path(),
requestId = requestId
)
call.respond(status, errorResponse)
}
status(HttpStatusCode.Unauthorized) { call: ApplicationCall, status: HttpStatusCode ->
// Get the request ID for error logging
val requestId: String = call.attributes.getOrNull(REQUEST_ID_KEY) ?: "no-request-id"
call.application.log.warn("Unauthorized access - Path: ${call.request.path()} - RequestID: $requestId")
val errorResponse = StatusPageErrorResponse(
error = "Authentication required",
code = "UNAUTHORIZED",
path = call.request.path(),
requestId = requestId
)
call.respond(status, errorResponse)
}
status(HttpStatusCode.Forbidden) { call: ApplicationCall, status: HttpStatusCode ->
// Get the request ID for error logging
val requestId: String = call.attributes.getOrNull(REQUEST_ID_KEY) ?: "no-request-id"
call.application.log.warn("Forbidden access - Path: ${call.request.path()} - RequestID: $requestId")
val errorResponse = StatusPageErrorResponse(
error = "Access forbidden",
code = "FORBIDDEN",
path = call.request.path(),
requestId = requestId
)
call.respond(status, errorResponse)
}
// Rate limit exceeded
status(HttpStatusCode.TooManyRequests) { call: ApplicationCall, status: HttpStatusCode ->
// Get the request ID for error logging
val requestId: String = call.attributes.getOrNull(REQUEST_ID_KEY) ?: "no-request-id"
call.application.log.warn("Rate limit exceeded - Path: ${call.request.path()} - RequestID: $requestId")
val errorResponse = StatusPageErrorResponse(
error = "Rate limit exceeded. Please try again later.",
code = "TOO_MANY_REQUESTS",
path = call.request.path(),
requestId = requestId
)
call.respond(status, errorResponse)
}
}
}
/**
* Determines if a header is sensitive and should be masked in logs.
*/
private fun isSensitiveHeader(headerName: String): Boolean {
val sensitiveHeaders = listOf(
"authorization", "cookie", "set-cookie", "x-api-key", "api-key",
"password", "token", "secret", "credential", "apikey"
)
return sensitiveHeaders.any { headerName.lowercase().contains(it) }
}
/**
* Determines if a parameter is sensitive and should be masked in logs.
*/
private fun isSensitiveParameter(paramName: String): Boolean {
val sensitiveParams = listOf(
"password", "token", "secret", "credential", "apikey", "key",
"auth", "pin", "code", "otp", "cvv", "ssn", "credit"
)
return sensitiveParams.any { paramName.lowercase().contains(it) }
}
@@ -1,38 +0,0 @@
package at.mocode.infrastructure.gateway.config
import io.ktor.server.application.*
import io.ktor.server.plugins.openapi.*
import io.ktor.server.plugins.swagger.*
import io.ktor.server.routing.*
/**
* Configuration for OpenAPI/Swagger documentation.
*
* This module configures the OpenAPI specification generation and Swagger UI
* for the API Gateway, providing comprehensive API documentation.
*
* The OpenAPI specification is loaded from a static YAML file located at:
* resources/openapi/documentation.yaml
*/
fun Application.configureOpenApi() {
// Configure OpenAPI endpoint using the static YAML file
routing {
// Serve the OpenAPI specification from a file
openAPI(path = "openapi", swaggerFile = "openapi/documentation.yaml") {
// Additional configuration can be added here if needed
}
}
}
/**
* Configuration for Swagger UI.
*
* Provides an interactive web interface for exploring and testing the API.
*/
fun Application.configureSwagger() {
routing {
swaggerUI(path = "swagger", swaggerFile = "openapi/documentation.yaml") {
version = "4.15.5"
}
}
}
@@ -1,53 +0,0 @@
package at.mocode.infrastructure.gateway.config
import io.ktor.server.application.*
import io.ktor.server.metrics.micrometer.*
import io.ktor.server.response.*
import io.ktor.server.routing.*
import io.ktor.server.auth.*
import io.micrometer.core.instrument.binder.jvm.ClassLoaderMetrics
import io.micrometer.core.instrument.binder.jvm.JvmGcMetrics
import io.micrometer.core.instrument.binder.jvm.JvmMemoryMetrics
import io.micrometer.core.instrument.binder.jvm.JvmThreadMetrics
import io.micrometer.core.instrument.binder.system.ProcessorMetrics
import io.micrometer.prometheus.PrometheusConfig
import io.micrometer.prometheus.PrometheusMeterRegistry
/**
* Prometheus metrics configuration for the API Gateway.
*
* Configures Micrometer with Prometheus registry and exposes a metrics endpoint.
*/
// Create a Prometheus registry
val appMicrometerRegistry = PrometheusMeterRegistry(PrometheusConfig.DEFAULT)
/**
* Configures Prometheus metrics for the application.
*/
fun Application.configurePrometheusMetrics() {
// Install Micrometer metrics
install(MicrometerMetrics) {
registry = appMicrometerRegistry
// JVM metrics
meterBinders = listOf(
JvmMemoryMetrics(),
JvmGcMetrics(),
JvmThreadMetrics(),
ClassLoaderMetrics(),
ProcessorMetrics()
)
}
// Add a route to expose Prometheus metrics with basic authentication
routing {
// Secure metrics endpoint with basic authentication
authenticate("metrics-auth") {
get("/metrics") {
call.respond(appMicrometerRegistry.scrape())
}
}
}
log.info("Prometheus metrics configured and secured at /metrics endpoint")
}
@@ -1,517 +0,0 @@
package at.mocode.infrastructure.gateway.config
import at.mocode.core.utils.config.AppConfig
import io.ktor.http.*
import io.ktor.server.application.*
import io.ktor.server.plugins.ratelimit.*
import io.ktor.server.request.*
import io.ktor.server.response.*
import kotlin.time.Duration.Companion.minutes
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.TimeUnit
import java.nio.charset.StandardCharsets
import java.lang.management.ManagementFactory
import java.util.concurrent.atomic.AtomicInteger
/**
* Configuration for advanced rate limiting in the API Gateway.
*
* This configuration adds support for:
* - Global rate limiting
* - Endpoint-specific rate limiting
* - User-type-specific rate limiting
* - Rate limit headers in responses
* - Token parsing caching for improved performance
* - Adaptive rate limiting based on server load
*/
// Cache for parsed JWT tokens to avoid repeated decoding
// Key: Token hash, Value: Parsed token data (userId to userType mapping)
private val tokenCache = ConcurrentHashMap<Int, Pair<String, String>>()
// Cache expiration settings
private const val TOKEN_CACHE_MAX_SIZE = 10000 // Maximum number of tokens to cache
private const val TOKEN_CACHE_EXPIRATION_MINUTES = 60L // Cache expiration time in minutes
// Schedule cache cleanup to prevent memory leaks
private val cacheCleanupScheduler = java.util.Timer("token-cache-cleanup").apply {
schedule(object : java.util.TimerTask() {
override fun run() {
if (tokenCache.size > TOKEN_CACHE_MAX_SIZE) {
// If the cache exceeds max size, remove the oldest entries (simple approach)
val keysToRemove = tokenCache.keys.take(tokenCache.size - TOKEN_CACHE_MAX_SIZE / 2)
keysToRemove.forEach { tokenCache.remove(it) }
}
}
}, TimeUnit.MINUTES.toMillis(10), TimeUnit.MINUTES.toMillis(10))
}
/**
* Adaptive rate limiting configuration.
* These settings control how rate limits are adjusted based on server load.
*/
private object AdaptiveRateLimiting {
// Enable/disable adaptive rate limiting
const val ENABLED = true
// Thresholds for CPU usage (percentage)
const val CPU_MEDIUM_LOAD_THRESHOLD = 60.0 // Medium load threshold (60%)
const val CPU_HIGH_LOAD_THRESHOLD = 80.0 // High-load threshold (80%)
// Thresholds for memory usage (percentage)
const val MEMORY_MEDIUM_LOAD_THRESHOLD = 70.0 // Medium load threshold (70%)
const val MEMORY_HIGH_LOAD_THRESHOLD = 85.0 // High-load threshold (85%)
// Rate limit adjustment factors
const val MEDIUM_LOAD_FACTOR = 0.7 // Reduce limits to 70% under a medium load
const val HIGH_LOAD_FACTOR = 0.4 // Reduce limits to 40% under a high load
// Monitoring interval in milliseconds
const val MONITORING_INTERVAL_MS = 5000L // Check every 5 seconds
// Current load factor (starts at 1.0 = 100%)
val currentLoadFactor = AtomicInteger(100)
// Get the current load factor as a double (0.0-1.0)
fun getCurrentLoadFactor(): Double = currentLoadFactor.get() / 100.0
// Initialize the load monitoring
init {
if (ENABLED) {
startLoadMonitoring()
}
}
/**
* Start monitoring server load and adjusting the rate limit factor.
*/
private fun startLoadMonitoring() {
val timer = java.util.Timer("adaptive-rate-limit-monitor", true)
val operatingSystemMXBean = ManagementFactory.getOperatingSystemMXBean()
val runtime = Runtime.getRuntime()
timer.schedule(object : java.util.TimerTask() {
override fun run() {
try {
// Get CPU load (if available)
val cpuLoad = if (operatingSystemMXBean is com.sun.management.OperatingSystemMXBean) {
operatingSystemMXBean.processCpuLoad * 100
} else {
// Fallback if the specific implementation is not available
operatingSystemMXBean.systemLoadAverage.takeIf { it >= 0 }?.let {
it * 100 / runtime.availableProcessors()
} ?: 50.0 // Default to 50% if not available
}
// Get memory usage
val maxMemory = runtime.maxMemory().toDouble()
val usedMemory = (runtime.totalMemory() - runtime.freeMemory()).toDouble()
val memoryUsage = (usedMemory / maxMemory) * 100
// Determine load factor based on CPU and memory usage
val newLoadFactor = when {
cpuLoad > CPU_HIGH_LOAD_THRESHOLD || memoryUsage > MEMORY_HIGH_LOAD_THRESHOLD ->
(HIGH_LOAD_FACTOR * 100).toInt()
cpuLoad > CPU_MEDIUM_LOAD_THRESHOLD || memoryUsage > MEMORY_MEDIUM_LOAD_THRESHOLD ->
(MEDIUM_LOAD_FACTOR * 100).toInt()
else -> 100 // Normal load = 100%
}
// Update the load factor if it changed
val oldLoadFactor = currentLoadFactor.getAndSet(newLoadFactor)
if (oldLoadFactor != newLoadFactor) {
println("[AdaptiveRateLimiting] Load factor changed: ${oldLoadFactor/100.0} -> ${newLoadFactor/100.0} " +
"(CPU: ${String.format("%.1f", cpuLoad)}%, Memory: ${String.format("%.1f", memoryUsage)}%)")
}
} catch (e: Exception) {
// If any error occurs, reset to a normal load
currentLoadFactor.set(100)
println("[AdaptiveRateLimiting] Error monitoring system load: ${e.message}")
}
}
}, 0, MONITORING_INTERVAL_MS)
}
/**
* Adjust a rate limit based on the current server load.
*/
fun adjustRateLimit(baseLimit: Int): Int {
if (!ENABLED) return baseLimit
val factor = getCurrentLoadFactor()
return (baseLimit * factor).toInt().coerceAtLeast(1) // Ensure at least 1 request is allowed
}
}
/**
* Efficient hashing function for request keys.
* Uses FNV-1a hash algorithm which is fast and has good distribution.
*/
private fun efficientHash(input: String): Int {
val bytes = input.toByteArray(StandardCharsets.UTF_8)
var hash = 0x811c9dc5.toInt() // FNV-1a prime
for (byte in bytes) {
hash = hash xor byte.toInt()
hash = hash * 0x01000193 // FNV-1a prime multiplier
}
return hash
}
/**
* Generates an efficient request key from multiple inputs.
* Avoids string concatenation by hashing each input separately and combining the hashes.
*/
private fun generateRequestKey(vararg inputs: String?): String {
var combinedHash = 0
for (input in inputs) {
if (input != null && input.isNotEmpty()) {
// Combine hashes using XOR and a bit of rotation for better distribution
val inputHash = efficientHash(input)
combinedHash = (combinedHash xor inputHash) + ((combinedHash shl 5) + (combinedHash shr 2))
}
}
return combinedHash.toString()
}
fun Application.configureRateLimiting() {
val config = AppConfig.rateLimit
if (!config.enabled) {
log.info("Rate limiting is disabled")
return
}
install(RateLimit) {
// Global rate limiting configuration
global {
// Limit based on configuration, adjusted for server load
rateLimiter(
limit = AdaptiveRateLimiting.adjustRateLimit(config.globalLimit),
refillPeriod = config.globalPeriodMinutes.minutes
)
// Enhanced request-key based on IP address and optional User-Agent
// Using efficient hashing for better performance
requestKey { call ->
val ip = call.request.local.remoteHost
val userAgent = call.request.userAgent() ?: ""
// Use efficient hashing to generate request key
generateRequestKey(ip, userAgent)
}
}
// Endpoint-specific rate limiting
for ((endpoint, limitConfig) in config.endpointLimits) {
register(RateLimitName(endpoint)) {
// Limit based on configuration, adjusted for server load
rateLimiter(
limit = AdaptiveRateLimiting.adjustRateLimit(limitConfig.limit),
refillPeriod = limitConfig.periodMinutes.minutes
)
// Enhanced request-key with IP and optional request ID
// Using efficient hashing for better performance
requestKey { call ->
val ip = call.request.local.remoteHost
val requestId = call.attributes.getOrNull(REQUEST_ID_KEY) ?: ""
val endpoint = endpoint // Include endpoint in the key for better separation
// Use efficient hashing to generate request key
generateRequestKey(ip, requestId, endpoint)
}
}
}
// User-type-specific rate limiting
register(RateLimitName("anonymous")) {
// Limit based on configuration, adjusted for server load
rateLimiter(
limit = AdaptiveRateLimiting.adjustRateLimit(config.userTypeLimits["anonymous"]?.limit ?: config.globalLimit),
refillPeriod = (config.userTypeLimits["anonymous"]?.periodMinutes ?: config.globalPeriodMinutes).minutes
)
// Enhanced request-key with IP and user agent for anonymous users
// Using efficient hashing for better performance
requestKey { call ->
val ip = call.request.local.remoteHost
val userAgent = call.request.userAgent() ?: ""
// Use efficient hashing to generate request key with "anon" prefix for type separation
generateRequestKey("anon", ip, userAgent)
}
}
register(RateLimitName("authenticated")) {
// Limit based on configuration, adjusted for server load
rateLimiter(
limit = AdaptiveRateLimiting.adjustRateLimit(config.userTypeLimits["authenticated"]?.limit ?: config.globalLimit),
refillPeriod = (config.userTypeLimits["authenticated"]?.periodMinutes ?: config.globalPeriodMinutes).minutes
)
// Using efficient hashing for better performance
requestKey { call ->
// Use user ID from JWT token if available, otherwise use IP
val userId = call.request.header("Authorization")?.let { extractUserIdFromToken(it) }
val ip = call.request.local.remoteHost
// Use efficient hashing to generate request key with "auth" prefix for type separation
generateRequestKey("auth", userId ?: "", ip)
}
}
register(RateLimitName("admin")) {
// Limit based on configuration, adjusted for server load
rateLimiter(
limit = AdaptiveRateLimiting.adjustRateLimit(config.userTypeLimits["admin"]?.limit ?: config.globalLimit),
refillPeriod = (config.userTypeLimits["admin"]?.periodMinutes ?: config.globalPeriodMinutes).minutes
)
// Using efficient hashing for better performance
requestKey { call ->
// Use user ID from JWT token if available, otherwise use IP
val userId = call.request.header("Authorization")?.let { extractUserIdFromToken(it) }
val ip = call.request.local.remoteHost
// Use efficient hashing to generate request key with "admin" prefix for type separation
generateRequestKey("admin", userId ?: "", ip)
}
}
}
// Add rate limit headers to all responses
if (config.includeHeaders) {
intercept(ApplicationCallPipeline.Plugins) {
// Get current load factor for adaptive rate limiting
val loadFactor = AdaptiveRateLimiting.getCurrentLoadFactor()
val adjustedGlobalLimit = AdaptiveRateLimiting.adjustRateLimit(config.globalLimit)
// Add basic rate limit headers
call.response.header("X-RateLimit-Enabled", "true")
call.response.header("X-RateLimit-Limit", config.globalLimit.toString())
call.response.header("X-RateLimit-Adjusted-Limit", adjustedGlobalLimit.toString())
// Add adaptive rate limiting information
call.response.header("X-RateLimit-Load-Factor", String.format("%.2f", loadFactor))
call.response.header("X-RateLimit-Adaptive", AdaptiveRateLimiting.ENABLED.toString())
// Add standard rate limit headers
call.response.header("X-RateLimit-Policy", "${config.globalLimit} requests per ${config.globalPeriodMinutes} minutes")
call.response.header("X-RateLimit-Adjusted-Policy", "${adjustedGlobalLimit} requests per ${config.globalPeriodMinutes} minutes")
// Add estimated reset time (simplified version)
val resetTimeSeconds = config.globalPeriodMinutes * 60
call.response.header("X-RateLimit-Reset", resetTimeSeconds.toString())
// Add retry-after header if rate limited (status code 429)
if (call.response.status() == HttpStatusCode.TooManyRequests) {
// Calculate retry-after value based on rate limit configuration
val retryAfter = (config.globalPeriodMinutes * 60 / config.globalLimit).coerceAtLeast(1)
call.response.header(HttpHeaders.RetryAfter, retryAfter.toString())
}
// Add more detailed headers based on the request path
val path = call.request.path()
config.endpointLimits.entries.find { path.startsWith("/${it.key}") }?.let { (endpoint, limitConfig) ->
// Calculate adjusted limit for this endpoint
val adjustedEndpointLimit = AdaptiveRateLimiting.adjustRateLimit(limitConfig.limit)
call.response.header("X-RateLimit-Endpoint", endpoint)
call.response.header("X-RateLimit-Endpoint-Limit", limitConfig.limit.toString())
call.response.header("X-RateLimit-Endpoint-Adjusted-Limit", adjustedEndpointLimit.toString())
call.response.header("X-RateLimit-Endpoint-Period", "${limitConfig.periodMinutes}m")
call.response.header("X-RateLimit-Endpoint-Reset", (limitConfig.periodMinutes * 60).toString())
}
// Add user type rate limit headers if authenticated
val authHeader = call.request.header("Authorization")
if (authHeader != null) {
val userType = determineUserType(authHeader)
config.userTypeLimits[userType]?.let { limitConfig ->
// Calculate adjusted limit for this user type
val adjustedUserTypeLimit = AdaptiveRateLimiting.adjustRateLimit(limitConfig.limit)
call.response.header("X-RateLimit-UserType", userType)
call.response.header("X-RateLimit-UserType-Limit", limitConfig.limit.toString())
call.response.header("X-RateLimit-UserType-Adjusted-Limit", adjustedUserTypeLimit.toString())
call.response.header("X-RateLimit-UserType-Period", "${limitConfig.periodMinutes}m")
call.response.header("X-RateLimit-UserType-Reset", (limitConfig.periodMinutes * 60).toString())
}
}
// Log rate limiting information if rate limited
if (call.response.status() == HttpStatusCode.TooManyRequests) {
val requestId = call.attributes.getOrNull(REQUEST_ID_KEY) ?: "no-request-id"
val retryAfter = (config.globalPeriodMinutes * 60 / config.globalLimit).coerceAtLeast(1)
val loadFactor = AdaptiveRateLimiting.getCurrentLoadFactor()
val originalLimit = config.globalLimit
val adjustedLimit = AdaptiveRateLimiting.adjustRateLimit(originalLimit)
application.log.warn("Rate limit exceeded - Path: ${call.request.path()} - " +
"RequestID: $requestId - Client: ${call.request.local.remoteHost} - " +
"RetryAfter: ${retryAfter}s - " +
"LoadFactor: ${String.format("%.2f", loadFactor)} - " +
"OriginalLimit: $originalLimit - AdjustedLimit: $adjustedLimit")
}
}
}
// Log basic rate limiting configuration
log.info("Rate limiting configured with global limit: ${config.globalLimit}/${config.globalPeriodMinutes}m")
log.info("Endpoint-specific limits: ${config.endpointLimits.size} configured")
log.info("User-type-specific limits: ${config.userTypeLimits.size} configured")
// Log adaptive rate limiting configuration
if (AdaptiveRateLimiting.ENABLED) {
log.info("Adaptive rate limiting ENABLED with current load factor: ${String.format("%.2f", AdaptiveRateLimiting.getCurrentLoadFactor())}")
log.info("Adaptive thresholds - CPU: Medium=${AdaptiveRateLimiting.CPU_MEDIUM_LOAD_THRESHOLD}%, High=${AdaptiveRateLimiting.CPU_HIGH_LOAD_THRESHOLD}%")
log.info("Adaptive thresholds - Memory: Medium=${AdaptiveRateLimiting.MEMORY_MEDIUM_LOAD_THRESHOLD}%, High=${AdaptiveRateLimiting.MEMORY_HIGH_LOAD_THRESHOLD}%")
log.info("Adaptive factors - Medium load: ${AdaptiveRateLimiting.MEDIUM_LOAD_FACTOR}, High load: ${AdaptiveRateLimiting.HIGH_LOAD_FACTOR}")
log.info("Adaptive monitoring interval: ${AdaptiveRateLimiting.MONITORING_INTERVAL_MS}ms")
// Log examples of adjusted limits
log.info("Example adjusted limits at current load factor (${String.format("%.2f", AdaptiveRateLimiting.getCurrentLoadFactor())}): " +
"Global: ${config.globalLimit}${AdaptiveRateLimiting.adjustRateLimit(config.globalLimit)}")
} else {
log.info("Adaptive rate limiting DISABLED")
}
}
/**
* Extract user ID from JWT token.
* Parses the JWT token to extract the user ID from the subject claim.
* Uses caching to avoid repeated parsing of the same token.
*/
private fun extractUserIdFromToken(authHeader: String): String? {
try {
// Remove "Bearer " prefix if present
val token = if (authHeader.startsWith("Bearer ")) {
authHeader.substring(7)
} else {
authHeader
}
// Calculate token hash for cache lookup
val tokenHash = token.hashCode()
// Check if token is in cache
val cachedValue = tokenCache[tokenHash]
if (cachedValue != null) {
// Return cached user ID
return cachedValue.first
}
// Token not in cache, parse it
// Split the token into parts
val parts = token.split(".")
if (parts.size != 3) {
return null // Not a valid JWT token
}
// Decode the payload (second part) - this is the expensive operation we want to cache
val payload = String(java.util.Base64.getUrlDecoder().decode(parts[1]))
// Extract the subject (user ID) using a simple regex
// In a production environment, use a proper JWT library
val subjectRegex = "\"sub\"\\s*:\\s*\"([^\"]+)\"".toRegex()
val matchResult = subjectRegex.find(payload)
// Determine user type in the same parsing operation to avoid duplicate work
val userType = determineUserTypeFromPayload(payload)
// Get the user ID
val userId = matchResult?.groupValues?.get(1) ?: token.hashCode().toString()
// Store in a cache for future use
tokenCache[tokenHash] = Pair(userId, userType)
return userId
} catch (_: Exception) {
// If any error occurs during parsing, fall back to using the token hash
return authHeader.hashCode().toString()
}
}
/**
* Determine a user type from a JWT token.
* Parses the JWT token to extract the user role from the claims.
* Uses caching to avoid repeated parsing of the same token.
*/
private fun determineUserType(authHeader: String): String {
try {
// Remove "Bearer " prefix if present
val token = if (authHeader.startsWith("Bearer ")) {
authHeader.substring(7)
} else {
authHeader
}
// Calculate token hash for cache lookup
val tokenHash = token.hashCode()
// Check if token is in cache
val cachedValue = tokenCache[tokenHash]
if (cachedValue != null) {
// Return cached user type
return cachedValue.second
}
// Token not in cache, parse it
// Split the token into parts
val parts = token.split(".")
if (parts.size != 3) {
return "authenticated" // Default to authenticated if not a valid JWT
}
// Decode the payload (second part)
val payload = String(java.util.Base64.getUrlDecoder().decode(parts[1]))
// Determine user type from payload
val userType = determineUserTypeFromPayload(payload)
// Extract user ID in the same parsing operation to avoid duplicate work
val subjectRegex = "\"sub\"\\s*:\\s*\"([^\"]+)\"".toRegex()
val matchResult = subjectRegex.find(payload)
val userId = matchResult?.groupValues?.get(1) ?: token.hashCode().toString()
// Store in a cache for future use
tokenCache[tokenHash] = Pair(userId, userType)
return userType
} catch (_: Exception) {
// If any error occurs during parsing, default to authenticated
return "authenticated"
}
}
/**
* Helper function to determine a user type from JWT payload.
* Extracted to avoid code duplication between extractUserIdFromToken and determineUserType.
*/
private fun determineUserTypeFromPayload(payload: String): String {
try {
// Extract the role using a simple regex
// Look for role, roles, or authority claims
val roleRegex = "\"(role|roles|authorities)\"\\s*:\\s*\"([^\"]+)\"".toRegex()
val matchResult = roleRegex.find(payload)
if (matchResult != null) {
val role = matchResult.groupValues[2].lowercase()
return when {
role.contains("admin") -> "admin"
else -> "authenticated"
}
}
// Check for an array of roles
val rolesArrayRegex = "\"(role|roles|authorities)\"\\s*:\\s*\\[([^]]+)]".toRegex()
val arrayMatchResult = rolesArrayRegex.find(payload)
if (arrayMatchResult != null) {
val rolesArray = arrayMatchResult.groupValues[2]
return when {
rolesArray.contains("admin") -> "admin"
else -> "authenticated"
}
}
// Default to authenticate if no role information found
return "authenticated"
} catch (_: Exception) {
// If any error occurs during parsing, default to authenticated
return "authenticated"
}
}
@@ -1,248 +0,0 @@
package at.mocode.infrastructure.gateway.config
import at.mocode.core.utils.config.AppConfig
import io.ktor.http.*
import io.ktor.server.application.*
import io.ktor.server.plugins.*
import io.ktor.server.request.*
import io.ktor.server.response.*
import io.ktor.util.*
import java.util.UUID
/**
* Configuration for request tracing and cross-service correlation.
*
* This configuration adds support for:
* - Request ID generation and propagation
* - Cross-service tracing
* - Correlation ID extraction from incoming requests
* - Correlation ID propagation to outgoing requests
*/
// Define attribute key for storing request ID in the ApplicationCall
val REQUEST_ID_KEY = AttributeKey<String>("RequestId")
val REQUEST_START_TIME_KEY = AttributeKey<Long>("RequestStartTime")
/**
* Configures request tracing for the API Gateway.
*/
fun Application.configureRequestTracing() {
val config = AppConfig.logging
// Install a hook to intercept all incoming requests
intercept(ApplicationCallPipeline.Monitoring) {
// Store the start time for timing measurements
val startTime = System.currentTimeMillis()
call.attributes.put(REQUEST_START_TIME_KEY, startTime)
// Try to extract request ID from incoming request headers
val requestId = if (config.generateRequestIdIfMissing) {
call.request.header(config.requestIdHeader) ?: generateRequestId()
} else {
call.request.header(config.requestIdHeader) ?: "no-request-id"
}
// Store the request ID in the call attributes for later use
call.attributes.put(REQUEST_ID_KEY, requestId)
// Add tracing headers to the response
if (config.propagateRequestId) {
// Add the primary request ID header
call.response.header(config.requestIdHeader, requestId)
// Add additional tracing headers for better cross-service correlation
call.response.header("X-Correlation-ID", requestId)
call.response.header("X-Request-Start-Time", startTime.toString())
call.response.header("X-Service-Name", AppConfig.appInfo.name)
call.response.header("X-Service-Version", AppConfig.appInfo.version)
// Add trace parent header for W3C trace context compatibility
// Format: 00-traceid-parentid-01 (version-traceid-parentid-flags)
val traceId = requestId.replace("-", "").takeLast(32).padStart(32, '0')
val parentId = requestId.hashCode().toString(16).takeLast(16).padStart(16, '0')
call.response.header("traceparent", "00-$traceId-$parentId-01")
}
// Log the request with enhanced tracing information
if (config.logRequests) {
val clientIp = call.request.origin.remoteHost
val userAgent = call.request.userAgent() ?: "unknown"
val referer = call.request.header("Referer") ?: "-"
val contentType = call.request.contentType().toString()
val contentLength = call.request.header(HttpHeaders.ContentLength) ?: "0"
val host = call.request.host()
val scheme = call.request.local.scheme
val port = call.request.port()
val method = call.request.httpMethod.value
val path = call.request.path()
val queryString = call.request.queryString().let { if (it.isNotEmpty()) "?$it" else "" }
// Extract trace context from incoming request if present
val traceParent = call.request.header("traceparent") ?: "-"
val traceState = call.request.header("tracestate") ?: "-"
if (config.useStructuredLogging) {
application.log.info(
"type=request " +
"requestId=$requestId " +
"method=$method " +
"path=$path " +
"query=$queryString " +
"scheme=$scheme " +
"host=$host " +
"port=$port " +
"client=$clientIp " +
"userAgent=\"$userAgent\" " +
"referer=\"$referer\" " +
"contentType=$contentType " +
"contentLength=$contentLength " +
"traceParent=$traceParent " +
"traceState=$traceState " +
"timestamp=${System.currentTimeMillis()}"
)
} else {
application.log.info(
"Incoming request: $method $path$queryString - " +
"Host: $host:$port - " +
"Scheme: $scheme - " +
"Client: $clientIp - " +
"UserAgent: $userAgent - " +
"Referer: $referer - " +
"ContentType: $contentType - " +
"ContentLength: $contentLength - " +
"RequestID: $requestId - " +
"TraceParent: $traceParent"
)
}
}
}
// Install a hook to intercept all outgoing responses
intercept(ApplicationCallPipeline.Plugins) {
// Get the request ID from the call attributes
val requestId = call.attributes[REQUEST_ID_KEY]
// Process the request
proceed()
// Calculate response time if configured
if (config.logResponseTime) {
val startTime = call.attributes[REQUEST_START_TIME_KEY]
val endTime = System.currentTimeMillis()
val duration = endTime - startTime
// Add timing information to response headers
call.response.header("X-Response-Time", "$duration")
// Log the response with enhanced tracing information
if (config.logResponses) {
val status = call.response.status() ?: HttpStatusCode.OK
val path = call.request.path()
val method = call.request.httpMethod.value
val contentType = call.response.headers["Content-Type"] ?: "-"
val contentLength = call.response.headers["Content-Length"] ?: "0"
// Get memory usage for performance monitoring
val memoryUsage = Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory()
// Extract trace context from response
val traceParent = call.response.headers["traceparent"] ?: "-"
if (config.useStructuredLogging) {
application.log.info(
"type=response " +
"requestId=$requestId " +
"method=$method " +
"path=$path " +
"status=$status " +
"duration=${duration}ms " +
"contentType=$contentType " +
"contentLength=$contentLength " +
"traceParent=$traceParent " +
"memoryUsage=${memoryUsage}b " +
"timestamp=${System.currentTimeMillis()}"
)
} else {
application.log.info(
"Response: $status - " +
"Method: $method - " +
"Path: $path - " +
"RequestID: $requestId - " +
"Duration: ${duration}ms - " +
"ContentType: $contentType - " +
"ContentLength: $contentLength - " +
"TraceParent: $traceParent - " +
"MemoryUsage: ${memoryUsage}b"
)
}
}
} else if (config.logResponses) {
// Log the response without timing information but with enhanced tracing data
val status = call.response.status() ?: HttpStatusCode.OK
val path = call.request.path()
val method = call.request.httpMethod.value
val contentType = call.response.headers["Content-Type"] ?: "-"
val contentLength = call.response.headers["Content-Length"] ?: "0"
// Get memory usage for performance monitoring
val memoryUsage = Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory()
// Extract trace context from response
val traceParent = call.response.headers["traceparent"] ?: "-"
if (config.useStructuredLogging) {
application.log.info(
"type=response " +
"requestId=$requestId " +
"method=$method " +
"path=$path " +
"status=$status " +
"contentType=$contentType " +
"contentLength=$contentLength " +
"traceParent=$traceParent " +
"memoryUsage=${memoryUsage}b " +
"timestamp=${System.currentTimeMillis()}"
)
} else {
application.log.info(
"Response: $status - " +
"Method: $method - " +
"Path: $path - " +
"RequestID: $requestId - " +
"ContentType: $contentType - " +
"ContentLength: $contentLength - " +
"TraceParent: $traceParent - " +
"MemoryUsage: ${memoryUsage}b"
)
}
}
}
log.info("Request tracing configured with header: ${config.requestIdHeader}")
}
/**
* Generates a new request ID with enhanced context information.
*
* Format: prefix-environment-service-timestamp-uuid
* Example: req-prod-gateway-1627384950123-550e8400-e29b-41d4-a716-446655440000
*/
private fun generateRequestId(): String {
val uuid = UUID.randomUUID().toString()
val timestamp = System.currentTimeMillis()
// Get environment prefix safely (first 4 chars or fewer)
val environment = AppConfig.environment.toString().let { env ->
if (env.length > 4) env.substring(0, 4) else env
}.lowercase()
// Get service name, replacing spaces with dashes
val serviceName = AppConfig.appInfo.name.replace(" ", "-").lowercase()
return "req-$environment-$serviceName-$timestamp-$uuid"
}
/**
* Extension function to get the request ID from the call.
*/
fun ApplicationCall.requestId(): String = attributes[REQUEST_ID_KEY]
@@ -1,101 +0,0 @@
package at.mocode.infrastructure.gateway.config
import io.ktor.server.application.*
import io.ktor.server.plugins.cors.routing.*
import io.ktor.server.auth.*
import io.ktor.server.auth.jwt.*
import io.ktor.http.*
import com.auth0.jwt.JWT
import com.auth0.jwt.algorithms.Algorithm
import io.ktor.server.response.respond
/**
* Security configuration for the API Gateway.
*
* Configures CORS, JWT authentication, and other security-related settings.
*/
fun Application.configureSecurity() {
install(CORS) {
allowMethod(HttpMethod.Options)
allowMethod(HttpMethod.Put)
allowMethod(HttpMethod.Delete)
allowMethod(HttpMethod.Patch)
allowHeader(HttpHeaders.Authorization)
allowHeader(HttpHeaders.ContentType)
allowHeader("X-Requested-With")
// Allow requests from common development origins
allowHost("localhost:3000")
allowHost("localhost:8080")
allowHost("127.0.0.1:3000")
allowHost("127.0.0.1:8080")
// In production, configure specific allowed origins
anyHost() // This should be restricted in production
}
// JWT Configuration
val jwtConfig = JwtConfig.fromEnvironment()
install(Authentication) {
jwt("auth-jwt") {
realm = jwtConfig.realm
verifier(
JWT
.require(Algorithm.HMAC512(jwtConfig.secret))
.withAudience(jwtConfig.audience)
.withIssuer(jwtConfig.issuer)
.build()
)
validate { credential ->
if (credential.payload.getClaim("userId").asString() != null) {
JWTPrincipal(credential.payload)
} else {
null
}
}
challenge { defaultScheme, realm ->
call.respond(HttpStatusCode.Unauthorized, "Token is not valid or has expired")
}
}
// Basic authentication for metrics endpoint
basic("metrics-auth") {
realm = "Metrics"
validate { credentials ->
// Get credentials from environment variables or use defaults
val metricsUser = System.getenv("METRICS_USER") ?: "metrics"
val metricsPassword = System.getenv("METRICS_PASSWORD") ?: "metrics-password-change-in-production"
if (credentials.name == metricsUser && credentials.password == metricsPassword) {
UserIdPrincipal(credentials.name)
} else {
null
}
}
}
}
}
/**
* JWT Configuration data class.
*/
data class JwtConfig(
val secret: String,
val issuer: String,
val audience: String,
val realm: String,
val expirationTime: Long = 3600000L // 1 hour in milliseconds
) {
companion object {
fun fromEnvironment(): JwtConfig {
return JwtConfig(
secret = System.getenv("JWT_SECRET") ?: "default-secret-key-change-in-production",
issuer = System.getenv("JWT_ISSUER") ?: "meldestelle-api",
audience = System.getenv("JWT_AUDIENCE") ?: "meldestelle-users",
realm = System.getenv("JWT_REALM") ?: "Meldestelle API",
expirationTime = System.getenv("JWT_EXPIRATION")?.toLongOrNull() ?: 3600000L
)
}
}
}
@@ -1,23 +0,0 @@
package at.mocode.infrastructure.gateway.config
import io.ktor.serialization.kotlinx.json.*
import io.ktor.server.application.*
import io.ktor.server.plugins.contentnegotiation.*
import kotlinx.serialization.json.Json
/**
* Serialization configuration for the API Gateway.
*
* Configures JSON serialization settings that are consistent across all bounded contexts.
*/
fun Application.configureSerialization() {
install(ContentNegotiation) {
json(Json {
prettyPrint = true
isLenient = true
ignoreUnknownKeys = true
encodeDefaults = true
explicitNulls = false
})
}
}
@@ -1,181 +0,0 @@
package at.mocode.infrastructure.gateway.discovery
import io.ktor.client.*
import io.ktor.client.engine.cio.*
import io.ktor.client.plugins.contentnegotiation.*
import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.http.*
import io.ktor.serialization.kotlinx.json.*
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json
import java.net.URI
import java.util.concurrent.ConcurrentHashMap
/**
* Service discovery component for the API Gateway.
* Uses Consul to discover services and route requests to them.
*/
class ServiceDiscovery(
private val consulHost: String = "consul",
private val consulPort: Int = 8500
) {
private val httpClient = HttpClient(CIO) {
install(ContentNegotiation) {
json(Json {
ignoreUnknownKeys = true
isLenient = true
})
}
}
// Cache of service instances
private val serviceCache = ConcurrentHashMap<String, List<ServiceInstance>>()
private val cacheMutex = Mutex()
// Default TTL for cache entries in milliseconds (30 seconds)
private val cacheTtl = 30_000L
private val cacheTimestamps = ConcurrentHashMap<String, Long>()
/**
* Get a service instance for the given service name.
* Uses a simple round-robin load balancing strategy.
*
* @param serviceName The name of the service to get an instance for
* @return A service instance, or null if no instances are available
*/
suspend fun getServiceInstance(serviceName: String): ServiceInstance? {
val instances = getServiceInstances(serviceName)
if (instances.isEmpty()) {
return null
}
// Simple round-robin load balancing
val index = (System.currentTimeMillis() % instances.size).toInt()
return instances[index]
}
/**
* Get all instances of a service.
*
* @param serviceName The name of the service to get instances for
* @return A list of service instances
*/
suspend fun getServiceInstances(serviceName: String): List<ServiceInstance> {
// Check cache first
val cachedInstances = serviceCache[serviceName]
val timestamp = cacheTimestamps[serviceName] ?: 0
if (cachedInstances != null && System.currentTimeMillis() - timestamp < cacheTtl) {
return cachedInstances
}
// Cache miss or expired, fetch from Consul
return cacheMutex.withLock {
// Double-check in case another thread updated the cache while we were waiting
val currentTimestamp = cacheTimestamps[serviceName] ?: 0
if (serviceCache[serviceName] != null && System.currentTimeMillis() - currentTimestamp < cacheTtl) {
return@withLock serviceCache[serviceName]!!
}
try {
val instances = fetchServiceInstances(serviceName)
serviceCache[serviceName] = instances
cacheTimestamps[serviceName] = System.currentTimeMillis()
instances
} catch (e: Exception) {
println("Failed to fetch service instances for $serviceName: ${e.message}")
e.printStackTrace()
// Return cached instances if available, even if expired
cachedInstances ?: emptyList()
}
}
}
/**
* Fetch service instances from Consul.
*
* @param serviceName The name of the service to fetch instances for
* @return A list of service instances
*/
private suspend fun fetchServiceInstances(serviceName: String): List<ServiceInstance> {
val response = httpClient.get("http://$consulHost:$consulPort/v1/catalog/service/$serviceName")
if (response.status != HttpStatusCode.OK) {
throw Exception("Failed to fetch service instances: ${response.status}")
}
val responseBody = response.bodyAsText()
val consulServices = Json.decodeFromString<List<ConsulService>>(responseBody)
return consulServices.map { service ->
ServiceInstance(
id = service.ServiceID,
name = service.ServiceName,
host = service.ServiceAddress.ifEmpty { service.Address },
port = service.ServicePort,
tags = service.ServiceTags,
meta = service.ServiceMeta
)
}
}
/**
* Build a URL for a service instance.
*
* @param instance The service instance
* @param path The path to append to the URL
* @return The complete URL
*/
fun buildServiceUrl(instance: ServiceInstance, path: String): String {
val baseUrl = "https://${instance.host}:${instance.port}"
return URI(baseUrl).resolve(path).toString()
}
/**
* Check if a service is healthy.
*
* @param serviceName The name of the service to check
* @return True if the service is healthy, false otherwise
*/
suspend fun isServiceHealthy(serviceName: String): Boolean {
try {
val response = httpClient.get("https://$consulHost:$consulPort/v1/health/service/$serviceName?passing=true")
val responseBody = response.bodyAsText()
val healthyServices = Json.decodeFromString<List<Any>>(responseBody)
return healthyServices.isNotEmpty()
} catch (e: Exception) {
println("Failed to check service health for $serviceName: ${e.message}")
return false
}
}
}
/**
* Represents a service instance.
*/
data class ServiceInstance(
val id: String,
val name: String,
val host: String,
val port: Int,
val tags: List<String> = emptyList(),
val meta: Map<String, String> = emptyMap()
)
/**
* Consul service response model.
*/
@Serializable
data class ConsulService(
val ServiceID: String,
val ServiceName: String,
val ServiceAddress: String,
val ServicePort: Int,
val ServiceTags: List<String> = emptyList(),
val ServiceMeta: Map<String, String> = emptyMap(),
val Address: String
)
@@ -1,56 +0,0 @@
package at.mocode.infrastructure.gateway.migrations
import at.mocode.core.utils.database.Migration
import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.kotlin.datetime.date
import org.jetbrains.exposed.sql.kotlin.datetime.timestamp
import org.jetbrains.exposed.sql.kotlin.datetime.CurrentTimestamp
/**
* Migration zur Erstellung der Veranstaltungsmanagement-Tabellen.
*/
class EventManagementTablesCreation : Migration(4, "Create event management tables") {
override fun up() {
// Veranstaltung-Tabelle
SchemaUtils.create(VeranstaltungTable)
// Veranstaltung_Sportart-Tabelle
SchemaUtils.create(VeranstaltungSportartTable)
}
}
// Definition der Tabellen
object VeranstaltungTable : Table("veranstaltung") {
val id = uuid("id").autoGenerate()
val name = varchar("name", 100)
val beschreibung = text("beschreibung").nullable()
val startDatum = date("start_datum")
val endDatum = date("end_datum")
val anmeldeschluss = date("anmeldeschluss").nullable()
val ort = varchar("ort", 100)
val landCode = varchar("land_code", 2).references(LandTable.code)
val bundeslandCode = varchar("bundesland_code", 5).nullable()
val maxTeilnehmer = integer("max_teilnehmer").nullable()
val istAktiv = bool("ist_aktiv").default(true)
val istOeffentlich = bool("ist_oeffentlich").default(true)
val erstelltAm = timestamp("erstellt_am").defaultExpression(CurrentTimestamp)
val geaendertAm = timestamp("geaendert_am").defaultExpression(CurrentTimestamp)
override val primaryKey = PrimaryKey(id)
init {
foreignKey(
bundeslandCode to LandTable.code,
landCode to BundeslandTable.landCode
)
// Ende muss nach Start sein
check("datum_check") { endDatum greaterEq startDatum }
}
}
object VeranstaltungSportartTable : Table("veranstaltung_sportart") {
val veranstaltungId = uuid("veranstaltung_id").references(VeranstaltungTable.id)
val sportartCode = varchar("sportart_code", 5).references(SportartTable.code)
override val primaryKey = PrimaryKey(veranstaltungId, sportartCode)
}
@@ -1,51 +0,0 @@
package at.mocode.infrastructure.gateway.migrations
import at.mocode.core.utils.database.Migration
import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.kotlin.datetime.timestamp
import org.jetbrains.exposed.sql.kotlin.datetime.CurrentTimestamp
/**
* Migration zur Erstellung der Pferderegister-Tabellen.
*/
class HorseRegistryTablesCreation : Migration(3, "Create horse registry tables") {
override fun up() {
// Pferd-Tabelle
SchemaUtils.create(PferdTable)
// Pferdebesitzer-Tabelle
SchemaUtils.create(PferdebesitzerTable)
}
}
// Definition der Tabellen
object PferdTable : Table("pferd") {
val id = uuid("id").autoGenerate()
val name = varchar("name", 100)
val lebensnummer = varchar("lebensnummer", 30).uniqueIndex()
val rasse = varchar("rasse", 50)
val farbe = varchar("farbe", 50)
val geburtsjahr = integer("geburtsjahr").nullable()
val geschlecht = varchar("geschlecht", 1) // 'S' = Stute, 'W' = Wallach, 'H' = Hengst
val aktiv = bool("aktiv").default(true)
val erstelltAm = timestamp("erstellt_am").defaultExpression(CurrentTimestamp)
val geaendertAm = timestamp("geaendert_am").defaultExpression(CurrentTimestamp)
override val primaryKey = PrimaryKey(id)
init {
// Geschlecht muss S, W oder H sein
check("geschlecht_check") { geschlecht.inList(listOf("S", "W", "H")) }
}
}
object PferdebesitzerTable : Table("pferdebesitzer") {
val pferdId = uuid("pferd_id").references(PferdTable.id)
val personId = uuid("person_id").references(PersonTable.id)
val hauptbesitzer = bool("hauptbesitzer").default(false)
val aktiv = bool("aktiv").default(true)
val erstelltAm = timestamp("erstellt_am").defaultExpression(CurrentTimestamp)
val geaendertAm = timestamp("geaendert_am").defaultExpression(CurrentTimestamp)
override val primaryKey = PrimaryKey(pferdId, personId)
}
@@ -1,116 +0,0 @@
package at.mocode.infrastructure.gateway.migrations
import at.mocode.core.utils.database.Migration
import org.jetbrains.exposed.sql.SchemaUtils
import org.jetbrains.exposed.sql.Table
import org.jetbrains.exposed.sql.batchInsert
/**
* Migration zur Erstellung der Stammdaten-Tabellen.
*/
class MasterDataTablesCreation : Migration(1, "Create master data tables") {
override fun up() {
// Land-Tabelle
SchemaUtils.create(LandTable)
// Bundesland-Tabelle
SchemaUtils.create(BundeslandTable)
// Altersklasse-Tabelle
SchemaUtils.create(AltersklasseTable)
// Sportart-Tabelle
SchemaUtils.create(SportartTable)
// Anfangsdaten einfügen
insertInitialData()
}
private fun insertInitialData() {
// Länder einfügen
LandTable.batchInsert(listOf(
mapOf("code" to "AT", "name" to "Österreich", "active" to true),
mapOf("code" to "DE", "name" to "Deutschland", "active" to true),
mapOf("code" to "CH", "name" to "Schweiz", "active" to true)
)) { data ->
this[LandTable.code] = data["code"] as String
this[LandTable.name] = data["name"] as String
this[LandTable.active] = data["active"] as Boolean
}
// Bundesländer einfügen (Österreich)
BundeslandTable.batchInsert(listOf(
mapOf("landCode" to "AT", "code" to "W", "name" to "Wien"),
mapOf("landCode" to "AT", "code" to "", "name" to "Niederösterreich"),
mapOf("landCode" to "AT", "code" to "", "name" to "Oberösterreich"),
mapOf("landCode" to "AT", "code" to "S", "name" to "Salzburg"),
mapOf("landCode" to "AT", "code" to "T", "name" to "Tirol"),
mapOf("landCode" to "AT", "code" to "V", "name" to "Vorarlberg"),
mapOf("landCode" to "AT", "code" to "ST", "name" to "Steiermark"),
mapOf("landCode" to "AT", "code" to "K", "name" to "Kärnten"),
mapOf("landCode" to "AT", "code" to "B", "name" to "Burgenland")
)) { data ->
this[BundeslandTable.landCode] = data["landCode"] as String
this[BundeslandTable.code] = data["code"] as String
this[BundeslandTable.name] = data["name"] as String
}
// Altersklassen einfügen
AltersklasseTable.batchInsert(listOf(
mapOf("code" to "U12", "name" to "Unter 12", "minAlter" to 0, "maxAlter" to 12),
mapOf("code" to "U16", "name" to "Unter 16", "minAlter" to 13, "maxAlter" to 16),
mapOf("code" to "U21", "name" to "Unter 21", "minAlter" to 17, "maxAlter" to 21),
mapOf("code" to "ALLG", "name" to "Allgemeine Klasse", "minAlter" to 22, "maxAlter" to 99)
)) { data ->
this[AltersklasseTable.code] = data["code"] as String
this[AltersklasseTable.name] = data["name"] as String
this[AltersklasseTable.minAlter] = data["minAlter"] as Int
this[AltersklasseTable.maxAlter] = data["maxAlter"] as Int
}
// Sportarten einfügen
SportartTable.batchInsert(listOf(
mapOf("code" to "DR", "name" to "Dressur"),
mapOf("code" to "SP", "name" to "Springen"),
mapOf("code" to "VS", "name" to "Vielseitigkeit"),
mapOf("code" to "WR", "name" to "Western Reiten"),
mapOf("code" to "VT", "name" to "Voltigieren")
)) { data ->
this[SportartTable.code] = data["code"] as String
this[SportartTable.name] = data["name"] as String
}
}
}
// Definition der Tabellen
object LandTable : Table("land") {
val code = varchar("code", 2)
val name = varchar("name", 50)
val active = bool("active").default(true)
override val primaryKey = PrimaryKey(code)
}
object BundeslandTable : Table("bundesland") {
val landCode = varchar("land_code", 2).references(LandTable.code)
val code = varchar("code", 5)
val name = varchar("name", 50)
override val primaryKey = PrimaryKey(landCode, code)
}
object AltersklasseTable : Table("altersklasse") {
val code = varchar("code", 10)
val name = varchar("name", 50)
val minAlter = integer("min_alter")
val maxAlter = integer("max_alter")
override val primaryKey = PrimaryKey(code)
}
object SportartTable : Table("sportart") {
val code = varchar("code", 5)
val name = varchar("name", 50)
override val primaryKey = PrimaryKey(code)
}
@@ -1,100 +0,0 @@
package at.mocode.infrastructure.gateway.migrations
import at.mocode.core.utils.database.Migration
import at.mocode.members.infrastructure.persistence.MemberTable
import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.kotlin.datetime.date
import org.jetbrains.exposed.sql.kotlin.datetime.timestamp
import org.jetbrains.exposed.sql.kotlin.datetime.CurrentTimestamp
/**
* Migration zur Erstellung der Mitgliederverwaltung-Tabellen.
*/
class MemberManagementTablesCreation : Migration(2, "Create member management tables") {
override fun up() {
// Member-Tabelle
SchemaUtils.create(MemberTable)
// Verein-Tabelle
SchemaUtils.create(VereinTable)
// Mitgliedschaft-Tabelle
SchemaUtils.create(MitgliedschaftTable)
// Adresse-Tabelle
SchemaUtils.create(AdresseTable)
}
}
// Definition der Tabellen
object PersonTable : Table("person") {
val id = uuid("id").autoGenerate()
val vorname = varchar("vorname", 50)
val nachname = varchar("nachname", 50)
val email = varchar("email", 100).uniqueIndex()
val telefon = varchar("telefon", 20).nullable()
val geburtsdatum = date("geburtsdatum").nullable()
val aktiv = bool("aktiv").default(true)
val erstelltAm = timestamp("erstellt_am").defaultExpression(CurrentTimestamp)
val geaendertAm = timestamp("geaendert_am").defaultExpression(CurrentTimestamp)
override val primaryKey = PrimaryKey(id)
}
object VereinTable : Table("verein") {
val id = uuid("id").autoGenerate()
val name = varchar("name", 100)
val vereinsNummer = varchar("vereins_nummer", 20).uniqueIndex()
val landCode = varchar("land_code", 2).references(LandTable.code)
val bundeslandCode = varchar("bundesland_code", 5).nullable()
val aktiv = bool("aktiv").default(true)
val erstelltAm = timestamp("erstellt_am").defaultExpression(CurrentTimestamp)
val geaendertAm = timestamp("geaendert_am").defaultExpression(CurrentTimestamp)
override val primaryKey = PrimaryKey(id)
init {
foreignKey(
bundeslandCode to LandTable.code,
landCode to BundeslandTable.landCode
)
}
}
object MitgliedschaftTable : Table("mitgliedschaft") {
val personId = uuid("person_id").references(PersonTable.id)
val vereinId = uuid("verein_id").references(VereinTable.id)
val aktiv = bool("aktiv").default(true)
val erstelltAm = timestamp("erstellt_am").defaultExpression(CurrentTimestamp)
val geaendertAm = timestamp("geaendert_am").defaultExpression(CurrentTimestamp)
override val primaryKey = PrimaryKey(personId, vereinId)
}
object AdresseTable : Table("adresse") {
val id = uuid("id").autoGenerate()
val personId = uuid("person_id").references(PersonTable.id).nullable()
val vereinId = uuid("verein_id").references(VereinTable.id).nullable()
val strasse = varchar("strasse", 100)
val hausnummer = varchar("hausnummer", 10)
val plz = varchar("plz", 10)
val ort = varchar("ort", 100)
val landCode = varchar("land_code", 2).references(LandTable.code)
val bundeslandCode = varchar("bundesland_code", 5).nullable()
val aktiv = bool("aktiv").default(true)
val erstelltAm = timestamp("erstellt_am").defaultExpression(CurrentTimestamp)
val geaendertAm = timestamp("geaendert_am").defaultExpression(CurrentTimestamp)
override val primaryKey = PrimaryKey(id)
init {
foreignKey(
bundeslandCode to LandTable.code,
landCode to BundeslandTable.landCode
)
check("address_owner_check") {
(personId.isNotNull() and vereinId.isNull()) or
(personId.isNull() and vereinId.isNotNull())
}
}
}
@@ -1,129 +0,0 @@
package at.mocode.infrastructure.gateway
import at.mocode.infrastructure.gateway.config.*
import at.mocode.infrastructure.gateway.config.configurePrometheusMetrics
import at.mocode.infrastructure.gateway.config.configureCustomMetrics
import at.mocode.infrastructure.gateway.plugins.configureHttpCaching
import at.mocode.infrastructure.gateway.routing.docRoutes
import at.mocode.infrastructure.gateway.routing.serviceRoutes
import at.mocode.infrastructure.gateway.routing.ApiGatewayInfo
import at.mocode.infrastructure.gateway.routing.HealthStatus
import at.mocode.core.utils.config.AppConfig
import at.mocode.core.domain.model.ApiResponse
import io.ktor.http.*
import io.ktor.serialization.kotlinx.json.*
import io.ktor.server.application.*
import io.ktor.server.http.content.*
import io.ktor.server.plugins.contentnegotiation.*
import io.ktor.server.plugins.cors.routing.*
import io.ktor.server.response.*
import io.ktor.server.routing.*
import io.ktor.server.auth.*
fun Application.module() {
val config = AppConfig
// ContentNegotiation installieren
install(ContentNegotiation) {
json()
}
// CORS installieren, wenn aktiviert
if (config.server.cors.enabled) {
install(CORS) {
if (config.server.cors.allowedOrigins.contains("*")) {
anyHost()
} else {
config.server.cors.allowedOrigins.forEach { allowHost(it, schemes = listOf("http", "https")) }
}
allowHeader(HttpHeaders.ContentType)
allowHeader(HttpHeaders.Authorization)
// Add request ID header to allowed headers
allowHeader(config.logging.requestIdHeader)
allowMethod(HttpMethod.Options)
allowMethod(HttpMethod.Get)
allowMethod(HttpMethod.Post)
allowMethod(HttpMethod.Put)
allowMethod(HttpMethod.Delete)
}
}
// Authentication installieren (für Metrics-Endpoint)
install(Authentication) {
basic("metrics-auth") {
realm = "Metrics Access"
validate { credentials ->
// Simple validation for metrics endpoint
if (credentials.name == "admin" && credentials.password == "metrics") {
UserIdPrincipal(credentials.name)
} else null
}
}
}
// Erweiterte Monitoring- und Logging-Konfiguration
configureMonitoring()
// Prometheus Metrics konfigurieren
configurePrometheusMetrics()
// Custom application metrics konfigurieren
configureCustomMetrics()
// Request Tracing für Cross-Service Tracing konfigurieren
configureRequestTracing()
// Enhanced Rate Limiting konfigurieren
configureRateLimiting()
// OpenAPI und Swagger UI konfigurieren
configureOpenApi()
configureSwagger()
// HTTP Caching konfigurieren
configureHttpCaching()
routing {
// Hauptrouten
get("/") {
val gatewayInfo = ApiGatewayInfo(
name = "Meldestelle API Gateway",
version = "1.0.0",
description = "API Gateway for Meldestelle Self-Contained Systems",
availableContexts = listOf("authentication", "master-data", "horse-registry"),
endpoints = mapOf(
"health" to "/health",
"metrics" to "/metrics",
"docs" to "/docs",
"api" to "/api",
"swagger" to "/swagger"
)
)
call.respond(ApiResponse.success(gatewayInfo, "API Gateway information retrieved successfully"))
}
// Health check endpoint
get("/health") {
val healthStatus = HealthStatus(
status = "UP",
contexts = mapOf(
"authentication" to "UP",
"master-data" to "UP",
"horse-registry" to "UP"
)
)
call.respond(ApiResponse.success(healthStatus, "Health check completed successfully"))
}
// Static resources for documentation
staticResources("/docs", "static/docs") {
default("index.html")
}
// API Documentation routes
docRoutes()
// Service discovery routes
serviceRoutes()
}
}
@@ -1,241 +0,0 @@
package at.mocode.infrastructure.gateway.plugins
import at.mocode.infrastructure.gateway.config.getCachingConfig
import io.ktor.http.*
import io.ktor.server.application.*
import io.ktor.server.request.*
import io.ktor.server.response.*
import io.ktor.util.pipeline.*
import java.security.MessageDigest
import java.text.SimpleDateFormat
import java.util.*
/**
* Configures enhanced HTTP caching headers for the application.
* This adds Cache-Control, Expires, and Vary headers to responses.
* It also integrates with the CachingConfig for more intelligent caching decisions.
*/
fun Application.configureHttpCaching() {
// Get the application logger
val logger = log
// Get the caching config
val cachingConfig = try {
getCachingConfig()
} catch (e: Exception) {
logger.warn("Failed to get CachingConfig, using default caching headers: ${e.message}")
null
}
// Add a response interceptor for setting cache headers
intercept(ApplicationCallPipeline.Call) {
// Add Vary header to all responses
call.response.header(HttpHeaders.Vary, "Accept, Accept-Encoding")
// For authenticated endpoints, add Authorization to Vary
if (call.request.headers.contains(HttpHeaders.Authorization)) {
call.response.header(HttpHeaders.Vary, "Accept, Accept-Encoding, Authorization")
}
// Set default no-cache headers for dynamic content
call.response.header(HttpHeaders.CacheControl, "no-cache, private")
// Check for conditional requests (If-None-Match, If-Modified-Since)
val requestETag = call.request.header(HttpHeaders.IfNoneMatch)
val requestLastModified = call.request.header(HttpHeaders.IfModifiedSince)
// If we have conditional headers, check if we can return 304 Not Modified
if (requestETag != null || requestLastModified != null) {
// This would be implemented with actual ETag and Last-Modified checking
// For now, we just log that we received conditional headers
logger.debug("Received conditional request: ETag=$requestETag, Last-Modified=$requestLastModified")
}
}
logger.info("HTTP caching configured with integration to CachingConfig")
}
/**
* Extension function to enable caching for static resources.
* Use this for CSS, JS, images, and other static files.
*/
fun ApplicationCall.enableStaticResourceCaching(maxAgeSeconds: Int = 86400) { // Default: 1 day
setCacheControlHeader(this, maxAgeSeconds, true)
}
/**
* Extension function to enable caching for master data.
* Use this for reference data that changes infrequently.
*/
fun ApplicationCall.enableMasterDataCaching(maxAgeSeconds: Int = 3600) { // Default: 1 hour
setCacheControlHeader(this, maxAgeSeconds, true)
}
/**
* Extension function to enable caching for user data.
* Use this for user-specific data that may change frequently.
*/
fun ApplicationCall.enableUserDataCaching(maxAgeSeconds: Int = 60) { // Default: 1 minute
setCacheControlHeader(this, maxAgeSeconds, false, true)
}
/**
* Extension function to disable caching.
* Use this for sensitive or frequently changing data.
*/
fun ApplicationCall.disableCaching() {
response.header(HttpHeaders.CacheControl, "no-cache, no-store, must-revalidate, private")
response.header(HttpHeaders.Pragma, "no-cache")
response.header(HttpHeaders.Expires, "0")
}
/**
* Helper function to set Cache-Control and Expires headers.
*/
private fun setCacheControlHeader(
call: ApplicationCall,
maxAgeSeconds: Int,
isPublic: Boolean,
mustRevalidate: Boolean = false
) {
// Build Cache-Control header
val visibility = if (isPublic) "public" else "private"
val revalidate = if (mustRevalidate) ", must-revalidate" else ""
call.response.header(
HttpHeaders.CacheControl,
"max-age=$maxAgeSeconds, $visibility$revalidate"
)
// Set Expires header
val calendar = Calendar.getInstance()
calendar.add(Calendar.SECOND, maxAgeSeconds)
val dateFormat = SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss z", Locale.US)
dateFormat.timeZone = TimeZone.getTimeZone("GMT")
call.response.header(HttpHeaders.Expires, dateFormat.format(calendar.time))
}
/**
* Extension function to set ETag header for a response.
*/
fun ApplicationCall.setETag(etag: String) {
response.header(HttpHeaders.ETag, "\"$etag\"")
}
/**
* Extension function to set Last-Modified header for a response.
*/
fun ApplicationCall.setLastModified(timestamp: Long) {
val dateFormat = SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss z", Locale.US)
dateFormat.timeZone = TimeZone.getTimeZone("GMT")
response.header(HttpHeaders.LastModified, dateFormat.format(Date(timestamp)))
}
/**
* Generate an ETag for the given content.
* This uses MD5 hashing for simplicity, but in production you might want to use a faster algorithm.
*/
fun generateETag(content: String): String {
val md = MessageDigest.getInstance("MD5")
val digest = md.digest(content.toByteArray(Charsets.UTF_8))
return digest.joinToString("") { "%02x".format(it) }
}
/**
* Generate an ETag for the given object by converting it to a string representation.
*/
fun generateETag(obj: Any): String {
return generateETag(obj.toString())
}
/**
* Check if the request has a matching ETag and return 304 Not Modified if it does.
* Returns true if the response was handled (304 sent), false otherwise.
*/
suspend fun PipelineContext<Unit, ApplicationCall>.checkETagAndRespond(etag: String): Boolean {
val requestETag = call.request.header(HttpHeaders.IfNoneMatch)
// If the client sent an If-None-Match header and it matches our ETag,
// we can return 304 Not Modified
if (requestETag != null && (requestETag == "\"$etag\"" || requestETag == "*")) {
call.response.header(HttpHeaders.ETag, "\"$etag\"")
call.respond(HttpStatusCode.NotModified)
return true
}
// Set the ETag header for the response
call.response.header(HttpHeaders.ETag, "\"$etag\"")
return false
}
/**
* Check if the request has a matching Last-Modified date and return 304 Not Modified if it does.
* Returns true if the response was handled (304 sent), false otherwise.
*/
suspend fun PipelineContext<Unit, ApplicationCall>.checkLastModifiedAndRespond(timestamp: Long): Boolean {
val requestLastModified = call.request.header(HttpHeaders.IfModifiedSince)
if (requestLastModified != null) {
try {
val dateFormat = SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss z", Locale.US)
dateFormat.timeZone = TimeZone.getTimeZone("GMT")
val requestDate = dateFormat.parse(requestLastModified).time
// If the resource hasn't been modified since the date in the request,
// we can return 304 Not Modified
if (timestamp <= requestDate) {
val lastModifiedFormatted = dateFormat.format(Date(timestamp))
call.response.header(HttpHeaders.LastModified, lastModifiedFormatted)
call.respond(HttpStatusCode.NotModified)
return true
}
} catch (_: Exception) {
// If we can't parse the date, ignore it
}
}
// Set the Last-Modified header for the response
val dateFormat = SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss z", Locale.US)
dateFormat.timeZone = TimeZone.getTimeZone("GMT")
call.response.header(HttpHeaders.LastModified, dateFormat.format(Date(timestamp)))
return false
}
/**
* Extension function to check if a resource is cached in CachingConfig.
* If it is, and the client has a matching ETag or Last-Modified date,
* this will return 304 Not Modified. Otherwise, it will return the cached value.
* Returns true if the response was handled, false otherwise.
*/
suspend fun <T> PipelineContext<Unit, ApplicationCall>.checkCacheAndRespond(
cacheName: String,
key: String,
etag: String? = null,
lastModified: Long? = null
): Boolean {
val application = call.application
val cachingConfig = try {
application.getCachingConfig()
} catch (_: Exception) {
return false
}
// Check if the resource is in the cache
val cachedValue = cachingConfig.get<T>(cacheName, key)
if (cachedValue != null) {
// If we have an ETag, check if the client has a matching one
if (etag != null && checkETagAndRespond(etag)) {
return true
}
// If we have a Last-Modified date, check if the client has a matching one
if (lastModified != null && checkLastModifiedAndRespond(lastModified)) {
return true
}
// If we get here, the client doesn't have a matching ETag or Last-Modified date,
// so we need to send the full response
return false
}
return false
}
@@ -1,17 +0,0 @@
package at.mocode.infrastructure.gateway.routing
import at.mocode.core.domain.model.BaseDto
import kotlinx.serialization.Serializable
/**
* Information about the API Gateway.
* This class is used to provide information about the API Gateway to clients.
*/
@Serializable
data class ApiGatewayInfo(
val name: String,
val version: String,
val description: String,
val availableContexts: List<String>,
val endpoints: Map<String, String>
) : BaseDto
@@ -1,242 +0,0 @@
package at.mocode.infrastructure.gateway.routing
import at.mocode.core.domain.model.ApiResponse
import at.mocode.infrastructure.auth.client.AuthenticationService
import at.mocode.infrastructure.auth.client.JwtService
import at.mocode.core.utils.validation.ApiValidationUtils
import io.ktor.http.*
import io.ktor.server.auth.*
import io.ktor.server.auth.jwt.*
import io.ktor.server.request.*
import io.ktor.server.response.*
import io.ktor.server.routing.*
import kotlinx.serialization.Serializable
/**
* Konfiguriert die Authentifizierungs-Routen.
*/
fun Routing.authRoutes(
authenticationService: AuthenticationService,
jwtService: JwtService
) {
route("/auth") {
// Login-Route
post("/login") {
try {
// Request-Daten lesen
val request = call.receive<LoginRequest>()
// Validierung
val validationErrors = ApiValidationUtils.validateLoginRequest(request.username, request.password)
if (!ApiValidationUtils.isValid(validationErrors)) {
call.respond(
HttpStatusCode.BadRequest,
ApiResponse.error<LoginResponse>(ApiValidationUtils.createErrorMessage(validationErrors))
)
return@post
}
// Authentifizierung durchführen
val authResult = authenticationService.authenticate(request.username, request.password)
// Antwort basierend auf dem Ergebnis senden
when (authResult) {
is AuthenticationService.AuthResult.Success -> {
call.respond(
HttpStatusCode.OK,
ApiResponse.success(
LoginResponse(
token = authResult.token,
userId = authResult.user.userId.toString(),
personId = authResult.user.personId.toString(),
username = authResult.user.username,
email = authResult.user.email
)
)
)
}
is AuthenticationService.AuthResult.Failure -> {
call.respond(
HttpStatusCode.Unauthorized,
ApiResponse.error<LoginResponse>(authResult.reason)
)
}
is AuthenticationService.AuthResult.Locked -> {
call.respond(
HttpStatusCode.Locked,
ApiResponse.error<LoginResponse>(
"Account gesperrt bis ${authResult.lockedUntil}"
)
)
}
}
} catch (e: Exception) {
call.respond(
HttpStatusCode.InternalServerError,
ApiResponse.error<LoginResponse>("Fehler bei der Anmeldung: ${e.message}")
)
}
}
// Registrierung (Beispiel, sollte an die Anforderungen angepasst werden)
post("/register") {
// Würde hier Registrierung implementieren
call.respond(
HttpStatusCode.NotImplemented,
ApiResponse.error<Any>("Registrierung noch nicht implementiert")
)
}
// Passwort ändern (geschützte Route)
authenticate("jwt") {
post("/change-password") {
try {
// Request-Daten lesen
val request = call.receive<ChangePasswordRequest>()
// Validierung
val validationErrors = ApiValidationUtils.validateChangePasswordRequest(
request.currentPassword,
request.newPassword,
request.confirmPassword
)
if (!ApiValidationUtils.isValid(validationErrors)) {
call.respond(
HttpStatusCode.BadRequest,
ApiResponse.error<Any>(ApiValidationUtils.createErrorMessage(validationErrors))
)
return@post
}
// Benutzer-ID aus dem Token extrahieren
val principal = call.principal<JWTPrincipal>()
val userId = principal?.getClaim("sub", String::class) ?: run {
call.respond(
HttpStatusCode.Unauthorized,
ApiResponse.error<Any>("Ungültiges Token")
)
return@post
}
// Passwort ändern
val result = authenticationService.changePassword(
com.benasher44.uuid.Uuid.fromString(userId),
request.currentPassword,
request.newPassword
)
// Antwort basierend auf dem Ergebnis senden
when (result) {
is AuthenticationService.PasswordChangeResult.Success -> {
call.respond(
HttpStatusCode.OK,
ApiResponse.success("Passwort erfolgreich geändert")
)
}
is AuthenticationService.PasswordChangeResult.Failure -> {
call.respond(
HttpStatusCode.BadRequest,
ApiResponse.error<Any>(result.reason)
)
}
is AuthenticationService.PasswordChangeResult.WeakPassword -> {
call.respond(
HttpStatusCode.BadRequest,
ApiResponse.error<Any>("Das neue Passwort ist zu schwach")
)
}
}
} catch (e: Exception) {
call.respond(
HttpStatusCode.InternalServerError,
ApiResponse.error<Any>("Fehler bei der Passwortänderung: ${e.message}")
)
}
}
// Benutzerinformationen abrufen
get("/me") {
try {
// Token validieren und Benutzerinformationen abrufen
val principal = call.principal<JWTPrincipal>()
val userId = principal?.getClaim("sub", String::class) ?: run {
call.respond(
HttpStatusCode.Unauthorized,
ApiResponse.error<Any>("Ungültiges Token")
)
return@get
}
// Hier können zusätzliche Informationen aus dem Token oder der Datenbank abgerufen werden
val username = principal.getClaim("username", String::class) ?: ""
val personId = principal.getClaim("personId", String::class) ?: ""
val permissions = principal.getClaim("permissions", String::class)?.split(",") ?: listOf()
call.respond(
HttpStatusCode.OK,
ApiResponse.success(
UserInfoResponse(
userId = userId,
personId = personId,
username = username,
permissions = permissions
)
)
)
} catch (e: Exception) {
call.respond(
HttpStatusCode.InternalServerError,
ApiResponse.error<Any>("Fehler beim Abrufen der Benutzerinformationen: ${e.message}")
)
}
}
}
}
}
/**
* Request-Modell für die Anmeldung.
*/
@Serializable
data class LoginRequest(
val username: String,
val password: String
)
/**
* Response-Modell für eine erfolgreiche Anmeldung.
*/
@Serializable
data class LoginResponse(
val token: String,
val userId: String,
val personId: String,
val username: String,
val email: String
)
/**
* Request-Modell für die Passwortänderung.
*/
@Serializable
data class ChangePasswordRequest(
val currentPassword: String,
val newPassword: String,
val confirmPassword: String
)
/**
* Response-Modell für Benutzerinformationen.
*/
@Serializable
data class UserInfoResponse(
val userId: String,
val personId: String,
val username: String,
val permissions: List<String>
)
@@ -1,73 +0,0 @@
package at.mocode.infrastructure.gateway.routing
import at.mocode.core.domain.model.ApiResponse
import io.ktor.server.response.*
import io.ktor.server.routing.*
import kotlinx.serialization.Serializable
/**
* Sets up routes for API documentation
*/
fun Routing.docRoutes() {
// Central API documentation endpoint - HTML version
get("/api") {
call.respondRedirect("/docs", permanent = false)
}
// JSON API documentation endpoint for backward compatibility
get("/api/json") {
val apiDocumentation = ApiDocumentationData(
title = "Meldestelle Self-Contained Systems API",
description = "Unified API Gateway for all bounded contexts",
contexts = listOf(
ApiContext(
name = "Authentication Context",
path = "/auth",
description = "User authentication, registration, and profile management"
),
ApiContext(
name = "Master Data Context",
path = "/api/masterdata",
description = "Reference data management (countries, states, age classes, venues)"
),
ApiContext(
name = "Horse Registry Context",
path = "/api/horses",
description = "Horse registration, ownership, and pedigree management"
),
ApiContext(
name = "Event Management Context",
path = "/api/events",
description = "Event creation, management, and participant registration"
)
)
)
call.respond(
ApiResponse.success(
data = apiDocumentation,
message = "API documentation retrieved successfully"
)
)
}
}
/**
* Data class for API documentation response
*/
@Serializable
data class ApiDocumentationData(
val title: String,
val description: String,
val contexts: List<ApiContext>
)
/**
* Data class for API context information
*/
@Serializable
data class ApiContext(
val name: String,
val path: String,
val description: String
)
@@ -1,14 +0,0 @@
package at.mocode.infrastructure.gateway.routing
import at.mocode.core.domain.model.BaseDto
import kotlinx.serialization.Serializable
/**
* Health status information for the API Gateway and its contexts.
* This class is used to provide health status information to clients.
*/
@Serializable
data class HealthStatus(
val status: String,
val contexts: Map<String, String>
) : BaseDto
@@ -1,201 +0,0 @@
package at.mocode.infrastructure.gateway.routing
import at.mocode.infrastructure.gateway.discovery.ServiceDiscovery
import at.mocode.core.utils.config.AppConfig
import io.ktor.client.*
import io.ktor.client.engine.cio.*
import io.ktor.client.plugins.contentnegotiation.*
import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.http.*
import io.ktor.serialization.kotlinx.json.*
import io.ktor.server.application.*
import io.ktor.server.request.*
import io.ktor.server.response.*
import io.ktor.server.routing.*
import io.ktor.util.*
import kotlinx.serialization.Serializable
/**
* Simple error response for service routing errors
*/
@Serializable
data class ServiceErrorResponse(
val error: String,
val code: String,
val service: String? = null
)
/**
* Simple success response for service routing
*/
@Serializable
data class ServiceSuccessResponse(
val message: String,
val service: String,
val instance: ServiceInstanceInfo
)
@Serializable
data class ServiceInstanceInfo(
val id: String,
val name: String,
val host: String,
val port: Int
)
/**
* Configure dynamic service routing using Consul service discovery.
* This allows the API Gateway to discover services registered with Consul and route requests to them.
*/
fun Routing.serviceRoutes() {
val config = AppConfig
// Check if we're in a test environment
val isTestEnvironment = System.getProperty("kotlinx.coroutines.test") != null ||
Thread.currentThread().stackTrace.any { it.className.contains("test", ignoreCase = true) }
// Initialize service discovery if enabled and not in test environment
val serviceDiscovery = if (config.serviceDiscovery.enabled && !isTestEnvironment) {
try {
ServiceDiscovery(
consulHost = config.serviceDiscovery.consulHost,
consulPort = config.serviceDiscovery.consulPort
)
} catch (e: Exception) {
// If service discovery fails to initialize, log and continue without it
println("Service discovery initialization failed: ${e.message}")
null
}
} else null
// Define service routes with all HTTP methods
// Master Data Service Routes
route("/api/masterdata") {
get("{...}") { handleServiceRequest(call, "master-data", serviceDiscovery) }
post("{...}") { handleServiceRequest(call, "master-data", serviceDiscovery) }
put("{...}") { handleServiceRequest(call, "master-data", serviceDiscovery) }
delete("{...}") { handleServiceRequest(call, "master-data", serviceDiscovery) }
patch("{...}") { handleServiceRequest(call, "master-data", serviceDiscovery) }
}
// Horse Registry Service Routes
route("/api/horses") {
get("{...}") { handleServiceRequest(call, "horse-registry", serviceDiscovery) }
post("{...}") { handleServiceRequest(call, "horse-registry", serviceDiscovery) }
put("{...}") { handleServiceRequest(call, "horse-registry", serviceDiscovery) }
delete("{...}") { handleServiceRequest(call, "horse-registry", serviceDiscovery) }
patch("{...}") { handleServiceRequest(call, "horse-registry", serviceDiscovery) }
}
// Event Management Service Routes
route("/api/events") {
get("{...}") { handleServiceRequest(call, "event-management", serviceDiscovery) }
post("{...}") { handleServiceRequest(call, "event-management", serviceDiscovery) }
put("{...}") { handleServiceRequest(call, "event-management", serviceDiscovery) }
delete("{...}") { handleServiceRequest(call, "event-management", serviceDiscovery) }
patch("{...}") { handleServiceRequest(call, "event-management", serviceDiscovery) }
}
// Member Management Service Routes
route("/api/members") {
get("{...}") { handleServiceRequest(call, "member-management", serviceDiscovery) }
post("{...}") { handleServiceRequest(call, "member-management", serviceDiscovery) }
put("{...}") { handleServiceRequest(call, "member-management", serviceDiscovery) }
delete("{...}") { handleServiceRequest(call, "member-management", serviceDiscovery) }
patch("{...}") { handleServiceRequest(call, "member-management", serviceDiscovery) }
}
}
/**
* HTTP client for forwarding requests to backend services
*/
private val httpClient = HttpClient(CIO) {
install(ContentNegotiation) {
json()
}
}
/**
* Handle a service request by discovering the service and forwarding the request.
* This implementation forwards the complete HTTP request to the backend service.
*/
private suspend fun handleServiceRequest(
call: ApplicationCall,
serviceName: String,
serviceDiscovery: ServiceDiscovery?
) {
try {
// Check if service discovery is available
if (serviceDiscovery == null) {
val errorResponse = ServiceErrorResponse(
error = "Service discovery is not available",
code = "SERVICE_DISCOVERY_DISABLED"
)
call.respond(HttpStatusCode.ServiceUnavailable, errorResponse)
return
}
// Get service instance
val serviceInstance = serviceDiscovery.getServiceInstance(serviceName)
if (serviceInstance == null) {
val errorResponse = ServiceErrorResponse(
error = "Service $serviceName is not available",
code = "SERVICE_NOT_FOUND",
service = serviceName
)
call.respond(HttpStatusCode.ServiceUnavailable, errorResponse)
return
}
// Build target URL
val targetUrl = "http://${serviceInstance.host}:${serviceInstance.port}${call.request.uri}"
// Forward the request to the backend service
val response = httpClient.request(targetUrl) {
method = call.request.httpMethod
// Copy all headers except Host and Content-Length (handled automatically)
call.request.headers.forEach { name, values ->
if (name.lowercase() !in listOf("host", "content-length")) {
values.forEach { value ->
header(name, value)
}
}
}
// Copy request body if present
if (call.request.httpMethod in listOf(HttpMethod.Post, HttpMethod.Put, HttpMethod.Patch)) {
val requestBody = call.receiveText()
if (requestBody.isNotEmpty()) {
setBody(requestBody)
}
}
}
// Forward the response back to the client
call.response.status(response.status)
// Copy response headers
response.headers.forEach { name, values ->
if (name.lowercase() !in listOf("content-length", "transfer-encoding")) {
values.forEach { value ->
call.response.header(name, value)
}
}
}
// Copy response body
val responseBody = response.bodyAsText()
call.respondText(responseBody, response.contentType())
} catch (e: Exception) {
val errorResponse = ServiceErrorResponse(
error = "Error routing request to service $serviceName: ${e.message}",
code = "SERVICE_ERROR",
service = serviceName
)
call.respond(HttpStatusCode.InternalServerError, errorResponse)
}
}
@@ -1,104 +0,0 @@
package at.mocode.infrastructure.gateway.validation
import at.mocode.core.domain.model.ApiResponse
import io.ktor.http.*
import io.ktor.server.application.*
import io.ktor.server.request.*
import io.ktor.server.response.*
/**
* Klasse für die Validierung von API-Anfragen.
* Bietet Methoden zum Validieren und Verarbeiten von Request-Daten.
*/
class RequestValidator {
companion object {
/**
* Validiert und verarbeitet eine Anfrage.
*
* @param call Der ApplicationCall
* @param validator Eine Funktion, die den Request validiert und eine Liste von Fehlern zurückgibt
* @param processor Eine Funktion, die den validierten Request verarbeitet
* @return true, wenn die Validierung erfolgreich war, false sonst
*/
suspend inline fun <reified T : Any> validateAndProcess(
call: ApplicationCall,
crossinline validator: (T) -> List<String>,
crossinline processor: suspend (T) -> Unit
): Boolean {
try {
// Request-Daten lesen
val request = call.receive<T>()
// Validierung durchführen
val errors = validator(request)
if (errors.isNotEmpty()) {
call.respond(
HttpStatusCode.BadRequest,
ApiResponse.error<T>("Validierungsfehler")
)
return false
}
// Request verarbeiten
processor(request)
return true
} catch (e: Exception) {
call.respond(
HttpStatusCode.BadRequest,
ApiResponse.error<T>("Fehler bei der Anfrageverarbeitung: ${e.message}")
)
return false
}
}
/**
* Validiert Pflichtfelder in einem Request.
*
* @param fields Map von Feldnamen zu Feldwerten
* @return Liste von Fehlermeldungen für fehlende Pflichtfelder
*/
fun validateRequiredFields(vararg fields: Pair<String, Any?>): List<String> {
return fields
.filter { (_, value) ->
when (value) {
null -> true
is String -> value.isBlank()
is Collection<*> -> value.isEmpty()
else -> false
}
}
.map { (name, _) -> "Das Feld '$name' ist erforderlich" }
}
/**
* Validiert die Länge eines Textfeldes.
*
* @param name Name des Feldes
* @param value Wert des Feldes
* @param minLength Minimale Länge
* @param maxLength Maximale Länge
* @return Fehlermeldung, wenn die Länge ungültig ist, sonst null
*/
fun validateStringLength(name: String, value: String?, minLength: Int, maxLength: Int): String? {
if (value == null) return null
return when {
value.length < minLength -> "Das Feld '$name' muss mindestens $minLength Zeichen enthalten"
value.length > maxLength -> "Das Feld '$name' darf höchstens $maxLength Zeichen enthalten"
else -> null
}
}
/**
* Validiert eine E-Mail-Adresse.
*
* @param email Die zu validierende E-Mail-Adresse
* @return true, wenn die E-Mail-Adresse gültig ist, false sonst
*/
fun isValidEmail(email: String?): Boolean {
if (email == null) return false
val emailRegex = "^[A-Za-z0-9+_.-]+@[A-Za-z0-9.-]+$"
return email.matches(emailRegex.toRegex())
}
}
}
@@ -0,0 +1,16 @@
# Port, auf dem das Gateway läuft
server:
port: 8080
# Name, unter dem sich das Gateway in Consul registriert
spring:
application:
name: api-gateway
cloud:
gateway:
# Aktiviert die automatische Routen-Erstellung basierend auf Consul
discovery:
locator:
enabled: true
# Macht Routen-Namen klein (z.B. /members-service/** statt /MEMBERS-SERVICE/**)
lower-case-service-id: true
@@ -1,463 +0,0 @@
package at.mocode.infrastructure.gateway
import at.mocode.core.domain.model.ApiResponse
import at.mocode.infrastructure.gateway.routing.ApiGatewayInfo
import at.mocode.infrastructure.gateway.routing.HealthStatus
import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.http.*
import io.ktor.server.testing.*
import kotlinx.serialization.json.Json
import org.junit.jupiter.api.DisplayName
import org.junit.jupiter.api.Nested
import org.junit.jupiter.api.Test
import kotlin.test.assertEquals
import kotlin.test.assertNotNull
import kotlin.test.assertTrue
/**
* Integration tests for the API Gateway.
*
* These tests verify that all API endpoints are working correctly
* and that the OpenAPI/Swagger integration is functioning properly.
*
* Tests are organized into nested classes by functionality area.
*/
class ApiIntegrationTest {
private val json = Json { ignoreUnknownKeys = true }
/**
* Helper function to verify common ApiResponse structure
*/
private fun verifyApiResponseStructure(responseText: String) {
assertTrue(responseText.contains("\"success\""), "Response should contain 'success' field")
assertTrue(responseText.contains("\"data\""), "Response should contain 'data' field")
assertTrue(responseText.contains("\"message\""), "Response should contain 'message' field")
}
/**
* Tests for core API Gateway functionality
*/
@Nested
@DisplayName("Core API Gateway Tests")
inner class CoreApiTests {
@Test
fun testApiGatewayInfo() = testApplication {
application {
module()
}
client.get("/").apply {
assertEquals(HttpStatusCode.OK, status, "Status should be OK")
val responseText = bodyAsText()
assertTrue(responseText.contains("Meldestelle API Gateway"), "Response should contain gateway name")
// Parse response as ApiResponse
val response = json.decodeFromString<ApiResponse<ApiGatewayInfo>>(responseText)
assertTrue(response.success, "Response should indicate success")
assertNotNull(response.data, "Response data should not be null")
assertEquals("Meldestelle API Gateway", response.data!!.name, "Gateway name should match")
assertEquals("1.0.0", response.data!!.version, "Gateway version should match")
// Verify all expected contexts are available
val expectedContexts = listOf("authentication", "master-data", "horse-registry")
expectedContexts.forEach { context ->
assertTrue(response.data!!.availableContexts.contains(context),
"Available contexts should contain $context")
}
// Verify ApiResponse structure
verifyApiResponseStructure(responseText)
}
}
@Test
fun testHealthCheck() = testApplication {
application {
module()
}
client.get("/health").apply {
assertEquals(HttpStatusCode.OK, status, "Health check status should be OK")
val responseText = bodyAsText()
// Parse response as ApiResponse
val response = json.decodeFromString<ApiResponse<HealthStatus>>(responseText)
assertTrue(response.success, "Health check response should indicate success")
assertNotNull(response.data, "Health check data should not be null")
assertEquals("UP", response.data!!.status, "Health status should be UP")
// Verify all expected contexts are available in health check
val expectedContexts = listOf("authentication", "master-data", "horse-registry")
expectedContexts.forEach { context ->
assertTrue(response.data!!.contexts.containsKey(context),
"Health contexts should contain $context")
}
// Verify ApiResponse structure
verifyApiResponseStructure(responseText)
}
}
@Test
fun testNotFoundEndpoint() = testApplication {
application {
module()
}
client.get("/nonexistent").apply {
assertEquals(HttpStatusCode.NotFound, status, "Non-existent endpoint should return 404")
val responseText = bodyAsText()
assertTrue(responseText.contains("Endpoint not found"),
"Response should indicate endpoint not found")
// Verify error response format
assertTrue(responseText.contains("\"success\":false"),
"Error response should have success=false")
}
}
@Test
fun testInvalidMethod() = testApplication {
application {
module()
}
client.delete("/").apply {
// Either method not allowed or not found is acceptable
assertTrue(
status == HttpStatusCode.MethodNotAllowed || status == HttpStatusCode.NotFound,
"Invalid method should return 405 Method Not Allowed or 404 Not Found"
)
}
}
}
/**
* Tests for API documentation and Swagger UI
*/
@Nested
@DisplayName("Documentation Tests")
inner class DocumentationTests {
@Test
fun testApiDocumentation() = testApplication {
application {
module()
}
client.get("/api").apply {
assertEquals(HttpStatusCode.OK, status, "API documentation status should be OK")
val responseText = bodyAsText()
// Verify documentation contains expected sections
val expectedSections = listOf(
"Meldestelle Self-Contained Systems API",
"Authentication Context",
"Master Data Context",
"Horse Registry Context"
)
expectedSections.forEach { section ->
assertTrue(responseText.contains(section),
"API documentation should contain section: $section")
}
}
}
@Test
fun testSwaggerUI() = testApplication {
application {
module()
}
client.get("/swagger").apply {
// Swagger UI should be accessible (might return HTML or redirect)
assertTrue(
status.isSuccess() || status == HttpStatusCode.Found,
"Swagger UI should be accessible or redirect"
)
// If it's HTML, it should contain Swagger-related content
if (status.isSuccess()) {
val responseText = bodyAsText()
assertTrue(
responseText.contains("swagger") || responseText.contains("openapi"),
"Swagger UI response should contain swagger-related content"
)
}
}
}
}
/**
* Tests for API technical features like CORS and content negotiation
*/
@Nested
@DisplayName("API Technical Features")
inner class TechnicalFeatureTests {
@Test
fun testCorsHeaders() = testApplication {
application {
module()
}
// Test preflight request
client.options("/") {
header(HttpHeaders.Origin, "http://localhost:3000")
header(HttpHeaders.AccessControlRequestMethod, "GET")
}.apply {
assertTrue(status.isSuccess(), "CORS preflight request should succeed")
// Verify CORS headers
assertTrue(
headers.contains(HttpHeaders.AccessControlAllowOrigin),
"Response should contain Access-Control-Allow-Origin header"
)
assertTrue(
headers.contains(HttpHeaders.AccessControlAllowMethods),
"Response should contain Access-Control-Allow-Methods header"
)
}
// Test actual request with Origin header
client.get("/") {
header(HttpHeaders.Origin, "http://localhost:3000")
}.apply {
assertEquals(HttpStatusCode.OK, status, "CORS actual request should succeed")
assertTrue(
headers.contains(HttpHeaders.AccessControlAllowOrigin),
"Response should contain Access-Control-Allow-Origin header"
)
}
}
@Test
fun testContentNegotiation() = testApplication {
application {
module()
}
// Test JSON content type
client.get("/") {
header(HttpHeaders.Accept, "application/json")
}.apply {
assertEquals(HttpStatusCode.OK, status, "Content negotiation request should succeed")
assertEquals(
ContentType.Application.Json.withCharset(Charsets.UTF_8),
contentType(),
"Response content type should be application/json"
)
}
// Test with no Accept header (should default to JSON)
client.get("/").apply {
assertEquals(HttpStatusCode.OK, status, "Default content type request should succeed")
assertEquals(
ContentType.Application.Json.withCharset(Charsets.UTF_8),
contentType(),
"Default response content type should be application/json"
)
}
}
}
/**
* Tests for Master Data endpoints
*/
@Nested
@DisplayName("Master Data Endpoints")
inner class MasterDataTests {
@Test
fun testCountriesEndpoint() = testApplication {
application {
module()
}
client.get("/api/masterdata/countries").apply {
assertEquals(HttpStatusCode.OK, status, "Countries endpoint should return OK")
val responseText = bodyAsText()
// Verify response format
verifyApiResponseStructure(responseText)
assertTrue(responseText.contains("\"success\":true"),
"Response should indicate success")
}
}
@Test
fun testActiveCountriesEndpoint() = testApplication {
application {
module()
}
client.get("/api/masterdata/countries/active").apply {
assertEquals(HttpStatusCode.OK, status, "Active countries endpoint should return OK")
val responseText = bodyAsText()
// Verify response format
verifyApiResponseStructure(responseText)
assertTrue(responseText.contains("\"success\":true"),
"Response should indicate success")
}
}
@Test
fun testCountriesWithPagination() = testApplication {
application {
module()
}
client.get("/api/masterdata/countries?limit=5&offset=0").apply {
assertEquals(HttpStatusCode.OK, status, "Countries with pagination should return OK")
val responseText = bodyAsText()
// Verify response format
verifyApiResponseStructure(responseText)
assertTrue(responseText.contains("\"success\":true"),
"Response should indicate success")
}
}
}
/**
* Tests for Horse Registry endpoints
*/
@Nested
@DisplayName("Horse Registry Endpoints")
inner class HorseRegistryTests {
@Test
fun testHorsesEndpointRequiresAuth() = testApplication {
application {
module()
}
client.get("/api/horses").apply {
// Should return unauthorized or redirect to login
assertTrue(
status == HttpStatusCode.Unauthorized || status == HttpStatusCode.Found,
"Horses endpoint should require authentication"
)
}
}
@Test
fun testHorseStatsEndpointRequiresAuth() = testApplication {
application {
module()
}
client.get("/api/horses/stats").apply {
// Should require authentication
assertTrue(
status == HttpStatusCode.Unauthorized || status == HttpStatusCode.Found,
"Horse stats endpoint should require authentication"
)
}
}
}
/**
* Tests for Authentication endpoints
*/
@Nested
@DisplayName("Authentication Endpoints")
inner class AuthenticationTests {
@Test
fun testRegistrationEndpoint() = testApplication {
application {
module()
}
client.post("/auth/register") {
contentType(ContentType.Application.Json)
setBody("""
{
"email": "test@example.com",
"password": "TestPassword123!",
"firstName": "Test",
"lastName": "User",
"phoneNumber": "+43123456789"
}
""".trimIndent())
}.apply {
// Should process the request (might fail due to validation or database issues)
// But should not return server error
assertTrue(status.value in 200..499,
"Registration endpoint should process request without server error")
// If it's a client error, it should be due to validation or existing user
if (status.value in 400..499) {
val responseText = bodyAsText()
assertTrue(
responseText.contains("validation") ||
responseText.contains("exist") ||
responseText.contains("already"),
"Client error should be due to validation or existing user"
)
}
}
}
@Test
fun testLoginEndpoint() = testApplication {
application {
module()
}
client.post("/auth/login") {
contentType(ContentType.Application.Json)
setBody("""
{
"email": "test@example.com",
"password": "TestPassword123!"
}
""".trimIndent())
}.apply {
// Should process the request without server error
assertTrue(status.value in 200..499,
"Login endpoint should process request without server error")
// If it's a client error, it should be due to invalid credentials
if (status.value in 400..499) {
val responseText = bodyAsText()
assertTrue(
responseText.contains("invalid") ||
responseText.contains("credentials") ||
responseText.contains("unauthorized"),
"Client error should be due to invalid credentials"
)
}
}
}
@Test
fun testInvalidLoginRequest() = testApplication {
application {
module()
}
// Test with missing password
client.post("/auth/login") {
contentType(ContentType.Application.Json)
setBody("""
{
"email": "test@example.com"
}
""".trimIndent())
}.apply {
// Should return a client error
assertTrue(status.value in 400..499,
"Invalid login request should return client error")
val responseText = bodyAsText()
assertTrue(
responseText.contains("validation") ||
responseText.contains("missing") ||
responseText.contains("required"),
"Error should indicate validation failure or missing field"
)
}
}
}
}