Proof of Concept
Pre-fix PoC (demonstrating risk before fix):
1) Run Redis with the vector module built from this repository.
2) Create a vector set and insert a vector with an M (HNSW neighbor parameter) set to an extreme value, e.g., M 1000000: VADD vec FP32 0.0 0.0 M 1000000
3) Observe rapid memory growth or crash due to huge HNSW data structures being allocated. This could be exploited for DoS or memory corruption, especially if multiple clients perform such operations.
Post-fix PoC (showing protection):
1) With the patch applied, attempt the same command. The server will reject with ERR invalid M if M is outside [4, 4096], preventing large allocations: VADD vec FP32 0.0 0.0 M 1000000 -> ERR invalid M
Alternate load PoC (RDB):
1) Create an RDB file containing a vector set with M=1000000. Without the fix, loading the RDB would allocate huge memory. With fix, the M value is loaded and validated; if M is out of bounds, loading should fail with an error.
Code Diff
diff --git a/hnsw.c b/hnsw.c
index 33e93d670ce..bd5d5503628 100644
--- a/hnsw.c
+++ b/hnsw.c
@@ -62,8 +62,6 @@
* used when deleting nodes for the search step
* needed sometimes to reconnect nodes that remain
* orphaned of one link. */
-#define HNSW_DEFAULT_M 16 /* Useful if 0 is given at creation time. */
-#define HNSW_MAX_M 1024 /* Hard limit for M. */
static void (*hfree)(void *p) = free;
static void *(*hmalloc)(size_t s) = malloc;
diff --git a/hnsw.h b/hnsw.h
index ee6186785a2..dc94f4cec5c 100644
--- a/hnsw.h
+++ b/hnsw.h
@@ -11,6 +11,9 @@
#include <pthread.h>
#include <stdatomic.h>
+#define HNSW_DEFAULT_M 16 /* Used when 0 is given at creation time. */
+#define HNSW_MIN_M 4 /* Probably even too low already. */
+#define HNSW_MAX_M 4096 /* Safeguard sanity limit. */
#define HNSW_MAX_THREADS 32 /* Maximum number of concurrent threads */
/* Quantization types you can enable at creation time in hnsw_new() */
diff --git a/vset.c b/vset.c
index f51ea0ee9f8..0b1de2da27f 100644
--- a/vset.c
+++ b/vset.c
@@ -100,13 +100,13 @@ float *applyProjection(const float *input, const float *proj_matrix,
}
/* Create the vector as HNSW+Dictionary combined data structure. */
-struct vsetObject *createVectorSetObject(unsigned int dim, uint32_t quant_type) {
+struct vsetObject *createVectorSetObject(unsigned int dim, uint32_t quant_type, uint32_t hnsw_M) {
struct vsetObject *o;
o = RedisModule_Alloc(sizeof(*o));
if (!o) return NULL;
o->id = VectorSetTypeNextId++;
- o->hnsw = hnsw_new(dim,quant_type,0);
+ o->hnsw = hnsw_new(dim,quant_type,hnsw_M);
if (!o->hnsw) {
RedisModule_Free(o);
return NULL;
@@ -380,7 +380,8 @@ int VADD_CASReply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
return retval;
}
-/* VADD key [REDUCE dim] FP32|VALUES vector value [CAS] [NOQUANT] */
+/* VADD key [REDUCE dim] FP32|VALUES vector value [CAS] [NOQUANT] [BIN] [Q8]
+ * [M count] */
int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
RedisModule_AutoMemory(ctx); /* Use automatic memory management. */
@@ -392,6 +393,7 @@ int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
int consumed_args;
int cas = 0; // Threaded check-and-set style insert.
long long ef = VSET_DEFAULT_C_EF; // HNSW creation time EF for new nodes.
+ long long hnsw_create_M = HNSW_DEFAULT_M; // HNSW creation default M value.
float *vec = parseVector(argv, argc, 2, &dim, &reduce_dim, &consumed_args);
RedisModuleString *attrib = NULL; // Attributes if passed via ATTRIB.
if (!vec)
@@ -415,6 +417,15 @@ int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
return RedisModule_ReplyWithError(ctx, "ERR invalid EF");
}
j++; // skip argument.
+ } else if (!strcasecmp(opt, "M") && j+1 < argc) {
+ if (RedisModule_StringToLongLong(argv[j+1], &hnsw_create_M)
+ != REDISMODULE_OK || hnsw_create_M < HNSW_MIN_M ||
+ hnsw_create_M > HNSW_MAX_M)
+ {
+ RedisModule_Free(vec);
+ return RedisModule_ReplyWithError(ctx, "ERR invalid M");
+ }
+ j++; // skip argument.
} else if (!strcasecmp(opt, "SETATTR") && j+1 < argc) {
attrib = argv[j+1];
j++; // skip argument.
@@ -465,7 +476,7 @@ int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
* key would be left empty until the threaded part
* does not return. It's also pointless to try try
* doing threaded first elemetn insertion. */
- vset = createVectorSetObject(reduce_dim ? reduce_dim : dim, quant_type);
+ vset = createVectorSetObject(reduce_dim ? reduce_dim : dim, quant_type, hnsw_create_M);
/* Initialize projection if requested */
if (reduce_dim) {
@@ -485,7 +496,13 @@ int VADD_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
if (vset->hnsw->quant_type != quant_type) {
RedisModule_Free(vec);
return RedisModule_ReplyWithError(ctx,
- "ERR use the same quantization of the existing vector set");
+ "ERR asked quantization mismatch with existing vector set");
+ }
+
+ if (vset->hnsw->M != hnsw_create_M) {
+ RedisModule_Free(vec);
+ return RedisModule_ReplyWithError(ctx,
+ "ERR asked M value mismatch with existing vector set");
}
if ((vset->proj_matrix == NULL && vset->hnsw->vector_dim != dim) ||
@@ -1227,12 +1244,16 @@ int VINFO_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc)
struct vsetObject *vset = RedisModule_ModuleTypeGetValue(key);
/* Reply with hash */
- RedisModule_ReplyWithMap(ctx, 7);
+ RedisModule_ReplyWithMap(ctx, 8);
/* Quantization type */
RedisModule_ReplyWithSimpleString(ctx, "quant-type");
RedisModule_ReplyWithSimpleString(ctx, vectorSetGetQuantName(vset));
+ /* HNSW M value */
+ RedisModule_ReplyWithSimpleString(ctx, "hnsw-m");
+ RedisModule_ReplyWithLongLong(ctx, vset->hnsw->M);
+
/* Vector dimensionality. */
RedisModule_ReplyWithSimpleString(ctx, "vector-dim");
RedisModule_ReplyWithLongLong(ctx, vset->hnsw->vector_dim);
@@ -1270,7 +1291,10 @@ void VectorSetRdbSave(RedisModuleIO *rdb, void *value) {
struct vsetObject *vset = value;
RedisModule_SaveUnsigned(rdb, vset->hnsw->vector_dim);
RedisModule_SaveUnsigned(rdb, vset->hnsw->node_count);
- RedisModule_SaveUnsigned(rdb, vset->hnsw->quant_type);
+
+ uint32_t hnsw_config = (vset->hnsw->quant_type & 0xff) |
+ ((vset->hnsw->M & 0xffff) << 8);
+ RedisModule_SaveUnsigned(rdb, hnsw_config);
uint32_t save_flags = 0;
if (vset->proj_matrix) save_flags |= SAVE_FLAG_HAS_PROJMATRIX;
@@ -1316,9 +1340,13 @@ void *VectorSetRdbLoad(RedisModuleIO *rdb, int encver) {
uint32_t dim = RedisModule_LoadUnsigned(rdb);
uint64_t elements = RedisModule_LoadUnsigned(rdb);
- uint32_t quant_type = RedisModule_LoadUnsigned(rdb);
+ uint32_t hnsw_config = RedisModule_LoadUnsigned(rdb);
+ uint32_t quant_type = hnsw_config & 0xff;
+ uint32_t hnsw_m = (hnsw_config >> 8) & 0xffff;
- struct vsetObject *vset = createVectorSetObject(dim,quant_type);
+ if (hnsw_m == 0) hnsw_m = 16; // Default, useful for RDB files predating
+ // this configuration parameter.
+ struct vsetObject *vset = createVectorSetObject(dim,quant_type,hnsw_m);
if (!vset) return NULL;
/* Load projection matrix if present */