Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions core/src/main/kotlin/EvaluationProxy.kt
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,28 @@ class EvaluationProxyResponseException(
data class EvaluationProxyResponse(
val status: HttpStatusCode,
val body: String,
val bytes: ByteArray? = null,
Comment thread
vaibhav-jain-exp marked this conversation as resolved.
) {
companion object {
fun error(
status: HttpStatusCode,
message: String,
): EvaluationProxyResponse {
return EvaluationProxyResponse(status, message)
return EvaluationProxyResponse(status, body = message, bytes = null)
}

inline fun <reified T> json(
status: HttpStatusCode,
response: T,
): EvaluationProxyResponse {
return EvaluationProxyResponse(status, json.encodeToString<T>(response))
return EvaluationProxyResponse(status, body = json.encodeToString<T>(response), bytes = null)
}

fun bytes(
status: HttpStatusCode,
payload: ByteArray,
): EvaluationProxyResponse {
return EvaluationProxyResponse(status = status, body = "", bytes = payload)
}
}
}
Expand Down
9 changes: 5 additions & 4 deletions core/src/main/kotlin/cohort/CohortApi.kt
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ internal interface CohortApi {
lastModified: Long?,
maxCohortSize: Int,
storage: CohortStorage,
)
): Boolean
}

internal class CohortApiV1(
Expand Down Expand Up @@ -134,7 +134,7 @@ internal class CohortApiV1(
lastModified: Long?,
maxCohortSize: Int,
storage: CohortStorage,
) {
): Boolean {
log.debug("streamCohort({}): start - maxCohortSize={}, lastModified={}", cohortId, maxCohortSize, lastModified)
val response =
retry(
Expand All @@ -158,7 +158,7 @@ internal class CohortApiV1(
}
log.debug("streamCohort({}): status={}", cohortId, response.status)
when (response.status) {
HttpStatusCode.NoContent -> return
HttpStatusCode.NoContent -> return false
HttpStatusCode.PayloadTooLarge -> throw CohortTooLargeException(cohortId, maxCohortSize)
else -> {
val input = response.bodyAsChannel().toInputStream()
Expand Down Expand Up @@ -205,7 +205,7 @@ internal class CohortApiV1(
val lm = parsedLastModified
if (id != null && lm != null) {
if (lm <= existingLastModified) {
return
return false
}
}
ensureWriter()
Expand All @@ -232,6 +232,7 @@ internal class CohortApiV1(
}
ensureWriter()
writer!!.complete(memberCount)
return true
}
}
}
Expand Down
33 changes: 33 additions & 0 deletions core/src/main/kotlin/cohort/CohortBlobCache.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package com.amplitude.cohort

import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock

/**
* Simple in-memory cache for gzipped cohort blobs.
* Key: "{cohortId}"
*/
internal class CohortBlobCache {
private data class CacheEntry(val bytes: ByteArray)

private val lock = Mutex()
private val map = HashMap<String, CacheEntry>()
Comment thread
vaibhav-jain-exp marked this conversation as resolved.

suspend fun get(key: String): ByteArray? =
lock.withLock {
map[key]?.bytes
}

suspend fun put(
key: String,
bytes: ByteArray,
) = lock.withLock {
map[key] = CacheEntry(bytes)
}

suspend fun remove(key: String) =
lock.withLock {
map.remove(key)
null
}
}
5 changes: 3 additions & 2 deletions core/src/main/kotlin/cohort/CohortLoader.kt
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,16 @@ internal class CohortLoader(
try {
try {
val storageCohort = cohortStorage.getCohortDescription(cohortId)
try {
val modified =
Metrics.with({ CohortDownload }, { e -> CohortDownloadFailure(e) }) {
cohortApi.streamCohort(cohortId, storageCohort?.lastModified, maxCohortSize, cohortStorage)
}
if (modified) {
val updated = cohortStorage.getCohortDescription(cohortId)
if (updated != null) {
log.info("Cohort download/save completed. {}", updated)
}
} catch (_: CohortNotModifiedException) {
} else {
log.debug("loadCohort: cohort not modified - cohortId={}", cohortId)
}
} catch (t: Throwable) {
Expand Down
137 changes: 134 additions & 3 deletions core/src/main/kotlin/cohort/CohortStorage.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,19 @@ import com.amplitude.util.logger
import com.amplitude.util.redis.Redis
import com.amplitude.util.redis.RedisKey
import com.amplitude.util.redis.createRedisConnections
import com.squareup.moshi.JsonWriter
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.encodeToString
import okio.buffer
import okio.sink
import java.io.ByteArrayOutputStream
import java.util.Base64
import java.util.concurrent.ConcurrentHashMap
import java.util.zip.GZIPOutputStream
import kotlin.time.Duration

// Constants
Expand Down Expand Up @@ -66,6 +74,11 @@ internal interface CohortStorage {
*/
fun createWriter(description: CohortDescription): CohortIngestionWriter

/**
* Get a pre-gzipped JSON blob for the given cohortId at its latest lastModified.
*/
suspend fun getCohortBlob(cohortId: String): ByteArray?

/**
* Attempt to acquire a distributed lock for cohort loading.
* Returns true if lock was acquired, false if another instance is already loading.
Expand Down Expand Up @@ -103,6 +116,7 @@ internal fun getCohortStorage(
connections.readOnly,
redisConfiguration.scanLimit,
redisConfiguration.pipelineBatchSize,
CohortBlobCache(),
)
} else {
InMemoryCohortStorage()
Expand Down Expand Up @@ -168,6 +182,30 @@ internal class InMemoryCohortStorage : CohortStorage {
}
}

override suspend fun getCohortBlob(cohortId: String): ByteArray? {
val cohort = getCohort(cohortId) ?: return null
val baos = ByteArrayOutputStream()
GZIPOutputStream(baos, false).use { gz ->
val sink = gz.sink().buffer()
val jw = JsonWriter.of(sink)

jw.beginObject()
jw.name("cohortId").value(cohort.id)
jw.name("groupType").value(cohort.groupType)
jw.name("lastModified").value(cohort.lastModified)
jw.name("size").value(cohort.size.toLong())

jw.name("memberIds").beginArray()
for (id in cohort.members) jw.value(id)
jw.endArray()

jw.endObject()
jw.flush()
sink.flush()
}
return baos.toByteArray()
}

override suspend fun tryLockCohortLoading(
cohortId: String,
lockTimeoutSeconds: Int,
Expand All @@ -189,11 +227,15 @@ internal class RedisCohortStorage(
private val readOnlyRedis: Redis,
private val scanLimit: Long,
private val pipelineBatchSize: Int,
private val cohortBlobCache: CohortBlobCache,
) : CohortStorage {
companion object {
val log by logger()
}

// Track inflight blob loads to avoid duplicate reads
private val inflightBlobLoads = ConcurrentHashMap<String, CompletableDeferred<ByteArray?>>()

/**
* Stream a Redis Set via SSCAN and pipeline membership updates in sub-batches.
* The provided [pipeline] function is invoked once per SSCAN chunk with the
Expand Down Expand Up @@ -265,6 +307,7 @@ internal class RedisCohortStorage(
}

override suspend fun deleteCohort(description: CohortDescription) {
cohortBlobCache.remove(description.id)
redis.hdel(RedisKey.CohortDescriptions(prefix, projectId), description.id)
val cohortMembersKey =
RedisKey.CohortMembers(
Expand Down Expand Up @@ -376,10 +419,65 @@ internal class RedisCohortStorage(
}
}

val jsonEncodedDescription = json.encodeToString(description.copy(size = finalSize))
redis.hset(RedisKey.CohortDescriptions(prefix, projectId), mapOf(description.id to jsonEncodedDescription))
// Build and store a pre-gzipped JSON blob for this cohort version in Redis for fast fanout.
val cohortId = description.id
val cohortLastModified = description.lastModified
val blobKey = RedisKey.CohortBlob(prefix, projectId, cohortId, cohortLastModified)
val gzBytes = buildCohortBlobGzip(description, finalSize)
val b64 = Base64.getEncoder().encodeToString(gzBytes)
redis.set(blobKey, b64)

// Remove the old blob from the cache - it will be fetched again in the next /cohort/{cohortId} request
cohortBlobCache.remove(cohortId)

// Publish the cohort description only after successful blob store
val updatedDescription = description.copy(size = finalSize)
val jsonEncodedDescription = json.encodeToString(updatedDescription)
redis.hset(
RedisKey.CohortDescriptions(prefix, projectId),
mapOf(description.id to jsonEncodedDescription),
)
}
}
}

/**
* Build a gzipped JSON blob for this cohort version.
*/
private suspend fun buildCohortBlobGzip(
description: CohortDescription,
finalSize: Int,
): ByteArray {
val baos = ByteArrayOutputStream()
GZIPOutputStream(baos, false).use { gz ->
val sink = gz.sink().buffer()
val jw = JsonWriter.of(sink)

jw.beginObject()
jw.name("cohortId").value(description.id)
jw.name("groupType").value(description.groupType)
jw.name("lastModified").value(description.lastModified)
jw.name("size").value(finalSize.toLong())

jw.name("memberIds").beginArray()
val newKey =
RedisKey.CohortMembers(
prefix,
projectId,
description.id,
description.groupType,
description.lastModified,
)
readOnlyRedis.sscanChunked(newKey, REDIS_SCAN_CHUNK_SIZE) { chunk ->
for (id in chunk) jw.value(id)
}
jw.endArray()

jw.endObject()
jw.flush()
sink.flush()
}
return baos.toByteArray()
}

override suspend fun tryLockCohortLoading(
Expand All @@ -399,11 +497,44 @@ internal class RedisCohortStorage(
}
}

override suspend fun getCohortBlob(cohortId: String): ByteArray? {
val description = getCohortDescription(cohortId) ?: return null
val cohortKey = description.id
cohortBlobCache.get(cohortKey)?.let {
return it
}
// Attempt to read from Redis blob key only (read-through) with single-flight
val newDeferred = CompletableDeferred<ByteArray?>()
val existing = inflightBlobLoads.putIfAbsent(cohortKey, newDeferred)
if (existing != null) {
return existing.await()
} else {
try {
val blobKey = RedisKey.CohortBlob(prefix, projectId, description.id, description.lastModified)
val b64 = readOnlyRedis.get(blobKey)
val gz = b64?.let { runCatching { Base64.getDecoder().decode(it) }.getOrNull() }
if (gz != null) {
cohortBlobCache.put(cohortKey, gz)
}
newDeferred.complete(gz)
return gz
} catch (t: Throwable) {
newDeferred.completeExceptionally(t)
throw t
} finally {
inflightBlobLoads.remove(cohortKey, newDeferred)
}
}
}

private suspend fun getCohortMembers(
cohortId: String,
cohortGroupType: String,
cohortLastModified: Long,
): Set<String>? {
return redis.sscan(RedisKey.CohortMembers(prefix, projectId, cohortId, cohortGroupType, cohortLastModified), scanLimit)
return readOnlyRedis.sscan(
RedisKey.CohortMembers(prefix, projectId, cohortId, cohortGroupType, cohortLastModified),
scanLimit,
)
}
}
12 changes: 7 additions & 5 deletions core/src/main/kotlin/project/ProjectProxy.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import com.amplitude.assignment.AssignmentTracker
import com.amplitude.cohort.CohortApiV1
import com.amplitude.cohort.CohortLoader
import com.amplitude.cohort.CohortStorage
import com.amplitude.cohort.GetCohortResponse
import com.amplitude.cohort.USER_GROUP_TYPE
import com.amplitude.deployment.DeploymentApiV2
import com.amplitude.deployment.DeploymentLoader
Expand Down Expand Up @@ -90,10 +89,13 @@ internal class ProjectProxy(
if (cohortDescription.lastModified == lastModified) {
return EvaluationProxyResponse.error(HttpStatusCode.NoContent, "Cohort not modified")
}
val cohort =
cohortStorage.getCohort(cohortId)
?: return EvaluationProxyResponse.error(HttpStatusCode.NotFound, "Cohort members not found")
return EvaluationProxyResponse.json(HttpStatusCode.OK, GetCohortResponse.fromCohort(cohort))
cohortStorage.getCohortBlob(cohortId)?.let { gz ->
return EvaluationProxyResponse.bytes(
status = HttpStatusCode.OK,
payload = gz,
)
}
return EvaluationProxyResponse.error(HttpStatusCode.NotFound, "Cohort members not found")
}

suspend fun getCohortMemberships(
Expand Down
10 changes: 10 additions & 0 deletions core/src/main/kotlin/util/redis/RedisKey.kt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ internal sealed class RedisKey(val value: String) {
val cohortLastModified: Long,
) : RedisKey("$prefix:$STORAGE_PROTOCOL_VERSION:{projects:$projectId:cohort:$cohortId}:$cohortGroupType:$cohortLastModified")

/**
* Gzipped JSON blob for a cohort (base64-encoded gzipped bytes of a JSON object).
*/
data class CohortBlob(
val prefix: String,
val projectId: String,
val cohortId: String,
val cohortLastModified: Long,
) : RedisKey("$prefix:$STORAGE_PROTOCOL_VERSION:{projects:$projectId:cohort:$cohortId}:blob:$cohortLastModified")

data class UserCohortMemberships(
val prefix: String,
val projectId: String,
Expand Down
Loading