Skip to content

feat: Added store option for sort command. #5095

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 83 additions & 13 deletions src/server/generic_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ extern "C" {

#include "base/flags.h"
#include "base/logging.h"
#include "core/qlist.h"
#include "redis/rdb.h"
#include "server/acl/acl_commands_def.h"
#include "server/blocking_controller.h"
Expand Down Expand Up @@ -1438,10 +1439,33 @@ OpResultTyped<SortEntryList> OpFetchSortEntries(const OpArgs& op_args, std::stri
return success ? res : OpStatus::INVALID_NUMERIC_RESULT;
}

template <typename IteratorBegin, typename IteratorEnd>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you need a template here? what are the types of the iterators?

Copy link
Contributor Author

@H4R5H1T-007 H4R5H1T-007 May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two types of iterator present as SortedEntryList is variant of two vector and hence iterator type will also change based on the arguments so I added this template. Let me know if I should handle this using some different method.

OpResult<uint32_t> OpStore(const OpArgs& op_args, std::string_view key, IteratorBegin&& start_it,
IteratorEnd&& end_it) {
uint32_t len = 0;

QList* ql_v2 = CompactObj::AllocateMR<QList>();
QList::Where where = QList::TAIL;
for (auto it = start_it; it != end_it; ++it) {
ql_v2->Push(it->key, where);
}
len = ql_v2->Size();

PrimeValue pv;
pv.InitRobj(OBJ_LIST, kEncodingQL2, ql_v2);

// This would overwrite existing value if any with new list.
auto op_res = op_args.GetDbSlice().AddOrUpdate(op_args.db_cntx, key, std::move(pv), 0);
RETURN_ON_BAD_STATUS(op_res);

return len;
}

void GenericFamily::Sort(CmdArgList args, const CommandContext& cmd_cntx) {
std::string_view key = ArgS(args, 0);
bool alpha = false;
bool reversed = false;
std::optional<std::string_view> store_key;
std::optional<std::pair<size_t, size_t>> bounds;
auto* builder = cmd_cntx.rb;
for (size_t i = 1; i < args.size(); i++) {
Expand All @@ -1463,29 +1487,57 @@ void GenericFamily::Sort(CmdArgList args, const CommandContext& cmd_cntx) {
}
bounds = {offset, limit};
i += 2;
} else if (arg == "STORE") {
if (i + 1 >= args.size()) {
return builder->SendError(kSyntaxErr);
}
store_key = ArgS(args, i + 1);
i += 1;
} else {
LOG_EVERY_T(ERROR, 1) << "Unsupported option " << arg;
return builder->SendError(kSyntaxErr, kSyntaxErrType);
}
}

OpResultTyped<SortEntryList> fetch_result =
cmd_cntx.tx->ScheduleSingleHopT([&](Transaction* t, EngineShard* shard) {
return OpFetchSortEntries(t->GetOpArgs(shard), key, alpha);
});
ShardId source_sid = Shard(key, shard_set->size());
OpResultTyped<SortEntryList> fetch_result;
auto fetch_cb = [&](Transaction* t, EngineShard* shard) {
ShardId shard_id = shard->shard_id();
if (shard_id == source_sid) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls add a comment: in case of SORT option, we fetch only on the source shard

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

fetch_result = OpFetchSortEntries(t->GetOpArgs(shard), key, alpha);
}
return OpStatus::OK;
};

if (store_key) {
cmd_cntx.tx->Execute(std::move(fetch_cb), false);
} else {
cmd_cntx.tx->Execute(std::move(fetch_cb), true);
}

// OpResultTyped<SortEntryList> fetch_result =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the old code

// cmd_cntx.tx->ScheduleSingleHopT([&](Transaction* t, EngineShard* shard) {
// return OpFetchSortEntries(t->GetOpArgs(shard), key, alpha);
// });

if (fetch_result == OpStatus::WRONG_TYPE)
if (fetch_result == OpStatus::WRONG_TYPE) {
cmd_cntx.tx->Conclude();
return builder->SendError(fetch_result.status());
}

if (fetch_result.status() == OpStatus::INVALID_NUMERIC_RESULT)
if (fetch_result.status() == OpStatus::INVALID_NUMERIC_RESULT) {
cmd_cntx.tx->Conclude();
return builder->SendError("One or more scores can't be converted into double");
}

