From 33bfac37bab7b754604c3fd8a5563aa1398c651d Mon Sep 17 00:00:00 2001 From: chzhoo Date: Tue, 18 Nov 2025 21:27:15 +0800 Subject: [PATCH] Optimize zset memory usage by embedding element in skiplist (#2508) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit By default, when the number of elements in a zset exceeds 128, the underlying data structure adopts a skiplist. We can reduce memory usage by embedding elements into the skiplist nodes. Change the `zskiplistNode` memory layout as follows: ``` Before +-------------+ +-----> | element-sds | | +-------------+ | +------------------+-------+------------------+---------+-----+---------+ | element--pointer | score | backward-pointer | level-0 | ... | level-N | +------------------+-------+------------------+---------+-----+---------+ After +-------+------------------+---------+-----+---------+-------------+ + score | backward-pointer | level-0 | ... | level-N | element-sds | +-------+------------------+---------+-----+---------+-------------+ ``` Before the embedded SDS representation, we include one byte representing the size of the SDS header, i.e. the offset into the SDS representation where that actual string starts. The memory saving is therefore one pointer minus one byte = 7 bytes per element, regardless of other factors such as element size or number of elements. ### Benchmark step I generated the test data using the following lua script && cli command. And check memory usage using the `info` command. **lua script** ``` local start_idx = tonumber(ARGV[1]) local end_idx = tonumber(ARGV[2]) local elem_count = tonumber(ARGV[3]) for i = start_idx, end_idx do local key = "zset:" .. string.format("%012d", i) local members = {} for j = 0, elem_count - 1 do table.insert(members, j) table.insert(members, "member:" .. j) end redis.call("ZADD", key, unpack(members)) end return "OK: Created " .. (end_idx - start_idx + 1) .. " zsets" ``` **valkey-cli command** `valkey-cli EVAL "$(catcreate_zsets.lua)" 0 0 100000 ${ZSET_ELEMENT_NUM}` ### Benchmark result |number of elements in a zset | memory usage before optimization | memory usage after optimization | change | |-------|-------|-------|-------| | 129 | 1047MB | 943MB | -9.9% | | 256 | 2010MB| 1803MB| -10.3%| | 512 | 3904MB|3483MB| -10.8%| --------- Signed-off-by: chzhoo Co-authored-by: Viktor Söderqvist --- src/aof.c | 2 +- src/db.c | 4 +- src/debug.c | 3 +- src/defrag.c | 18 ++-- src/geo.c | 4 +- src/module.c | 9 +- src/object.c | 6 +- src/rdb.c | 4 +- src/server.c | 2 +- src/server.h | 5 +- src/sort.c | 5 +- src/t_zset.c | 156 ++++++++++++++++++++++------------ src/valkey-check-rdb.c | 3 +- tests/modules/zset.c | 99 +++++++++++++++++++++ tests/unit/moduleapi/zset.tcl | 69 +++++++++++++++ 15 files changed, 303 insertions(+), 86 deletions(-) diff --git a/src/aof.c b/src/aof.c index 23dfcc67a..1ec77dc0e 100644 --- a/src/aof.c +++ b/src/aof.c @@ -1922,7 +1922,7 @@ int rewriteSortedSetObject(rio *r, robj *key, robj *o) { return 0; } } - sds ele = node->ele; + sds ele = zslGetNodeElement(node); if (!rioWriteBulkDouble(r, node->score) || !rioWriteBulkString(r, ele, sdslen(ele))) { hashtableResetIterator(&iter); return 0; diff --git a/src/db.c b/src/db.c index f398191d2..91aa67a19 100644 --- a/src/db.c +++ b/src/db.c @@ -1054,7 +1054,7 @@ void hashtableScanCallback(void *privdata, void *entry) { key = (sds)entry; } else if (o->type == OBJ_ZSET) { zskiplistNode *node = (zskiplistNode *)entry; - key = node->ele; + key = zslGetNodeElement(node); /* zset data is copied after filtering by key */ } else if (o->type == OBJ_HASH) { key = entryGetField(entry); @@ -1077,7 +1077,7 @@ void hashtableScanCallback(void *privdata, void *entry) { if (o->type == OBJ_ZSET) { /* zset data is copied */ zskiplistNode *node = (zskiplistNode *)entry; - key = sdsdup(node->ele); + key = sdsdup(zslGetNodeElement(node)); if (!data->only_keys) { char buf[MAX_LONG_DOUBLE_CHARS]; int len = ld2string(buf, sizeof(buf), node->score, LD_STR_AUTO); diff --git a/src/debug.c b/src/debug.c index 56dbdfbb2..ee0720f04 100644 --- a/src/debug.c +++ b/src/debug.c @@ -215,7 +215,8 @@ void xorObjectDigest(serverDb *db, robj *keyobj, unsigned char *digest, robj *o) const int len = fpconv_dtoa(node->score, buf); buf[len] = '\0'; memset(eledigest, 0, 20); - mixDigest(eledigest, node->ele, sdslen(node->ele)); + sds ele = zslGetNodeElement(node); + mixDigest(eledigest, ele, sdslen(ele)); mixDigest(eledigest, buf, strlen(buf)); xorDigest(digest, eledigest, 20); } diff --git a/src/defrag.c b/src/defrag.c index 79a8b473d..f4cb5f0b2 100644 --- a/src/defrag.c +++ b/src/defrag.c @@ -266,12 +266,11 @@ static void activeDefragZsetNode(void *privdata, void *entry_ref) { zskiplistNode **node_ref = (zskiplistNode **)entry_ref; zskiplistNode *node = *node_ref; - /* defragment node internals */ - sds newsds = activeDefragSds(node->ele); - if (newsds) node->ele = newsds; + size_t allocation_size; + zskiplistNode *newnode = activeDefragAllocWithoutFree(node, &allocation_size); + if (newnode == NULL) return; const double score = node->score; - const sds ele = node->ele; /* find skiplist pointers that need to be updated if we end up moving the * skiplist node. */ @@ -283,7 +282,7 @@ static void activeDefragZsetNode(void *privdata, void *entry_ref) { zskiplistNode *next = x->level[i].forward; while (next && (next->score < score || - (next->score == score && sdscmp(next->ele, ele) < 0))) { + (next->score == score && next != node))) { x = next; next = x->level[i].forward; } @@ -292,12 +291,9 @@ static void activeDefragZsetNode(void *privdata, void *entry_ref) { /* should have arrived at intended node */ serverAssert(x->level[0].forward == node); - /* try to defrag the skiplist record itself */ - zskiplistNode *newnode = activeDefragAlloc(node); - if (newnode) { - zslUpdateNode(zsl, node, newnode, update); - *node_ref = newnode; /* update hashtable pointer */ - } + zslUpdateNode(zsl, node, newnode, update); + *node_ref = newnode; /* update hashtable pointer */ + allocatorDefragFree(node, allocation_size); } #define DEFRAG_SDS_DICT_NO_VAL 0 diff --git a/src/geo.c b/src/geo.c index d7895262c..38027476b 100644 --- a/src/geo.c +++ b/src/geo.c @@ -322,7 +322,8 @@ int geoGetPointsInRange(robj *zobj, double min, double max, GeoShape *shape, geo if (!zslValueLteMax(ln->score, &range)) break; if (geoWithinShape(shape, ln->score, xy, &distance) == C_OK) { /* Append the new element. */ - geoArrayAppend(ga, xy, distance, ln->score, sdsdup(ln->ele)); + sds ele = zslGetNodeElement(ln); + geoArrayAppend(ga, xy, distance, ln->score, sdsdup(ele)); } if (ga->used && limit && ga->used >= limit) break; ln = ln->level[0].forward; @@ -825,6 +826,7 @@ void georadiusGeneric(client *c, int srcKeyIndex, int flags) { totelelen += elelen; znode = zslInsert(zs->zsl, score, gp->member); serverAssert(hashtableAdd(zs->ht, znode)); + sdsfree(gp->member); gp->member = NULL; } diff --git a/src/module.c b/src/module.c index 5c770141b..ab9361153 100644 --- a/src/module.c +++ b/src/module.c @@ -5165,7 +5165,8 @@ ValkeyModuleString *VM_ZsetRangeCurrentElement(ValkeyModuleKey *key, double *sco } else if (key->value->encoding == OBJ_ENCODING_SKIPLIST) { zskiplistNode *ln = key->u.zset.current; if (score) *score = ln->score; - str = createStringObject(ln->ele, sdslen(ln->ele)); + sds ele = zslGetNodeElement(ln); + str = createStringObject(ele, sdslen(ele)); } else { serverPanic("Unsupported zset encoding"); } @@ -5222,7 +5223,7 @@ int VM_ZsetRangeNext(ValkeyModuleKey *key) { key->u.zset.er = 1; return 0; } else if (key->u.zset.type == VALKEYMODULE_ZSET_RANGE_LEX) { - if (!zslLexValueLteMax(next->ele, &key->u.zset.lrs)) { + if (!zslLexValueLteMax(zslGetNodeElement(next), &key->u.zset.lrs)) { key->u.zset.er = 1; return 0; } @@ -5284,7 +5285,7 @@ int VM_ZsetRangePrev(ValkeyModuleKey *key) { key->u.zset.er = 1; return 0; } else if (key->u.zset.type == VALKEYMODULE_ZSET_RANGE_LEX) { - if (!zslLexValueGteMin(prev->ele, &key->u.zset.lrs)) { + if (!zslLexValueGteMin(zslGetNodeElement(prev), &key->u.zset.lrs)) { key->u.zset.er = 1; return 0; } @@ -11418,7 +11419,7 @@ static void moduleScanKeyHashtableCallback(void *privdata, void *entry) { /* no value */ } else if (o->type == OBJ_ZSET) { zskiplistNode *node = (zskiplistNode *)entry; - key = node->ele; + key = zslGetNodeElement(node); value = createStringObjectFromLongDouble(node->score, 0); } else if (o->type == OBJ_HASH) { key = entryGetField(entry); diff --git a/src/object.c b/src/object.c index 41a7bd50c..13efeb502 100644 --- a/src/object.c +++ b/src/object.c @@ -667,8 +667,9 @@ void dismissZsetObject(robj *o, size_t size_hint) { if (size_hint / zsl->length >= server.page_size) { zskiplistNode *zn = zsl->tail; while (zn != NULL) { - dismissSds(zn->ele); - zn = zn->backward; + zskiplistNode *next = zn->backward; + dismissMemory(zn, 0); + zn = next; } } @@ -1190,7 +1191,6 @@ size_t objectComputeSize(robj *key, robj *o, size_t sample_size, int dbid) { asize += sizeof(zset) + sizeof(zskiplist) + hashtableMemUsage(ht) + zmalloc_size(zsl->header); while (znode != NULL && samples < sample_size) { - elesize += sdsAllocSize(znode->ele); elesize += zmalloc_size(znode); samples++; znode = znode->level[0].forward; diff --git a/src/rdb.c b/src/rdb.c index 6d0f8af61..9b07d4d36 100644 --- a/src/rdb.c +++ b/src/rdb.c @@ -959,7 +959,8 @@ ssize_t rdbSaveObject(rio *rdb, robj *o, robj *key, int dbid) { * O(1) instead of O(log(N)). */ zskiplistNode *zn = zsl->tail; while (zn != NULL) { - if ((n = rdbSaveRawString(rdb, (unsigned char *)zn->ele, sdslen(zn->ele))) == -1) { + sds ele = zslGetNodeElement(zn); + if ((n = rdbSaveRawString(rdb, (unsigned char *)ele, sdslen(ele))) == -1) { return -1; } nwritten += n; @@ -2095,6 +2096,7 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key, int dbid, int *error) { totelelen += sdslen(sdsele); znode = zslInsert(zs->zsl, score, sdsele); + sdsfree(sdsele); if (!hashtableAdd(zs->ht, znode)) { rdbReportCorruptRDB("Duplicate zset fields detected"); decrRefCount(o); diff --git a/src/server.c b/src/server.c index eeaec5d9d..44bd09ad4 100644 --- a/src/server.c +++ b/src/server.c @@ -561,7 +561,7 @@ hashtableType setHashtableType = { const void *zsetHashtableGetKey(const void *element) { const zskiplistNode *node = element; - return node->ele; + return zslGetNodeElement(node); } /* Sorted sets hash (note: a skiplist is used in addition to the hash table) */ diff --git a/src/server.h b/src/server.h index 7411031d8..b930b7511 100644 --- a/src/server.h +++ b/src/server.h @@ -1423,7 +1423,6 @@ struct sharedObjectsStruct { /* ZSETs use a specialized version of Skiplists */ typedef struct zskiplistNode { - sds ele; double score; struct zskiplistNode *backward; struct zskiplistLevel { @@ -1434,6 +1433,7 @@ typedef struct zskiplistNode { * So we use it in order to hold the height of the node, which is the number of levels. */ unsigned long span; } level[]; + /* After the level[], sds header length (1 byte) and an embedded sds element are stored. */ } zskiplistNode; typedef struct zskiplist { @@ -3275,8 +3275,9 @@ typedef struct { zskiplist *zslCreate(void); void zslFree(zskiplist *zsl); -zskiplistNode *zslInsert(zskiplist *zsl, double score, sds ele); +zskiplistNode *zslInsert(zskiplist *zsl, double score, const_sds ele); zskiplistNode *zslNthInRange(zskiplist *zsl, zrangespec *range, long n, long *rank); +sds zslGetNodeElement(const zskiplistNode *x); double zzlGetScore(unsigned char *sptr); void zzlNext(unsigned char *zl, unsigned char **eptr, unsigned char **sptr); void zzlPrev(unsigned char *zl, unsigned char **eptr, unsigned char **sptr); diff --git a/src/sort.c b/src/sort.c index 754ebef4a..dbd127452 100644 --- a/src/sort.c +++ b/src/sort.c @@ -434,7 +434,7 @@ void sortCommandGeneric(client *c, int readonly) { while (rangelen--) { serverAssertWithInfo(c, sortval, ln != NULL); - sdsele = ln->ele; + sdsele = zslGetNodeElement(ln); vector[j].obj = createStringObject(sdsele, sdslen(sdsele)); vector[j].u.score = 0; vector[j].u.cmpobj = NULL; @@ -451,7 +451,8 @@ void sortCommandGeneric(client *c, int readonly) { void *next; while (hashtableNext(&iter, &next)) { zskiplistNode *node = next; - vector[j].obj = createStringObject(node->ele, sdslen(node->ele)); + sds sdsele = zslGetNodeElement(node); + vector[j].obj = createStringObject(sdsele, sdslen(sdsele)); vector[j].u.score = 0; vector[j].u.cmpobj = NULL; j++; diff --git a/src/t_zset.c b/src/t_zset.c index 421788dd4..d499e86ab 100644 --- a/src/t_zset.c +++ b/src/t_zset.c @@ -115,15 +115,50 @@ static inline void zslSetNodeHeight(zskiplistNode *x, int height) { } /* Create a skiplist node with the specified number of levels. - * The SDS string 'ele' is referenced by the node after the call. */ -static zskiplistNode *zslCreateNode(int height, double score, sds ele) { - zskiplistNode *zn = zmalloc(sizeof(*zn) + height * sizeof(struct zskiplistLevel)); + * By embedding elements and levels into the skiplist nodes, + * we achieve good cache-friendliness and a compact memory structure. + * + * The memory layout is as follows: + * + * +-------+------------------+---------+-----+---------+-----------------+-------------+ + * | score | backward-pointer | level-0 | ... | level-N | sds-header-size | element-sds | + * +-------+------------------+---------+-----+---------+-----------------+-------------+ + * + * sds-header-size and element-sds are only valid for non-header nodes. + */ +static zskiplistNode *zslCreateNode(int height, double score, const_sds ele) { + size_t ele_sds_len = sdslen(ele); + char ele_sds_type = sdsReqType(ele_sds_len); + size_t ele_sds_size = sdsReqSize(ele_sds_len, ele_sds_type); + /* Allocate enough space for the node, levels, and the element sds. + * We include one extra byte representing the sds header size, + * which is the offset into the embedded sds data where the + * string content starts. */ + zskiplistNode *zn = zmalloc(sizeof(*zn) + height * sizeof(struct zskiplistLevel) + 1 + ele_sds_size); zn->score = score; - zn->ele = ele; zslSetNodeHeight(zn, height); + char *data = ((char *)(zn + 1)) + height * sizeof(struct zskiplistLevel); + *data++ = sdsHdrSize(ele_sds_type); + sdswrite(data, ele_sds_size, ele_sds_type, ele, ele_sds_len); return zn; } +static zskiplistNode *zslCreateHeaderNode(void) { + /* Allocate enough space for the node and levels. */ + zskiplistNode *zn = zmalloc(sizeof(*zn) + ZSKIPLIST_MAXLEVEL * sizeof(struct zskiplistLevel)); + zslSetNodeHeight(zn, ZSKIPLIST_MAXLEVEL); + return zn; +} + +/* Helper function to return the element string from a skip list node. */ +sds zslGetNodeElement(const zskiplistNode *x) { + unsigned char *data = (void *)(x + 1); + data += zslGetNodeHeight(x) * sizeof(struct zskiplistLevel); + uint8_t hdr_size = *(uint8_t *)data; + data += 1 + hdr_size; + return (sds)data; +} + /* Create a new skiplist. */ zskiplist *zslCreate(void) { int j; @@ -132,7 +167,7 @@ zskiplist *zslCreate(void) { zsl = zmalloc(sizeof(*zsl)); zsl->level = 1; zsl->length = 0; - zsl->header = zslCreateNode(ZSKIPLIST_MAXLEVEL, 0, NULL); + zsl->header = zslCreateHeaderNode(); for (j = 0; j < ZSKIPLIST_MAXLEVEL; j++) { zsl->header->level[j].forward = NULL; zsl->header->level[j].span = 0; @@ -142,11 +177,8 @@ zskiplist *zslCreate(void) { return zsl; } -/* Free the specified skiplist node. The referenced SDS string representation - * of the element is freed too, unless node->ele is set to NULL before calling - * this function. */ +/* Free the specified skiplist node. */ static void zslFreeNode(zskiplistNode *node) { - sdsfree(node->ele); zfree(node); } @@ -192,7 +224,7 @@ static int zslCompareNodes(const zskiplistNode *a, const zskiplistNode *b) { if (a->score > b->score) return 1; if (a->score < b->score) return -1; - return sdscmp(a->ele, b->ele); + return sdscmp(zslGetNodeElement(a), zslGetNodeElement(b)); } /* Insert a node in the skiplist. Assumes the element does not already exist in @@ -251,9 +283,8 @@ static zskiplistNode *zslInsertNode(zskiplist *zsl, zskiplistNode *node) { } /* Insert a new node in the skiplist. Assumes the element does not already - * exist (up to the caller to enforce that). The skiplist takes ownership - * of the passed SDS string 'ele'. */ -zskiplistNode *zslInsert(zskiplist *zsl, double score, sds ele) { + * exist (up to the caller to enforce that). The string 'ele' is copied. */ +zskiplistNode *zslInsert(zskiplist *zsl, double score, const_sds ele) { const int level = zslRandomLevel(); zskiplistNode *node = zslCreateNode(level, score, ele); zslInsertNode(zsl, node); @@ -330,13 +361,9 @@ static zskiplistNode *zslUpdateScore(zskiplist *zsl, zskiplistNode *node, double serverAssert(x->level[0].forward == node); zslDeleteNode(zsl, node, update); - /* update pointer inside hashtable with new node */ - zskiplistNode *new_node = zslInsert(zsl, newscore, node->ele); - /* We reused the old node->ele SDS string, free the node now - * since zslInsert created a new node */ - node->ele = NULL; - zslFreeNode(node); - return new_node; + node->score = newscore; /* reuse existing node to avoid memory allocation */ + zslInsertNode(zsl, node); + return node; } int zslValueGteMin(double value, zrangespec *spec) { @@ -460,8 +487,9 @@ static unsigned long zslDeleteRangeByScore(zskiplist *zsl, zrangespec *range, ha while (x && zslValueLteMax(x->score, range)) { zskiplistNode *next = x->level[0].forward; zslDeleteNode(zsl, x, update); - hashtableDelete(ht, x->ele); - zslFreeNode(x); /* Here is where x->ele is actually released. */ + sds ele = zslGetNodeElement(x); + hashtablePop(ht, ele, NULL); + zslFreeNode(x); removed++; x = next; } @@ -476,7 +504,10 @@ static unsigned long zslDeleteRangeByLex(zskiplist *zsl, zlexrangespec *range, h x = zsl->header; for (i = zsl->level - 1; i >= 0; i--) { - while (x->level[i].forward && !zslLexValueGteMin(x->level[i].forward->ele, range)) x = x->level[i].forward; + while (x->level[i].forward && + !zslLexValueGteMin(zslGetNodeElement(x->level[i].forward), range)) { + x = x->level[i].forward; + } update[i] = x; } @@ -484,10 +515,10 @@ static unsigned long zslDeleteRangeByLex(zskiplist *zsl, zlexrangespec *range, h x = x->level[0].forward; /* Delete nodes while in range. */ - while (x && zslLexValueLteMax(x->ele, range)) { + while (x && zslLexValueLteMax(zslGetNodeElement(x), range)) { zskiplistNode *next = x->level[0].forward; zslDeleteNode(zsl, x, update); - hashtableDelete(ht, x->ele); + hashtableDelete(ht, zslGetNodeElement(x)); zslFreeNode(x); /* Here is where x->ele is actually released. */ removed++; x = next; @@ -516,7 +547,7 @@ static unsigned long zslDeleteRangeByRank(zskiplist *zsl, unsigned int start, un while (x && traversed <= end) { zskiplistNode *next = x->level[0].forward; zslDeleteNode(zsl, x, update); - hashtableDelete(ht, x->ele); + hashtableDelete(ht, zslGetNodeElement(x)); zslFreeNode(x); removed++; traversed++; @@ -694,9 +725,11 @@ static int zslIsInLexRange(zskiplist *zsl, zlexrangespec *range) { int cmp = sdscmplex(range->min, range->max); if (cmp > 0 || (cmp == 0 && (range->minex || range->maxex))) return 0; x = zsl->tail; - if (x == NULL || !zslLexValueGteMin(x->ele, range)) return 0; + sds ele = zslGetNodeElement(x); + if (x == NULL || !zslLexValueGteMin(ele, range)) return 0; x = zsl->header->level[0].forward; - if (x == NULL || !zslLexValueLteMax(x->ele, range)) return 0; + ele = zslGetNodeElement(x); + if (x == NULL || !zslLexValueLteMax(ele, range)) return 0; return 1; } @@ -717,7 +750,7 @@ zskiplistNode *zslNthInLexRange(zskiplist *zsl, zlexrangespec *range, long n) { /* Go forward while *OUT* of range at level of zsl->level-1. */ x = zsl->header; i = zsl->level - 1; - while (x->level[i].forward && !zslLexValueGteMin(x->level[i].forward->ele, range)) { + while (x->level[i].forward && !zslLexValueGteMin(zslGetNodeElement(x->level[i].forward), range)) { edge_rank += zslGetNodeSpanAtLevel(x, i); x = x->level[i].forward; } @@ -728,7 +761,7 @@ zskiplistNode *zslNthInLexRange(zskiplist *zsl, zlexrangespec *range, long n) { if (n >= 0) { for (i = zsl->level - 2; i >= 0; i--) { /* Go forward while *OUT* of range. */ - while (x->level[i].forward && !zslLexValueGteMin(x->level[i].forward->ele, range)) { + while (x->level[i].forward && !zslLexValueGteMin(zslGetNodeElement(x->level[i].forward), range)) { /* Count the rank of the last element smaller than the range. */ edge_rank += zslGetNodeSpanAtLevel(x, i); x = x->level[i].forward; @@ -748,11 +781,11 @@ zskiplistNode *zslNthInLexRange(zskiplist *zsl, zlexrangespec *range, long n) { x = zslGetElementByRankFromNode(last_highest_level_node, zsl->level - 1, rank_diff); } /* Check if score <= max. */ - if (x && !zslLexValueLteMax(x->ele, range)) return NULL; + if (x && !zslLexValueLteMax(zslGetNodeElement(x), range)) return NULL; } else { for (i = zsl->level - 1; i >= 0; i--) { /* Go forward while *IN* range. */ - while (x->level[i].forward && zslLexValueLteMax(x->level[i].forward->ele, range)) { + while (x->level[i].forward && zslLexValueLteMax(zslGetNodeElement(x->level[i].forward), range)) { /* Count the rank of the last element in range. */ edge_rank += zslGetNodeSpanAtLevel(x, i); x = x->level[i].forward; @@ -772,7 +805,7 @@ zskiplistNode *zslNthInLexRange(zskiplist *zsl, zlexrangespec *range, long n) { x = zslGetElementByRankFromNode(last_highest_level_node, zsl->level - 1, rank_diff); } /* Check if score >= min. */ - if (x && !zslLexValueGteMin(x->ele, range)) return NULL; + if (x && !zslLexValueGteMin(zslGetNodeElement(x), range)) return NULL; } return x; @@ -1287,6 +1320,7 @@ void zsetConvertAndExpand(robj *zobj, int encoding, unsigned long cap) { ele = sdsnewlen((char *)vstr, vlen); node = zslInsert(zs->zsl, score, ele); + sdsfree(ele); serverAssert(hashtableAdd(zs->ht, node)); zzlNext(zl, &eptr, &sptr); } @@ -1308,7 +1342,7 @@ void zsetConvertAndExpand(robj *zobj, int encoding, unsigned long cap) { zfree(zs->zsl); while (node) { - zl = zzlInsertAt(zl, NULL, node->ele, node->score); + zl = zzlInsertAt(zl, NULL, zslGetNodeElement(node), node->score); next = node->level[0].forward; zslFreeNode(node); node = next; @@ -1513,7 +1547,6 @@ int zsetAdd(robj *zobj, double score, sds ele, int in_flags, int *out_flags, dou } return 1; } else if (!xx) { - ele = sdsdup(ele); zskiplistNode *new_node = zslInsert(zs->zsl, score, ele); serverAssert(hashtableAdd(zs->ht, new_node)); *out_flags |= ZADD_OUT_ADDED; @@ -1666,9 +1699,8 @@ robj *zsetDup(robj *o) { * O(1) instead of O(log(N)). */ ln = zsl->tail; while (llen--) { - ele = ln->ele; - sds new_ele = sdsdup(ele); - zskiplistNode *znode = zslInsert(new_zs->zsl, ln->score, new_ele); + ele = zslGetNodeElement(ln); + zskiplistNode *znode = zslInsert(new_zs->zsl, ln->score, ele); hashtableAdd(new_zs->ht, znode); ln = ln->backward; } @@ -1702,8 +1734,9 @@ static void zsetTypeRandomElement(robj *zsetobj, unsigned long zsetsize, listpac void *entry; hashtableFairRandomEntry(zs->ht, &entry); zskiplistNode *node = entry; - key->sval = (unsigned char *)node->ele; - key->slen = sdslen(node->ele); + sds ele = zslGetNodeElement(node); + key->sval = (unsigned char *)ele; + key->slen = sdslen(ele); if (score) *score = node->score; } else if (zsetobj->encoding == OBJ_ENCODING_LISTPACK) { listpackEntry val; @@ -2210,7 +2243,7 @@ static int zuiNext(zsetopsrc *op, zsetopval *val) { zzlPrev(it->zl.zl, &it->zl.eptr, &it->zl.sptr); } else if (op->encoding == OBJ_ENCODING_SKIPLIST) { if (it->sl.node == NULL) return 0; - val->ele = it->sl.node->ele; + val->ele = zslGetNodeElement(it->sl.node); val->score = it->sl.node->score; /* Move to next element. (going backwards, see zuiInitIterator) */ @@ -2338,7 +2371,8 @@ static size_t zsetHashtableGetMaxElementLength(hashtable *ht, size_t *totallen) void *next; while (hashtableNext(&iter, &next)) { zskiplistNode *node = next; - size_t elelen = sdslen(node->ele); + sds ele = zslGetNodeElement(node); + size_t elelen = sdslen(ele); if (elelen > maxelelen) maxelelen = elelen; if (totallen) (*totallen) += elelen; } @@ -2395,6 +2429,7 @@ static void zdiffAlgorithm1(zsetopsrc *src, long setnum, zset *dstzset, size_t * hashtableAdd(dstzset->ht, znode); if (sdslen(tmp) > *maxelelen) *maxelelen = sdslen(tmp); (*totelelen) += sdslen(tmp); + sdsfree(tmp); } } zuiClearIterator(&src[0]); @@ -2433,6 +2468,7 @@ static void zdiffAlgorithm2(zsetopsrc *src, long setnum, zset *dstzset, size_t * if (j == 0) { tmp = zuiNewSdsFromValue(&zval); znode = zslInsert(dstzset->zsl, zval.score, tmp); + sdsfree(tmp); hashtableAdd(dstzset->ht, znode); cardinality++; } else { @@ -2689,6 +2725,7 @@ static void zunionInterDiffGenericCommand(client *c, robj *dstkey, int numkeysIn hashtableAdd(dstzset->ht, znode); totelelen += sdslen(tmp); if (sdslen(tmp) > maxelelen) maxelelen = sdslen(tmp); + sdsfree(tmp); } } zuiClearIterator(&src[0]); @@ -2717,14 +2754,17 @@ static void zunionInterDiffGenericCommand(client *c, robj *dstkey, int numkeysIn /* If we don't have it, we need to create a new entry. */ void *existing; if (hashtableFindPositionForInsert(dstzset->ht, sdsval, &position, &existing)) { - zskiplistNode *new_node = zslCreateNode(zslRandomLevel(), score, zuiNewSdsFromValue(&zval)); + sds tmp_ele = zuiNewSdsFromValue(&zval); + zskiplistNode *new_node = zslCreateNode(zslRandomLevel(), score, tmp_ele); + sdsfree(tmp_ele); hashtableInsertAtPosition(dstzset->ht, new_node, &position); /* Remember the longest single element encountered, * to understand if it's possible to convert to listpack * at the end. */ - totelelen += sdslen(new_node->ele); - if (sdslen(new_node->ele) > maxelelen) { - maxelelen = sdslen(new_node->ele); + sds ele = zslGetNodeElement(new_node); + totelelen += sdslen(ele); + if (sdslen(ele) > maxelelen) { + maxelelen = sdslen(ele); } } else { /* Update the score with the score of the new instance @@ -2785,7 +2825,8 @@ static void zunionInterDiffGenericCommand(client *c, robj *dstkey, int numkeysIn while (zn != NULL) { if (withscores && c->resp > 2) addReplyArrayLen(c, 2); - addReplyBulkCBuffer(c, zn->ele, sdslen(zn->ele)); + sds ele = zslGetNodeElement(zn); + addReplyBulkCBuffer(c, ele, sdslen(ele)); if (withscores) addReplyDouble(c, zn->score); zn = zn->level[0].forward; } @@ -3086,7 +3127,7 @@ void genericZrangebyrankCommand(zrange_result_handler *handler, while (rangelen--) { serverAssertWithInfo(c, zobj, ln != NULL); - sds ele = ln->ele; + sds ele = zslGetNodeElement(ln); handler->emitResultFromCBuffer(handler, ele, sdslen(ele), ln->score); ln = reverse ? ln->backward : ln->level[0].forward; } @@ -3210,7 +3251,8 @@ void genericZrangebyscoreCommand(zrange_result_handler *handler, } rangelen++; - handler->emitResultFromCBuffer(handler, ln->ele, sdslen(ln->ele), ln->score); + sds ele = zslGetNodeElement(ln); + handler->emitResultFromCBuffer(handler, ele, sdslen(ele), ln->score); /* Move to next node */ if (reverse) { @@ -3471,14 +3513,15 @@ void genericZrangebylexCommand(zrange_result_handler *handler, while (ln && limit--) { /* Abort when the node is no longer in range. */ + sds ele = zslGetNodeElement(ln); if (reverse) { - if (!zslLexValueGteMin(ln->ele, range)) break; + if (!zslLexValueGteMin(ele, range)) break; } else { - if (!zslLexValueLteMax(ln->ele, range)) break; + if (!zslLexValueLteMax(ele, range)) break; } rangelen++; - handler->emitResultFromCBuffer(handler, ln->ele, sdslen(ln->ele), ln->score); + handler->emitResultFromCBuffer(handler, ele, sdslen(ele), ln->score); /* Move to next node */ if (reverse) { @@ -3880,7 +3923,7 @@ void genericZpopCommand(client *c, /* There must be an element in the sorted set. */ serverAssertWithInfo(c, zobj, zln != NULL); - ele = sdsdup(zln->ele); + ele = sdsdup(zslGetNodeElement(zln)); score = zln->score; } else { serverPanic("Unknown sorted set encoding"); @@ -4090,7 +4133,8 @@ void zrandmemberWithCountCommand(client *c, long l, int withscores) { serverAssert(hashtableFairRandomEntry(zs->ht, &entry)); zskiplistNode *node = entry; if (withscores && c->resp > 2) addReplyArrayLen(c, 2); - addReplyBulkCBuffer(c, node->ele, sdslen(node->ele)); + sds ele = zslGetNodeElement(node); + addReplyBulkCBuffer(c, ele, sdslen(ele)); if (withscores) addReplyDouble(c, node->score); if (c->flag.close_asap) break; } @@ -4189,7 +4233,7 @@ void zrandmemberWithCountCommand(client *c, long l, int withscores) { while (size > count) { void *element; hashtableFairRandomEntry(ht, &element); - hashtableDelete(ht, ((zskiplistNode *)element)->ele); + hashtableDelete(ht, zslGetNodeElement((zskiplistNode *)element)); size--; } hashtableResetIterator(&iter); @@ -4199,7 +4243,7 @@ void zrandmemberWithCountCommand(client *c, long l, int withscores) { void *next; while (hashtableNext(&iter, &next)) { zskiplistNode *node = (zskiplistNode *)next; - sds key = node->ele; + sds key = zslGetNodeElement(node); if (withscores && c->resp > 2) addReplyArrayLen(c, 2); addReplyBulkCBuffer(c, key, sdslen(key)); if (withscores) addReplyDouble(c, node->score); diff --git a/src/valkey-check-rdb.c b/src/valkey-check-rdb.c index 10aa12d31..79782accd 100644 --- a/src/valkey-check-rdb.c +++ b/src/valkey-check-rdb.c @@ -328,7 +328,8 @@ void computeDatasetProfile(int dbid, robj *keyobj, robj *o, long long expiretime const int len = fpconv_dtoa(node->score, buf); buf[len] = '\0'; - eleLen += sdslen(node->ele) + strlen(buf); + sds ele = zslGetNodeElement(node); + eleLen += sdslen(ele) + strlen(buf); statsRecordElementSize(eleLen, 1, stats); } hashtableResetIterator(&iter); diff --git a/tests/modules/zset.c b/tests/modules/zset.c index 0f0980f00..9a76c783d 100644 --- a/tests/modules/zset.c +++ b/tests/modules/zset.c @@ -2,6 +2,8 @@ #include #include +#define UNUSED(V) ((void) V) + /* ZSET.REM key element * * Removes an occurrence of an element from a sorted set. Replies with the @@ -69,6 +71,91 @@ int zset_incrby(ValkeyModuleCtx *ctx, ValkeyModuleString **argv, int argc) { return ValkeyModule_ReplyWithError(ctx, "ERR ZsetIncrby failed"); } +static int zset_internal_rangebylex(ValkeyModuleCtx *ctx, ValkeyModuleString **argv, int reverse) { + ValkeyModule_AutoMemory(ctx); + ValkeyModuleKey *key = ValkeyModule_OpenKey(ctx, argv[1], VALKEYMODULE_READ); + if (ValkeyModule_KeyType(key) != VALKEYMODULE_KEYTYPE_ZSET) { + return ValkeyModule_ReplyWithError(ctx, VALKEYMODULE_ERRORMSG_WRONGTYPE); + } + + if (reverse) { + if (ValkeyModule_ZsetLastInLexRange(key, argv[2], argv[3]) != VALKEYMODULE_OK) { + return ValkeyModule_ReplyWithError(ctx, "invalid range"); + } + } else { + if (ValkeyModule_ZsetFirstInLexRange(key, argv[2], argv[3]) != VALKEYMODULE_OK) { + return ValkeyModule_ReplyWithError(ctx, "invalid range"); + } + } + + int arraylen = 0; + ValkeyModule_ReplyWithArray(ctx, VALKEYMODULE_POSTPONED_LEN); + while (!ValkeyModule_ZsetRangeEndReached(key)) { + ValkeyModuleString *ele = ValkeyModule_ZsetRangeCurrentElement(key, NULL); + ValkeyModule_ReplyWithString(ctx, ele); + ValkeyModule_FreeString(ctx, ele); + if (reverse) { + ValkeyModule_ZsetRangePrev(key); + } else { + ValkeyModule_ZsetRangeNext(key); + } + arraylen += 1; + } + ValkeyModule_ZsetRangeStop(key); + ValkeyModule_CloseKey(key); + ValkeyModule_ReplySetArrayLength(ctx, arraylen); + return VALKEYMODULE_OK; +} + +/* ZSET.rangebylex key min max + * + * Returns members in a sorted set within a lexicographical range. + */ +int zset_rangebylex(ValkeyModuleCtx *ctx, ValkeyModuleString **argv, int argc) { + if (argc != 4) + return ValkeyModule_WrongArity(ctx); + return zset_internal_rangebylex(ctx, argv, 0); +} + +/* ZSET.revrangebylex key min max + * + * Returns members in a sorted set within a lexicographical range in reverse order. + */ +int zset_revrangebylex(ValkeyModuleCtx *ctx, ValkeyModuleString **argv, int argc) { + if (argc != 4) + return ValkeyModule_WrongArity(ctx); + return zset_internal_rangebylex(ctx, argv, 1); +} + +static void zset_members_cb(ValkeyModuleKey *key, ValkeyModuleString *field, ValkeyModuleString *value, void *privdata) { + UNUSED(key); + UNUSED(value); + ValkeyModuleCtx *ctx = (ValkeyModuleCtx *)privdata; + ValkeyModule_ReplyWithString(ctx, field); +} + +/* ZSET.members key + * + * Returns members in a sorted set. + */ +int zset_members(ValkeyModuleCtx *ctx, ValkeyModuleString **argv, int argc) { + if (argc != 2) + return ValkeyModule_WrongArity(ctx); + ValkeyModule_AutoMemory(ctx); + + ValkeyModuleKey *key = ValkeyModule_OpenKey(ctx, argv[1], VALKEYMODULE_READ); + if (ValkeyModule_KeyType(key) != VALKEYMODULE_KEYTYPE_ZSET) { + return ValkeyModule_ReplyWithError(ctx, VALKEYMODULE_ERRORMSG_WRONGTYPE); + } + + ValkeyModule_ReplyWithArray(ctx, ValkeyModule_ValueLength(key)); + ValkeyModuleScanCursor *c = ValkeyModule_ScanCursorCreate(); + while (ValkeyModule_ScanKey(key, c, zset_members_cb, ctx)); + ValkeyModule_CloseKey(key); + ValkeyModule_ScanCursorDestroy(c); + return VALKEYMODULE_OK; +} + int ValkeyModule_OnLoad(ValkeyModuleCtx *ctx, ValkeyModuleString **argv, int argc) { VALKEYMODULE_NOT_USED(argv); VALKEYMODULE_NOT_USED(argc); @@ -87,5 +174,17 @@ int ValkeyModule_OnLoad(ValkeyModuleCtx *ctx, ValkeyModuleString **argv, int arg 1, 1, 1) == VALKEYMODULE_ERR) return VALKEYMODULE_ERR; + if (ValkeyModule_CreateCommand(ctx, "zset.rangebylex", zset_rangebylex, "readonly", + 1, 1, 1) == VALKEYMODULE_ERR) + return VALKEYMODULE_ERR; + + if (ValkeyModule_CreateCommand(ctx, "zset.revrangebylex", zset_revrangebylex, "readonly", + 1, 1, 1) == VALKEYMODULE_ERR) + return VALKEYMODULE_ERR; + + if (ValkeyModule_CreateCommand(ctx, "zset.members", zset_members, "readonly", + 1, 1, 1) == VALKEYMODULE_ERR) + return VALKEYMODULE_ERR; + return VALKEYMODULE_OK; } diff --git a/tests/unit/moduleapi/zset.tcl b/tests/unit/moduleapi/zset.tcl index b6ab41d5f..b90d2cfe5 100644 --- a/tests/unit/moduleapi/zset.tcl +++ b/tests/unit/moduleapi/zset.tcl @@ -34,6 +34,75 @@ start_server {tags {"modules"}} { assert_equal {hello 100} [r zrange k 0 -1 withscores] } + test {Module zset rangebylex} { + # Should give wrong arity error + assert_error "ERR wrong number of arguments*" {r zset.rangebylex} + assert_error "ERR wrong number of arguments*" {r zset.revrangebylex} + + # Should give wrong type error + r del k + r set k v + assert_error "WRONGTYPE Operation against a key*" {r zset.rangebylex k - +} + + # Should give invalid range error + r del k + r zadd k 0 ele + assert_error "invalid range" {r zset.rangebylex k - a} + assert_error "invalid range" {r zset.revrangebylex k - a} + + # Check if the data structure of the sorted set is skiplist + r del k + r config set zset-max-listpack-entries 2 + r config set zset-max-listpack-value 64 + for {set i 0} {$i < 4} {incr i} { + r zadd k 0 "ele$i" + } + assert_equal {ele0 ele1 ele2 ele3} [r zset.rangebylex k - +] + assert_equal {ele3 ele2 ele1 ele0} [r zset.revrangebylex k - +] + assert_equal {ele1 ele2} [r zset.rangebylex k "(ele0" "(ele3"] + assert_equal {ele2 ele1} [r zset.revrangebylex k "(ele0" "(ele3"] + + # Check if the data structure of the sorted set is listpack + r del k + r config set zset-max-listpack-entries 128 + r config set zset-max-listpack-value 64 + for {set i 0} {$i < 4} {incr i} { + r zadd k 0 "ele$i" + } + assert_equal {ele0 ele1 ele2 ele3} [r zset.rangebylex k - +] + assert_equal {ele3 ele2 ele1 ele0} [r zset.revrangebylex k - +] + assert_equal {ele1 ele2} [r zset.rangebylex k "(ele0" "(ele3"] + assert_equal {ele2 ele1} [r zset.revrangebylex k "(ele0" "(ele3"] + } + + test {Module zset members} { + # Should give wrong arity error + assert_error "ERR wrong number of arguments*" {r zset.members} + + # Should give wrong type error + r del k + r set k v + assert_error "WRONGTYPE Operation against a key*" {r zset.members k} + + # Check if the data structure of the sorted set is skiplist + r del k + r config set zset-max-listpack-entries 2 + r config set zset-max-listpack-value 64 + for {set i 0} {$i < 4} {incr i} { + r zadd k 0 "ele$i" + } + assert_equal {ele0 ele1 ele2 ele3} [lsort [r zset.members k]] + + # Check if the data structure of the sorted set is listpack + r del k + r config set zset-max-listpack-entries 128 + r config set zset-max-listpack-value 64 + for {set i 0} {$i < 4} {incr i} { + r zadd k 0 "ele$i" + } + assert_equal {ele0 ele1 ele2 ele3} [lsort [r zset.members k]] + } + test "Unload the module - zset" { assert_equal {OK} [r module unload zset] }