auto* rb = static_cast<RedisReplyBuilder*>(builder);
if (!fetch_result.ok())
if (!fetch_result.ok()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would unify all the non-ok conditions here:

if (!fetch_result.ok()) {
   cmd_cntx.tx->Conclude();
   if (fetch_result.status() == OpStatus::INVALID_NUMERIC_RESULT) 
     return builder->SendError(...);
   ....
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup that will be better. Will do that

cmd_cntx.tx->Conclude();
return rb->SendEmptyArray();
}

auto result_type = fetch_result.type();
auto sort_call = [builder, bounds, reversed, result_type](auto& entries) {
auto sort_call = [builder, bounds, reversed, result_type, &store_key, &cmd_cntx](auto& entries) {
using value_t = typename std::decay_t<decltype(entries)>::value_type;
auto cmp = reversed ? &value_t::greater : &value_t::less;
if (bounds) {
Expand All @@ -1504,11 +1556,29 @@ void GenericFamily::Sort(CmdArgList args, const CommandContext& cmd_cntx) {

bool is_set = (result_type == OBJ_SET || result_type == OBJ_ZSET);
auto* rb = static_cast<RedisReplyBuilder*>(builder);
rb->StartCollection(std::distance(start_it, end_it),
is_set ? RedisReplyBuilder::SET : RedisReplyBuilder::ARRAY);
if (store_key) {
ShardId dest_sid = Shard(store_key.value(), shard_set->size());
OpResult<uint32_t> store_len;
auto store_callback = [&](Transaction* t, EngineShard* shard) {
ShardId shard_id = shard->shard_id();
if (shard_id == dest_sid) {
store_len = OpStore(t->GetOpArgs(shard), store_key.value(), start_it, end_it);
}
return OpStatus::OK;
};
cmd_cntx.tx->Execute(std::move(store_callback), true);
if (store_len) {
rb->SendLong(store_len.value());
} else {
rb->SendError(store_len.status());
}
} else {
rb->StartCollection(std::distance(start_it, end_it),
is_set ? RedisReplyBuilder::SET : RedisReplyBuilder::ARRAY);

for (auto it = start_it; it != end_it; ++it) {
rb->SendBulkString(it->key);
for (auto it = start_it; it != end_it; ++it) {
rb->SendBulkString(it->key);
}
}
};

Expand Down Expand Up @@ -1965,7 +2035,7 @@ void GenericFamily::Register(CommandRegistry* registry) {
<< CI{"DUMP", CO::READONLY, 2, 1, 1, acl::kDump}.HFUNC(Dump)
<< CI{"UNLINK", CO::WRITE, -2, 1, -1, acl::kUnlink}.HFUNC(Unlink)
<< CI{"STICK", CO::WRITE, -2, 1, -1, acl::kStick}.HFUNC(Stick)
<< CI{"SORT", CO::READONLY, -2, 1, 1, acl::kSort}.HFUNC(Sort)
<< CI{"SORT", CO::WRITE, -2, 1, -1, acl::kSort}.HFUNC(Sort)
<< CI{"MOVE", CO::WRITE | CO::GLOBAL_TRANS | CO::NO_AUTOJOURNAL, 3, 1, 1, acl::kMove}.HFUNC(
Move)
<< CI{"RESTORE", CO::WRITE, -4, 1, 1, acl::kRestore}.HFUNC(Restore)
Expand Down
87 changes: 87 additions & 0 deletions src/server/generic_family_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,93 @@ TEST_F(GenericFamilyTest, SortBug3636) {
ASSERT_THAT(resp, ArrLen(17));
}

TEST_F(GenericFamilyTest, SortStore) {
// Test list sort with params
Run({"del", "list-1"});
Run({"del", "list-2"});
Run({"lpush", "list-1", "3.5", "1.2", "10.1", "2.20", "200"});
// numeric
auto resp = Run({"sort", "list-1", "store", "list-2"});
EXPECT_EQ(5, resp.GetInt());
ASSERT_THAT(Run({"lrange", "list-2", "0", "-1"}).GetVec(),
ElementsAre("1.2", "2.20", "3.5", "10.1", "200"));

// string
resp = Run({"sort", "list-1", "ALPHA", "store", "list-2"});
EXPECT_EQ(5, resp.GetInt());
ASSERT_THAT(Run({"lrange", "list-2", "0", "-1"}).GetVec(),
ElementsAre("1.2", "10.1", "2.20", "200", "3.5"));

// desc numeric
resp = Run({"sort", "list-1", "DESC", "store", "list-2"});
EXPECT_EQ(5, resp.GetInt());
ASSERT_THAT(Run({"lrange", "list-2", "0", "-1"}).GetVec(),
ElementsAre("200", "10.1", "3.5", "2.20", "1.2"));

// desc string
resp = Run({"sort", "list-1", "ALPHA", "DESC", "store", "list-2"});
EXPECT_EQ(5, resp.GetInt());
ASSERT_THAT(Run({"lrange", "list-2", "0", "-1"}).GetVec(),
ElementsAre("3.5", "200", "2.20", "10.1", "1.2"));

// limits
resp = Run({"sort", "list-1", "LIMIT", "0", "5", "store", "list-2"});
EXPECT_EQ(5, resp.GetInt());
ASSERT_THAT(Run({"lrange", "list-2", "0", "-1"}).GetVec(),
ElementsAre("1.2", "2.20", "3.5", "10.1", "200"));
resp = Run({"sort", "list-1", "LIMIT", "0", "10", "store", "list-2"});
EXPECT_EQ(5, resp.GetInt());
ASSERT_THAT(Run({"lrange", "list-2", "0", "-1"}).GetVec(),
ElementsAre("1.2", "2.20", "3.5", "10.1", "200"));
resp = Run({"sort", "list-1", "LIMIT", "2", "2", "store", "list-2"});
EXPECT_EQ(2, resp.GetInt());
ASSERT_THAT(Run({"lrange", "list-2", "0", "-1"}).GetVec(), ElementsAre("3.5", "10.1"));
resp = Run({"sort", "list-1", "LIMIT", "1", "1", "store", "list-2"});
EXPECT_EQ(1, resp.GetInt());
ASSERT_THAT(Run({"lrange", "list-2", "0", "-1"}), "2.20");
resp = Run({"sort", "list-1", "LIMIT", "4", "2", "store", "list-2"});
EXPECT_EQ(1, resp.GetInt());
ASSERT_THAT(Run({"lrange", "list-2", "0", "-1"}), "200");
resp = Run({"sort", "list-1", "LIMIT", "5", "2", "store", "list-2"});
EXPECT_EQ(0, resp.GetInt());
ASSERT_THAT(Run({"lrange", "list-2", "0", "-1"}), ArrLen(0));

// Test set sort
Run({"del", "set-1"});
Run({"del", "list-3"});
Run({"sadd", "set-1", "5.3", "4.4", "60", "99.9", "100", "9"});
resp = Run({"sort", "set-1", "store", "list-3"});
EXPECT_EQ(6, resp.GetInt());
ASSERT_THAT(Run({"lrange", "list-3", "0", "-1"}).GetVec(),
ElementsAre("4.4", "5.3", "9", "60", "99.9", "100"));

// Test sorted set sort
Run({"del", "zset-1"});
Run({"del", "list-4"});
Run({"zadd", "zset-1", "0", "3.3", "0", "30.1", "0", "8.2"});
resp = Run({"sort", "zset-1", "store", "list-4"});
EXPECT_EQ(3, resp.GetInt());
ASSERT_THAT(Run({"lrange", "list-4", "0", "-1"}).GetVec(), ElementsAre("3.3", "8.2", "30.1"));

// Same key overwrite.
Run({"del", "list-1"});
Run({"del", "list-2"});
Run({"lpush", "list-1", "3.5", "1.2", "10.1", "2.20", "200"});
resp = Run({"sort", "list-1", "store", "list-1"});
EXPECT_EQ(5, resp.GetInt());
ASSERT_THAT(Run({"lrange", "list-1", "0", "-1"}).GetVec(),
ElementsAre("1.2", "2.20", "3.5", "10.1", "200"));

// Check that the keys should not expire after some time.
Run({"del", "list-1"});
Run({"del", "list-2"});
Run({"lpush", "list-1", "3.5", "1.2", "10.1", "2.20", "200"});
Run({"sort", "list-1", "store", "list-2"});
AdvanceTime(5000);
ASSERT_THAT(Run({"lrange", "list-2", "0", "-1"}).GetVec(),
ElementsAre("1.2", "2.20", "3.5", "10.1", "200"));
}

TEST_F(GenericFamilyTest, TimeNoKeys) {
auto resp = Run({"time"});
EXPECT_THAT(resp, ArrLen(2));
Expand Down
Loading