From bcf0cd5e92b95037eb12e31fd264a4b6ecbc4927 Mon Sep 17 00:00:00 2001 From: hdt3213 <hdt3213@outlook.com> Date: Sat, 3 Apr 2021 20:14:12 +0800 Subject: [PATCH] reformat code --- README.md | 17 +- README_CN.md | 34 +- src/cluster/client.go | 46 +- src/cluster/del.go | 166 ++-- src/cluster/idgenerator/snowflake.go | 112 +-- src/cluster/mset.go | 264 +++--- src/cluster/rename.go | 57 +- src/cluster/router.go | 182 ++-- src/cluster/transaction.go | 304 +++--- src/cmd/main.go | 6 +- src/config/config.go | 12 +- src/datastruct/dict/dict.go | 22 +- src/datastruct/dict/dict_test.go | 452 ++++----- src/datastruct/dict/simple.go | 142 +-- src/datastruct/list/linked.go | 542 +++++------ src/datastruct/list/linked_test.go | 364 +++---- src/datastruct/lock/lock_map.go | 219 +++-- src/datastruct/set/set_test.go | 52 +- src/datastruct/sortedset/border.go | 124 +-- src/datastruct/sortedset/skiplist.go | 555 ++++++----- src/datastruct/sortedset/sortedset.go | 345 ++++--- src/datastruct/utils/utils.go | 44 +- src/db/aof.go | 30 +- src/db/db.go | 382 ++++---- src/db/hash.go | 748 +++++++-------- src/db/list.go | 792 ++++++++-------- src/db/router.go | 180 ++-- src/db/server.go | 20 +- src/db/sortedset.go | 1098 +++++++++++----------- src/interface/db/db.go | 6 +- src/interface/redis/client.go | 12 +- src/interface/redis/reply.go | 2 +- src/interface/tcp/handler.go | 10 +- src/lib/consistenthash/consistenthash.go | 100 +- src/lib/files/files.go | 88 +- src/lib/geohash/geohash.go | 154 +-- src/lib/geohash/geohash_test.go | 54 +- src/lib/geohash/neighbor.go | 200 ++-- src/lib/marshal/gob/gob.go | 22 +- src/lib/sync/atomic/bool.go | 18 +- src/lib/sync/wait/wait.go | 46 +- src/lib/wildcard/wildcard.go | 152 +-- src/lib/wildcard/wildcard_test.go | 128 +-- src/pubsub/hub.go | 20 +- src/pubsub/pubsub.go | 220 ++--- src/redis/client/client.go | 535 ++++++----- src/redis/client/client_test.go | 176 ++-- src/redis/reply/consts.go | 32 +- src/redis/reply/errors.go | 52 +- src/redis/reply/reply.go | 119 ++- src/redis/server/client.go | 132 +-- src/redis/server/handler.go | 340 +++---- src/tcp/echo.go | 119 ++- src/tcp/server.go | 125 ++- 54 files changed, 5082 insertions(+), 5091 deletions(-) diff --git a/README.md b/README.md index 2ac9029c..383916d5 100644 --- a/README.md +++ b/README.md @@ -2,11 +2,13 @@ [中文版](https://github.com/HDT3213/godis/blob/master/README_CN.md) -`Godis` is a simple implementation of Redis Server, which intents to provide an example of writing a high concurrent middleware using golang. +`Godis` is a simple implementation of Redis Server, which intents to provide an example of writing a high concurrent +middleware using golang. Please be advised, NEVER think about using this in production environment. -This repository implemented most features of redis, including 5 data structures, ttl, publish/subscribe, AOF persistence and server side cluster mode. +This repository implemented most features of redis, including 5 data structures, ttl, publish/subscribe, AOF persistence +and server side cluster mode. If you could read Chinese, you can find more details in [My Blog](https://www.cnblogs.com/Finley/category/1598973.html). @@ -22,7 +24,7 @@ You could use redis-cli or other redis client to connect godis server, which lis The program will try to read config file path from environment variable `CONFIG`. -If environment variable is not set, then the program try to read `redis.conf` in the working directory. +If environment variable is not set, then the program try to read `redis.conf` in the working directory. If there is no such file, then the program will run with default config. @@ -35,8 +37,7 @@ peers localhost:7379,localhost:7389 // other node in cluster self localhost:6399 // self address ``` -We provide node1.conf and node2.conf for demonstration. -use following command line to start a two-node-cluster: +We provide node1.conf and node2.conf for demonstration. use following command line to start a two-node-cluster: ```bash CONFIG=node1.conf ./godis-darwin & @@ -151,7 +152,7 @@ Supported Commands: If you want to read my code in this repository, here is a simple guidance. - cmd: only the entry point -- config: config parser +- config: config parser - interface: some interface definitions - lib: some utils, such as logger, sync utils and wildcard @@ -167,7 +168,7 @@ I suggest focusing on the following directories: - sortedset: a sorted set implements based on skiplist - db: the implements of the redis db - db.go: the basement of database - - router.go: it find handler for commands + - router.go: it find handler for commands - keys.go: handlers for keys commands - string.go: handlers for string commands - list.go: handlers for list commands @@ -176,7 +177,7 @@ I suggest focusing on the following directories: - sortedset.go: handlers for sorted set commands - pubsub.go: implements of publish / subscribe - aof.go: implements of AOF persistence and rewrite - + # License This project is licensed under the [GPL license](https://github.com/HDT3213/godis/blob/master/LICENSE). \ No newline at end of file diff --git a/README_CN.md b/README_CN.md index 113314e1..45e67fc0 100644 --- a/README_CN.md +++ b/README_CN.md @@ -2,8 +2,8 @@ Godis 是一个用 Go 语言实现的 Redis 服务器。本项目旨在为尝试 **请注意:不要在生产环境使用使用此项目** -Godis 实现了 Redis 的大多数功能,包括5种数据结构、TTL、发布订阅以及 AOF 持久化。可以在[我的博客](https://www.cnblogs.com/Finley/category/1598973.html)了解更多关于 Godis 的信息。 - +Godis 实现了 Redis 的大多数功能,包括5种数据结构、TTL、发布订阅以及 AOF 持久化。可以在[我的博客](https://www.cnblogs.com/Finley/category/1598973.html)了解更多关于 +Godis 的信息。 # 运行 Godis @@ -149,19 +149,19 @@ redis-cli -p 6399 - tcp: tcp 服务器实现 - redis: redis 协议解析器 - datastruct: redis 的各类数据结构实现 - - dict: hash 表 - - list: 链表 - - lock: 用于锁定 key 的锁组件 - - set: 基于hash表的集合 - - sortedset: 基于跳表实现的有序集合 + - dict: hash 表 + - list: 链表 + - lock: 用于锁定 key 的锁组件 + - set: 基于hash表的集合 + - sortedset: 基于跳表实现的有序集合 - db: redis 存储引擎实现 - - db.go: 引擎的基础功能 - - router.go: 将命令路由给响应的处理函数 - - keys.go: del、ttl、expire 等通用命令实现 - - string.go: get、set 等字符串命令实现 - - list.go: lpush、lindex 等列表命令实现 - - hash.go: hget、hset 等哈希表命令实现 - - set.go: sadd 等集合命令实现 - - sortedset.go: zadd 等有序集合命令实现 - - pubsub.go: 发布订阅命令实现 - - aof.go: aof持久化实现 \ No newline at end of file + - db.go: 引擎的基础功能 + - router.go: 将命令路由给响应的处理函数 + - keys.go: del、ttl、expire 等通用命令实现 + - string.go: get、set 等字符串命令实现 + - list.go: lpush、lindex 等列表命令实现 + - hash.go: hget、hset 等哈希表命令实现 + - set.go: sadd 等集合命令实现 + - sortedset.go: zadd 等有序集合命令实现 + - pubsub.go: 发布订阅命令实现 + - aof.go: aof持久化实现 \ No newline at end of file diff --git a/src/cluster/client.go b/src/cluster/client.go index 121d55d7..e7733993 100644 --- a/src/cluster/client.go +++ b/src/cluster/client.go @@ -1,45 +1,45 @@ package cluster import ( - "context" - "errors" - "github.com/HDT3213/godis/src/redis/client" - "github.com/jolestar/go-commons-pool/v2" + "context" + "errors" + "github.com/HDT3213/godis/src/redis/client" + "github.com/jolestar/go-commons-pool/v2" ) type ConnectionFactory struct { - Peer string + Peer string } func (f *ConnectionFactory) MakeObject(ctx context.Context) (*pool.PooledObject, error) { - c, err := client.MakeClient(f.Peer) - if err != nil { - return nil, err - } - c.Start() - return pool.NewPooledObject(c), nil + c, err := client.MakeClient(f.Peer) + if err != nil { + return nil, err + } + c.Start() + return pool.NewPooledObject(c), nil } func (f *ConnectionFactory) DestroyObject(ctx context.Context, object *pool.PooledObject) error { - c, ok := object.Object.(*client.Client) - if !ok { - return errors.New("type mismatch") - } - c.Close() - return nil + c, ok := object.Object.(*client.Client) + if !ok { + return errors.New("type mismatch") + } + c.Close() + return nil } func (f *ConnectionFactory) ValidateObject(ctx context.Context, object *pool.PooledObject) bool { - // do validate - return true + // do validate + return true } func (f *ConnectionFactory) ActivateObject(ctx context.Context, object *pool.PooledObject) error { - // do activate - return nil + // do activate + return nil } func (f *ConnectionFactory) PassivateObject(ctx context.Context, object *pool.PooledObject) error { - // do passivate - return nil + // do passivate + return nil } diff --git a/src/cluster/del.go b/src/cluster/del.go index 4fc047e9..6ef63038 100644 --- a/src/cluster/del.go +++ b/src/cluster/del.go @@ -1,99 +1,99 @@ package cluster import ( - "github.com/HDT3213/godis/src/interface/redis" - "github.com/HDT3213/godis/src/redis/reply" - "strconv" + "github.com/HDT3213/godis/src/interface/redis" + "github.com/HDT3213/godis/src/redis/reply" + "strconv" ) func Del(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { - if len(args) < 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'del' command") - } - keys := make([]string, len(args)-1) - for i := 1; i < len(args); i++ { - keys[i-1] = string(args[i]) - } - groupMap := cluster.groupBy(keys) - if len(groupMap) == 1 { // do fast - for peer, group := range groupMap { // only one group - return cluster.Relay(peer, c, makeArgs("DEL", group...)) - } - } - // prepare - var errReply redis.Reply - txId := cluster.idGenerator.NextId() - txIdStr := strconv.FormatInt(txId, 10) - rollback := false - for peer, group := range groupMap { - args := []string{txIdStr} - args = append(args, group...) - var resp redis.Reply - if peer == cluster.self { - resp = PrepareDel(cluster, c, makeArgs("PrepareDel", args...)) - } else { - resp = cluster.Relay(peer, c, makeArgs("PrepareDel", args...)) - } - if reply.IsErrorReply(resp) { - errReply = resp - rollback = true - break - } - } - var respList []redis.Reply - if rollback { - // rollback - RequestRollback(cluster, c, txId, groupMap) - } else { - // commit - respList, errReply = RequestCommit(cluster, c, txId, groupMap) - if errReply != nil { - rollback = true - } - } - if !rollback { - var deleted int64 = 0 - for _, resp := range respList { - intResp := resp.(*reply.IntReply) - deleted += intResp.Code - } - return reply.MakeIntReply(int64(deleted)) - } - return errReply + if len(args) < 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'del' command") + } + keys := make([]string, len(args)-1) + for i := 1; i < len(args); i++ { + keys[i-1] = string(args[i]) + } + groupMap := cluster.groupBy(keys) + if len(groupMap) == 1 { // do fast + for peer, group := range groupMap { // only one group + return cluster.Relay(peer, c, makeArgs("DEL", group...)) + } + } + // prepare + var errReply redis.Reply + txId := cluster.idGenerator.NextId() + txIdStr := strconv.FormatInt(txId, 10) + rollback := false + for peer, group := range groupMap { + args := []string{txIdStr} + args = append(args, group...) + var resp redis.Reply + if peer == cluster.self { + resp = PrepareDel(cluster, c, makeArgs("PrepareDel", args...)) + } else { + resp = cluster.Relay(peer, c, makeArgs("PrepareDel", args...)) + } + if reply.IsErrorReply(resp) { + errReply = resp + rollback = true + break + } + } + var respList []redis.Reply + if rollback { + // rollback + RequestRollback(cluster, c, txId, groupMap) + } else { + // commit + respList, errReply = RequestCommit(cluster, c, txId, groupMap) + if errReply != nil { + rollback = true + } + } + if !rollback { + var deleted int64 = 0 + for _, resp := range respList { + intResp := resp.(*reply.IntReply) + deleted += intResp.Code + } + return reply.MakeIntReply(int64(deleted)) + } + return errReply } // args: PrepareDel id keys... func PrepareDel(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { - if len(args) < 3 { - return reply.MakeErrReply("ERR wrong number of arguments for 'preparedel' command") - } - txId := string(args[1]) - keys := make([]string, 0, len(args)-2) - for i := 2; i < len(args); i++ { - arg := args[i] - keys = append(keys, string(arg)) - } - txArgs := makeArgs("DEL", keys...) // actual args for cluster.db - tx := NewTransaction(cluster, c, txId, txArgs, keys) - cluster.transactions.Put(txId, tx) - err := tx.prepare() - if err != nil { - return reply.MakeErrReply(err.Error()) - } - return &reply.OkReply{} + if len(args) < 3 { + return reply.MakeErrReply("ERR wrong number of arguments for 'preparedel' command") + } + txId := string(args[1]) + keys := make([]string, 0, len(args)-2) + for i := 2; i < len(args); i++ { + arg := args[i] + keys = append(keys, string(arg)) + } + txArgs := makeArgs("DEL", keys...) // actual args for cluster.db + tx := NewTransaction(cluster, c, txId, txArgs, keys) + cluster.transactions.Put(txId, tx) + err := tx.prepare() + if err != nil { + return reply.MakeErrReply(err.Error()) + } + return &reply.OkReply{} } // invoker should provide lock func CommitDel(cluster *Cluster, c redis.Connection, tx *Transaction) redis.Reply { - keys := make([]string, len(tx.args)) - for i, v := range tx.args { - keys[i] = string(v) - } - keys = keys[1:] + keys := make([]string, len(tx.args)) + for i, v := range tx.args { + keys[i] = string(v) + } + keys = keys[1:] - deleted := cluster.db.Removes(keys...) - if deleted > 0 { - cluster.db.AddAof(reply.MakeMultiBulkReply(tx.args)) - } - return reply.MakeIntReply(int64(deleted)) + deleted := cluster.db.Removes(keys...) + if deleted > 0 { + cluster.db.AddAof(reply.MakeMultiBulkReply(tx.args)) + } + return reply.MakeIntReply(int64(deleted)) } diff --git a/src/cluster/idgenerator/snowflake.go b/src/cluster/idgenerator/snowflake.go index e2234ce4..4946bf10 100644 --- a/src/cluster/idgenerator/snowflake.go +++ b/src/cluster/idgenerator/snowflake.go @@ -1,85 +1,85 @@ package idgenerator import ( - "hash/fnv" - "log" - "sync" - "time" + "hash/fnv" + "log" + "sync" + "time" ) const ( - workerIdBits int64 = 5 - datacenterIdBits int64 = 5 - sequenceBits int64 = 12 + workerIdBits int64 = 5 + datacenterIdBits int64 = 5 + sequenceBits int64 = 12 - maxWorkerId int64 = -1 ^ (-1 << uint64(workerIdBits)) - maxDatacenterId int64 = -1 ^ (-1 << uint64(datacenterIdBits)) - maxSequence int64 = -1 ^ (-1 << uint64(sequenceBits)) + maxWorkerId int64 = -1 ^ (-1 << uint64(workerIdBits)) + maxDatacenterId int64 = -1 ^ (-1 << uint64(datacenterIdBits)) + maxSequence int64 = -1 ^ (-1 << uint64(sequenceBits)) - timeLeft uint8 = 22 - dataLeft uint8 = 17 - workLeft uint8 = 12 + timeLeft uint8 = 22 + dataLeft uint8 = 17 + workLeft uint8 = 12 - twepoch int64 = 1525705533000 + twepoch int64 = 1525705533000 ) type IdGenerator struct { - mu *sync.Mutex - lastStamp int64 - workerId int64 - dataCenterId int64 - sequence int64 + mu *sync.Mutex + lastStamp int64 + workerId int64 + dataCenterId int64 + sequence int64 } func MakeGenerator(cluster string, node string) *IdGenerator { - fnv64 := fnv.New64() - _, _ = fnv64.Write([]byte(cluster)) - dataCenterId := int64(fnv64.Sum64()) + fnv64 := fnv.New64() + _, _ = fnv64.Write([]byte(cluster)) + dataCenterId := int64(fnv64.Sum64()) - fnv64.Reset() - _, _ = fnv64.Write([]byte(node)) - workerId := int64(fnv64.Sum64()) + fnv64.Reset() + _, _ = fnv64.Write([]byte(node)) + workerId := int64(fnv64.Sum64()) - return &IdGenerator{ - mu: &sync.Mutex{}, - lastStamp: -1, - dataCenterId: dataCenterId, - workerId: workerId, - sequence: 1, - } + return &IdGenerator{ + mu: &sync.Mutex{}, + lastStamp: -1, + dataCenterId: dataCenterId, + workerId: workerId, + sequence: 1, + } } func (w *IdGenerator) getCurrentTime() int64 { - return time.Now().UnixNano() / 1e6 + return time.Now().UnixNano() / 1e6 } func (w *IdGenerator) NextId() int64 { - w.mu.Lock() - defer w.mu.Unlock() + w.mu.Lock() + defer w.mu.Unlock() - timestamp := w.getCurrentTime() - if timestamp < w.lastStamp { - log.Fatal("can not generate id") - } - if w.lastStamp == timestamp { - w.sequence = (w.sequence + 1) & maxSequence - if w.sequence == 0 { - for timestamp <= w.lastStamp { - timestamp = w.getCurrentTime() - } - } - } else { - w.sequence = 0 - } - w.lastStamp = timestamp + timestamp := w.getCurrentTime() + if timestamp < w.lastStamp { + log.Fatal("can not generate id") + } + if w.lastStamp == timestamp { + w.sequence = (w.sequence + 1) & maxSequence + if w.sequence == 0 { + for timestamp <= w.lastStamp { + timestamp = w.getCurrentTime() + } + } + } else { + w.sequence = 0 + } + w.lastStamp = timestamp - return ((timestamp - twepoch) << timeLeft) | (w.dataCenterId << dataLeft) | (w.workerId << workLeft) | w.sequence + return ((timestamp - twepoch) << timeLeft) | (w.dataCenterId << dataLeft) | (w.workerId << workLeft) | w.sequence } func (w *IdGenerator) tilNextMillis() int64 { - timestamp := w.getCurrentTime() - if timestamp <= w.lastStamp { - timestamp = w.getCurrentTime() - } - return timestamp + timestamp := w.getCurrentTime() + if timestamp <= w.lastStamp { + timestamp = w.getCurrentTime() + } + return timestamp } diff --git a/src/cluster/mset.go b/src/cluster/mset.go index 73fbaf09..0db52d67 100644 --- a/src/cluster/mset.go +++ b/src/cluster/mset.go @@ -1,159 +1,159 @@ package cluster import ( - "fmt" - "github.com/HDT3213/godis/src/db" - "github.com/HDT3213/godis/src/interface/redis" - "github.com/HDT3213/godis/src/redis/reply" - "strconv" + "fmt" + "github.com/HDT3213/godis/src/db" + "github.com/HDT3213/godis/src/interface/redis" + "github.com/HDT3213/godis/src/redis/reply" + "strconv" ) func MGet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { - if len(args) < 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'mget' command") - } - keys := make([]string, len(args)-1) - for i := 1; i < len(args); i++ { - keys[i-1] = string(args[i]) - } + if len(args) < 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'mget' command") + } + keys := make([]string, len(args)-1) + for i := 1; i < len(args); i++ { + keys[i-1] = string(args[i]) + } - resultMap := make(map[string][]byte) - groupMap := cluster.groupBy(keys) - for peer, group := range groupMap { - resp := cluster.Relay(peer, c, makeArgs("MGET", group...)) - if reply.IsErrorReply(resp) { - errReply := resp.(reply.ErrorReply) - return reply.MakeErrReply(fmt.Sprintf("ERR during get %s occurs: %v", group[0], errReply.Error())) - } - arrReply, _ := resp.(*reply.MultiBulkReply) - for i, v := range arrReply.Args { - key := group[i] - resultMap[key] = v - } - } - result := make([][]byte, len(keys)) - for i, k := range keys { - result[i] = resultMap[k] - } - return reply.MakeMultiBulkReply(result) + resultMap := make(map[string][]byte) + groupMap := cluster.groupBy(keys) + for peer, group := range groupMap { + resp := cluster.Relay(peer, c, makeArgs("MGET", group...)) + if reply.IsErrorReply(resp) { + errReply := resp.(reply.ErrorReply) + return reply.MakeErrReply(fmt.Sprintf("ERR during get %s occurs: %v", group[0], errReply.Error())) + } + arrReply, _ := resp.(*reply.MultiBulkReply) + for i, v := range arrReply.Args { + key := group[i] + resultMap[key] = v + } + } + result := make([][]byte, len(keys)) + for i, k := range keys { + result[i] = resultMap[k] + } + return reply.MakeMultiBulkReply(result) } // args: PrepareMSet id keys... func PrepareMSet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { - if len(args) < 3 { - return reply.MakeErrReply("ERR wrong number of arguments for 'preparemset' command") - } - txId := string(args[1]) - size := (len(args) - 2) / 2 - keys := make([]string, size) - for i := 0; i < size; i++ { - keys[i] = string(args[2*i+2]) - } + if len(args) < 3 { + return reply.MakeErrReply("ERR wrong number of arguments for 'preparemset' command") + } + txId := string(args[1]) + size := (len(args) - 2) / 2 + keys := make([]string, size) + for i := 0; i < size; i++ { + keys[i] = string(args[2*i+2]) + } - txArgs := [][]byte{ - []byte("MSet"), - } // actual args for cluster.db - txArgs = append(txArgs, args[2:]...) - tx := NewTransaction(cluster, c, txId, txArgs, keys) - cluster.transactions.Put(txId, tx) - err := tx.prepare() - if err != nil { - return reply.MakeErrReply(err.Error()) - } - return &reply.OkReply{} + txArgs := [][]byte{ + []byte("MSet"), + } // actual args for cluster.db + txArgs = append(txArgs, args[2:]...) + tx := NewTransaction(cluster, c, txId, txArgs, keys) + cluster.transactions.Put(txId, tx) + err := tx.prepare() + if err != nil { + return reply.MakeErrReply(err.Error()) + } + return &reply.OkReply{} } // invoker should provide lock func CommitMSet(cluster *Cluster, c redis.Connection, tx *Transaction) redis.Reply { - size := len(tx.args) / 2 - keys := make([]string, size) - values := make([][]byte, size) - for i := 0; i < size; i++ { - keys[i] = string(tx.args[2*i+1]) - values[i] = tx.args[2*i+2] - } - for i, key := range keys { - value := values[i] - cluster.db.Put(key, &db.DataEntity{Data: value}) - } - cluster.db.AddAof(reply.MakeMultiBulkReply(tx.args)) - return &reply.OkReply{} + size := len(tx.args) / 2 + keys := make([]string, size) + values := make([][]byte, size) + for i := 0; i < size; i++ { + keys[i] = string(tx.args[2*i+1]) + values[i] = tx.args[2*i+2] + } + for i, key := range keys { + value := values[i] + cluster.db.Put(key, &db.DataEntity{Data: value}) + } + cluster.db.AddAof(reply.MakeMultiBulkReply(tx.args)) + return &reply.OkReply{} } func MSet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { - argCount := len(args) - 1 - if argCount%2 != 0 || argCount < 1 { - return reply.MakeErrReply("ERR wrong number of arguments for 'mset' command") - } + argCount := len(args) - 1 + if argCount%2 != 0 || argCount < 1 { + return reply.MakeErrReply("ERR wrong number of arguments for 'mset' command") + } - size := argCount / 2 - keys := make([]string, size) - valueMap := make(map[string]string) - for i := 0; i < size; i++ { - keys[i] = string(args[2*i+1]) - valueMap[keys[i]] = string(args[2*i+2]) - } + size := argCount / 2 + keys := make([]string, size) + valueMap := make(map[string]string) + for i := 0; i < size; i++ { + keys[i] = string(args[2*i+1]) + valueMap[keys[i]] = string(args[2*i+2]) + } - groupMap := cluster.groupBy(keys) - if len(groupMap) == 1 { // do fast - for peer := range groupMap { - return cluster.Relay(peer, c, args) - } - } + groupMap := cluster.groupBy(keys) + if len(groupMap) == 1 { // do fast + for peer := range groupMap { + return cluster.Relay(peer, c, args) + } + } - //prepare - var errReply redis.Reply - txId := cluster.idGenerator.NextId() - txIdStr := strconv.FormatInt(txId, 10) - rollback := false - for peer, group := range groupMap { - peerArgs := []string{txIdStr} - for _, k := range group { - peerArgs = append(peerArgs, k, valueMap[k]) - } - var resp redis.Reply - if peer == cluster.self { - resp = PrepareMSet(cluster, c, makeArgs("PrepareMSet", peerArgs...)) - } else { - resp = cluster.Relay(peer, c, makeArgs("PrepareMSet", peerArgs...)) - } - if reply.IsErrorReply(resp) { - errReply = resp - rollback = true - break - } - } - if rollback { - // rollback - RequestRollback(cluster, c, txId, groupMap) - } else { - _, errReply = RequestCommit(cluster, c, txId, groupMap) - rollback = errReply != nil - } - if !rollback { - return &reply.OkReply{} - } - return errReply + //prepare + var errReply redis.Reply + txId := cluster.idGenerator.NextId() + txIdStr := strconv.FormatInt(txId, 10) + rollback := false + for peer, group := range groupMap { + peerArgs := []string{txIdStr} + for _, k := range group { + peerArgs = append(peerArgs, k, valueMap[k]) + } + var resp redis.Reply + if peer == cluster.self { + resp = PrepareMSet(cluster, c, makeArgs("PrepareMSet", peerArgs...)) + } else { + resp = cluster.Relay(peer, c, makeArgs("PrepareMSet", peerArgs...)) + } + if reply.IsErrorReply(resp) { + errReply = resp + rollback = true + break + } + } + if rollback { + // rollback + RequestRollback(cluster, c, txId, groupMap) + } else { + _, errReply = RequestCommit(cluster, c, txId, groupMap) + rollback = errReply != nil + } + if !rollback { + return &reply.OkReply{} + } + return errReply } func MSetNX(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { - argCount := len(args) - 1 - if argCount%2 != 0 || argCount < 1 { - return reply.MakeErrReply("ERR wrong number of arguments for 'mset' command") - } - var peer string - size := argCount / 2 - for i := 0; i < size; i++ { - key := string(args[2*i]) - currentPeer := cluster.peerPicker.Get(key) - if peer == "" { - peer = currentPeer - } else { - if peer != currentPeer { - return reply.MakeErrReply("ERR msetnx must within one slot in cluster mode") - } - } - } - return cluster.Relay(peer, c, args) + argCount := len(args) - 1 + if argCount%2 != 0 || argCount < 1 { + return reply.MakeErrReply("ERR wrong number of arguments for 'mset' command") + } + var peer string + size := argCount / 2 + for i := 0; i < size; i++ { + key := string(args[2*i]) + currentPeer := cluster.peerPicker.Get(key) + if peer == "" { + peer = currentPeer + } else { + if peer != currentPeer { + return reply.MakeErrReply("ERR msetnx must within one slot in cluster mode") + } + } + } + return cluster.Relay(peer, c, args) } diff --git a/src/cluster/rename.go b/src/cluster/rename.go index 42761c73..6793ae33 100644 --- a/src/cluster/rename.go +++ b/src/cluster/rename.go @@ -1,40 +1,39 @@ package cluster import ( - "github.com/HDT3213/godis/src/interface/redis" - "github.com/HDT3213/godis/src/redis/reply" + "github.com/HDT3213/godis/src/interface/redis" + "github.com/HDT3213/godis/src/redis/reply" ) // TODO: support multiplex slots func Rename(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { - if len(args) != 3 { - return reply.MakeErrReply("ERR wrong number of arguments for 'rename' command") - } - src := string(args[1]) - dest := string(args[2]) - - srcPeer := cluster.peerPicker.Get(src) - destPeer := cluster.peerPicker.Get(dest) - - if srcPeer != destPeer { - return reply.MakeErrReply("ERR rename must within one slot in cluster mode") - } - return cluster.Relay(srcPeer, c, args) + if len(args) != 3 { + return reply.MakeErrReply("ERR wrong number of arguments for 'rename' command") + } + src := string(args[1]) + dest := string(args[2]) + + srcPeer := cluster.peerPicker.Get(src) + destPeer := cluster.peerPicker.Get(dest) + + if srcPeer != destPeer { + return reply.MakeErrReply("ERR rename must within one slot in cluster mode") + } + return cluster.Relay(srcPeer, c, args) } func RenameNx(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { - if len(args) != 3 { - return reply.MakeErrReply("ERR wrong number of arguments for 'renamenx' command") - } - src := string(args[1]) - dest := string(args[2]) - - srcPeer := cluster.peerPicker.Get(src) - destPeer := cluster.peerPicker.Get(dest) - - if srcPeer != destPeer { - return reply.MakeErrReply("ERR rename must within one slot in cluster mode") - } - return cluster.Relay(srcPeer, c, args) + if len(args) != 3 { + return reply.MakeErrReply("ERR wrong number of arguments for 'renamenx' command") + } + src := string(args[1]) + dest := string(args[2]) + + srcPeer := cluster.peerPicker.Get(src) + destPeer := cluster.peerPicker.Get(dest) + + if srcPeer != destPeer { + return reply.MakeErrReply("ERR rename must within one slot in cluster mode") + } + return cluster.Relay(srcPeer, c, args) } - diff --git a/src/cluster/router.go b/src/cluster/router.go index 5b2f0ae7..8874d012 100644 --- a/src/cluster/router.go +++ b/src/cluster/router.go @@ -3,106 +3,106 @@ package cluster import "github.com/HDT3213/godis/src/interface/redis" func MakeRouter() map[string]CmdFunc { - routerMap := make(map[string]CmdFunc) - routerMap["ping"] = Ping + routerMap := make(map[string]CmdFunc) + routerMap["ping"] = Ping - routerMap["commit"] = Commit - routerMap["rollback"] = Rollback - routerMap["del"] = Del - routerMap["preparedel"] = PrepareDel - routerMap["preparemset"] = PrepareMSet + routerMap["commit"] = Commit + routerMap["rollback"] = Rollback + routerMap["del"] = Del + routerMap["preparedel"] = PrepareDel + routerMap["preparemset"] = PrepareMSet - routerMap["expire"] = defaultFunc - routerMap["expireat"] = defaultFunc - routerMap["pexpire"] = defaultFunc - routerMap["pexpireat"] = defaultFunc - routerMap["ttl"] = defaultFunc - routerMap["pttl"] = defaultFunc - routerMap["persist"] = defaultFunc - routerMap["exists"] = defaultFunc - routerMap["type"] = defaultFunc - routerMap["rename"] = Rename - routerMap["renamenx"] = RenameNx + routerMap["expire"] = defaultFunc + routerMap["expireat"] = defaultFunc + routerMap["pexpire"] = defaultFunc + routerMap["pexpireat"] = defaultFunc + routerMap["ttl"] = defaultFunc + routerMap["pttl"] = defaultFunc + routerMap["persist"] = defaultFunc + routerMap["exists"] = defaultFunc + routerMap["type"] = defaultFunc + routerMap["rename"] = Rename + routerMap["renamenx"] = RenameNx - routerMap["set"] = defaultFunc - routerMap["setnx"] = defaultFunc - routerMap["setex"] = defaultFunc - routerMap["psetex"] = defaultFunc - routerMap["mset"] = MSet - routerMap["mget"] = MGet - routerMap["msetnx"] = MSetNX - routerMap["get"] = defaultFunc - routerMap["getset"] = defaultFunc - routerMap["incr"] = defaultFunc - routerMap["incrby"] = defaultFunc - routerMap["incrbyfloat"] = defaultFunc - routerMap["decr"] = defaultFunc - routerMap["decrby"] = defaultFunc + routerMap["set"] = defaultFunc + routerMap["setnx"] = defaultFunc + routerMap["setex"] = defaultFunc + routerMap["psetex"] = defaultFunc + routerMap["mset"] = MSet + routerMap["mget"] = MGet + routerMap["msetnx"] = MSetNX + routerMap["get"] = defaultFunc + routerMap["getset"] = defaultFunc + routerMap["incr"] = defaultFunc + routerMap["incrby"] = defaultFunc + routerMap["incrbyfloat"] = defaultFunc + routerMap["decr"] = defaultFunc + routerMap["decrby"] = defaultFunc - routerMap["lpush"] = defaultFunc - routerMap["lpushx"] = defaultFunc - routerMap["rpush"] = defaultFunc - routerMap["rpushx"] = defaultFunc - routerMap["lpop"] = defaultFunc - routerMap["rpop"] = defaultFunc - //routerMap["rpoplpush"] = RPopLPush - routerMap["lrem"] = defaultFunc - routerMap["llen"] = defaultFunc - routerMap["lindex"] = defaultFunc - routerMap["lset"] = defaultFunc - routerMap["lrange"] = defaultFunc + routerMap["lpush"] = defaultFunc + routerMap["lpushx"] = defaultFunc + routerMap["rpush"] = defaultFunc + routerMap["rpushx"] = defaultFunc + routerMap["lpop"] = defaultFunc + routerMap["rpop"] = defaultFunc + //routerMap["rpoplpush"] = RPopLPush + routerMap["lrem"] = defaultFunc + routerMap["llen"] = defaultFunc + routerMap["lindex"] = defaultFunc + routerMap["lset"] = defaultFunc + routerMap["lrange"] = defaultFunc - routerMap["hset"] = defaultFunc - routerMap["hsetnx"] = defaultFunc - routerMap["hget"] = defaultFunc - routerMap["hexists"] = defaultFunc - routerMap["hdel"] = defaultFunc - routerMap["hlen"] = defaultFunc - routerMap["hmget"] = defaultFunc - routerMap["hmset"] = defaultFunc - routerMap["hkeys"] = defaultFunc - routerMap["hvals"] = defaultFunc - routerMap["hgetall"] = defaultFunc - routerMap["hincrby"] = defaultFunc - routerMap["hincrbyfloat"] = defaultFunc + routerMap["hset"] = defaultFunc + routerMap["hsetnx"] = defaultFunc + routerMap["hget"] = defaultFunc + routerMap["hexists"] = defaultFunc + routerMap["hdel"] = defaultFunc + routerMap["hlen"] = defaultFunc + routerMap["hmget"] = defaultFunc + routerMap["hmset"] = defaultFunc + routerMap["hkeys"] = defaultFunc + routerMap["hvals"] = defaultFunc + routerMap["hgetall"] = defaultFunc + routerMap["hincrby"] = defaultFunc + routerMap["hincrbyfloat"] = defaultFunc - routerMap["sadd"] = defaultFunc - routerMap["sismember"] = defaultFunc - routerMap["srem"] = defaultFunc - routerMap["scard"] = defaultFunc - routerMap["smembers"] = defaultFunc - routerMap["sinter"] = defaultFunc - routerMap["sinterstore"] = defaultFunc - routerMap["sunion"] = defaultFunc - routerMap["sunionstore"] = defaultFunc - routerMap["sdiff"] = defaultFunc - routerMap["sdiffstore"] = defaultFunc - routerMap["srandmember"] = defaultFunc + routerMap["sadd"] = defaultFunc + routerMap["sismember"] = defaultFunc + routerMap["srem"] = defaultFunc + routerMap["scard"] = defaultFunc + routerMap["smembers"] = defaultFunc + routerMap["sinter"] = defaultFunc + routerMap["sinterstore"] = defaultFunc + routerMap["sunion"] = defaultFunc + routerMap["sunionstore"] = defaultFunc + routerMap["sdiff"] = defaultFunc + routerMap["sdiffstore"] = defaultFunc + routerMap["srandmember"] = defaultFunc - routerMap["zadd"] = defaultFunc - routerMap["zscore"] = defaultFunc - routerMap["zincrby"] = defaultFunc - routerMap["zrank"] = defaultFunc - routerMap["zcount"] = defaultFunc - routerMap["zrevrank"] = defaultFunc - routerMap["zcard"] = defaultFunc - routerMap["zrange"] = defaultFunc - routerMap["zrevrange"] = defaultFunc - routerMap["zrangebyscore"] = defaultFunc - routerMap["zrevrangebyscore"] = defaultFunc - routerMap["zrem"] = defaultFunc - routerMap["zremrangebyscore"] = defaultFunc - routerMap["zremrangebyrank"] = defaultFunc + routerMap["zadd"] = defaultFunc + routerMap["zscore"] = defaultFunc + routerMap["zincrby"] = defaultFunc + routerMap["zrank"] = defaultFunc + routerMap["zcount"] = defaultFunc + routerMap["zrevrank"] = defaultFunc + routerMap["zcard"] = defaultFunc + routerMap["zrange"] = defaultFunc + routerMap["zrevrange"] = defaultFunc + routerMap["zrangebyscore"] = defaultFunc + routerMap["zrevrangebyscore"] = defaultFunc + routerMap["zrem"] = defaultFunc + routerMap["zremrangebyscore"] = defaultFunc + routerMap["zremrangebyrank"] = defaultFunc - //routerMap["flushdb"] = FlushDB - //routerMap["flushall"] = FlushAll - //routerMap["keys"] = Keys + //routerMap["flushdb"] = FlushDB + //routerMap["flushall"] = FlushAll + //routerMap["keys"] = Keys - return routerMap + return routerMap } func defaultFunc(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { - key := string(args[1]) - peer := cluster.peerPicker.Get(key) - return cluster.Relay(peer, c, args) -} \ No newline at end of file + key := string(args[1]) + peer := cluster.peerPicker.Get(key) + return cluster.Relay(peer, c, args) +} diff --git a/src/cluster/transaction.go b/src/cluster/transaction.go index 20dcae89..22b8b9c8 100644 --- a/src/cluster/transaction.go +++ b/src/cluster/transaction.go @@ -1,188 +1,188 @@ package cluster import ( - "context" - "fmt" - "github.com/HDT3213/godis/src/db" - "github.com/HDT3213/godis/src/interface/redis" - "github.com/HDT3213/godis/src/lib/marshal/gob" - "github.com/HDT3213/godis/src/redis/reply" - "strconv" - "strings" - "time" + "context" + "fmt" + "github.com/HDT3213/godis/src/db" + "github.com/HDT3213/godis/src/interface/redis" + "github.com/HDT3213/godis/src/lib/marshal/gob" + "github.com/HDT3213/godis/src/redis/reply" + "strconv" + "strings" + "time" ) type Transaction struct { - id string // transaction id - args [][]byte // cmd args - cluster *Cluster - conn redis.Connection - - keys []string // related keys - undoLog map[string][]byte // store data for undoLog - - lockUntil time.Time - ctx context.Context - cancel context.CancelFunc - status int8 + id string // transaction id + args [][]byte // cmd args + cluster *Cluster + conn redis.Connection + + keys []string // related keys + undoLog map[string][]byte // store data for undoLog + + lockUntil time.Time + ctx context.Context + cancel context.CancelFunc + status int8 } const ( - maxLockTime = 3 * time.Second + maxLockTime = 3 * time.Second - CreatedStatus = 0 - PreparedStatus = 1 - CommitedStatus = 2 - RollbackedStatus = 3 + CreatedStatus = 0 + PreparedStatus = 1 + CommitedStatus = 2 + RollbackedStatus = 3 ) func NewTransaction(cluster *Cluster, c redis.Connection, id string, args [][]byte, keys []string) *Transaction { - return &Transaction{ - id: id, - args: args, - cluster: cluster, - conn: c, - keys: keys, - status: CreatedStatus, - } + return &Transaction{ + id: id, + args: args, + cluster: cluster, + conn: c, + keys: keys, + status: CreatedStatus, + } } // t should contains Keys field func (tx *Transaction) prepare() error { - // lock keys - tx.cluster.db.Locks(tx.keys...) - - // use context to manage - //tx.lockUntil = time.Now().Add(maxLockTime) - //ctx, cancel := context.WithDeadline(context.Background(), tx.lockUntil) - //tx.ctx = ctx - //tx.cancel = cancel - - // build undoLog - tx.undoLog = make(map[string][]byte) - for _, key := range tx.keys { - entity, ok := tx.cluster.db.Get(key) - if ok { - blob, err := gob.Marshal(entity) - if err != nil { - return err - } - tx.undoLog[key] = blob - } else { - tx.undoLog[key] = []byte{} // entity was nil, should be removed while rollback - } - } - tx.status = PreparedStatus - return nil + // lock keys + tx.cluster.db.Locks(tx.keys...) + + // use context to manage + //tx.lockUntil = time.Now().Add(maxLockTime) + //ctx, cancel := context.WithDeadline(context.Background(), tx.lockUntil) + //tx.ctx = ctx + //tx.cancel = cancel + + // build undoLog + tx.undoLog = make(map[string][]byte) + for _, key := range tx.keys { + entity, ok := tx.cluster.db.Get(key) + if ok { + blob, err := gob.Marshal(entity) + if err != nil { + return err + } + tx.undoLog[key] = blob + } else { + tx.undoLog[key] = []byte{} // entity was nil, should be removed while rollback + } + } + tx.status = PreparedStatus + return nil } func (tx *Transaction) rollback() error { - for key, blob := range tx.undoLog { - if len(blob) > 0 { - entity := &db.DataEntity{} - err := gob.UnMarshal(blob, entity) - if err != nil { - return err - } - tx.cluster.db.Put(key, entity) - } else { - tx.cluster.db.Remove(key) - } - } - if tx.status != CommitedStatus { - tx.cluster.db.UnLocks(tx.keys...) - } - tx.status = RollbackedStatus - return nil + for key, blob := range tx.undoLog { + if len(blob) > 0 { + entity := &db.DataEntity{} + err := gob.UnMarshal(blob, entity) + if err != nil { + return err + } + tx.cluster.db.Put(key, entity) + } else { + tx.cluster.db.Remove(key) + } + } + if tx.status != CommitedStatus { + tx.cluster.db.UnLocks(tx.keys...) + } + tx.status = RollbackedStatus + return nil } // rollback local transaction func Rollback(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { - if len(args) != 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'rollback' command") - } - txId := string(args[1]) - raw, ok := cluster.transactions.Get(txId) - if !ok { - return reply.MakeIntReply(0) - } - tx, _ := raw.(*Transaction) - err := tx.rollback() - if err != nil { - return reply.MakeErrReply(err.Error()) - } - return reply.MakeIntReply(1) + if len(args) != 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'rollback' command") + } + txId := string(args[1]) + raw, ok := cluster.transactions.Get(txId) + if !ok { + return reply.MakeIntReply(0) + } + tx, _ := raw.(*Transaction) + err := tx.rollback() + if err != nil { + return reply.MakeErrReply(err.Error()) + } + return reply.MakeIntReply(1) } // commit local transaction as a worker func Commit(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { - if len(args) != 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'commit' command") - } - txId := string(args[1]) - raw, ok := cluster.transactions.Get(txId) - if !ok { - return reply.MakeIntReply(0) - } - tx, _ := raw.(*Transaction) - - // finish transaction - defer func() { - cluster.db.UnLocks(tx.keys...) - tx.status = CommitedStatus - //cluster.transactions.Remove(tx.id) // cannot remove, may rollback after commit - }() - - cmd := strings.ToLower(string(tx.args[0])) - var result redis.Reply - if cmd == "del" { - result = CommitDel(cluster, c, tx) - } else if cmd == "mset" { - result = CommitMSet(cluster, c, tx) - } - - if reply.IsErrorReply(result) { - // failed - err2 := tx.rollback() - return reply.MakeErrReply(fmt.Sprintf("err occurs when rollback: %v, origin err: %s", err2, result)) - } - - return result + if len(args) != 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'commit' command") + } + txId := string(args[1]) + raw, ok := cluster.transactions.Get(txId) + if !ok { + return reply.MakeIntReply(0) + } + tx, _ := raw.(*Transaction) + + // finish transaction + defer func() { + cluster.db.UnLocks(tx.keys...) + tx.status = CommitedStatus + //cluster.transactions.Remove(tx.id) // cannot remove, may rollback after commit + }() + + cmd := strings.ToLower(string(tx.args[0])) + var result redis.Reply + if cmd == "del" { + result = CommitDel(cluster, c, tx) + } else if cmd == "mset" { + result = CommitMSet(cluster, c, tx) + } + + if reply.IsErrorReply(result) { + // failed + err2 := tx.rollback() + return reply.MakeErrReply(fmt.Sprintf("err occurs when rollback: %v, origin err: %s", err2, result)) + } + + return result } // request all node commit transaction as leader func RequestCommit(cluster *Cluster, c redis.Connection, txId int64, peers map[string][]string) ([]redis.Reply, reply.ErrorReply) { - var errReply reply.ErrorReply - txIdStr := strconv.FormatInt(txId, 10) - respList := make([]redis.Reply, 0, len(peers)) - for peer := range peers { - var resp redis.Reply - if peer == cluster.self { - resp = Commit(cluster, c, makeArgs("commit", txIdStr)) - } else { - resp = cluster.Relay(peer, c, makeArgs("commit", txIdStr)) - } - if reply.IsErrorReply(resp) { - errReply = resp.(reply.ErrorReply) - break - } - respList = append(respList, resp) - } - if errReply != nil { - RequestRollback(cluster, c, txId, peers) - return nil, errReply - } - return respList, nil + var errReply reply.ErrorReply + txIdStr := strconv.FormatInt(txId, 10) + respList := make([]redis.Reply, 0, len(peers)) + for peer := range peers { + var resp redis.Reply + if peer == cluster.self { + resp = Commit(cluster, c, makeArgs("commit", txIdStr)) + } else { + resp = cluster.Relay(peer, c, makeArgs("commit", txIdStr)) + } + if reply.IsErrorReply(resp) { + errReply = resp.(reply.ErrorReply) + break + } + respList = append(respList, resp) + } + if errReply != nil { + RequestRollback(cluster, c, txId, peers) + return nil, errReply + } + return respList, nil } // request all node rollback transaction as leader func RequestRollback(cluster *Cluster, c redis.Connection, txId int64, peers map[string][]string) { - txIdStr := strconv.FormatInt(txId, 10) - for peer := range peers { - if peer == cluster.self { - Rollback(cluster, c, makeArgs("rollback", txIdStr)) - } else { - cluster.Relay(peer, c, makeArgs("rollback", txIdStr)) - } - } -} \ No newline at end of file + txIdStr := strconv.FormatInt(txId, 10) + for peer := range peers { + if peer == cluster.self { + Rollback(cluster, c, makeArgs("rollback", txIdStr)) + } else { + cluster.Relay(peer, c, makeArgs("rollback", txIdStr)) + } + } +} diff --git a/src/cmd/main.go b/src/cmd/main.go index e1ad6355..8e0d29f2 100644 --- a/src/cmd/main.go +++ b/src/cmd/main.go @@ -4,7 +4,7 @@ import ( "fmt" "github.com/HDT3213/godis/src/config" "github.com/HDT3213/godis/src/lib/logger" - RedisServer "github.com/HDT3213/godis/src/redis/server" + RedisServer "github.com/HDT3213/godis/src/redis/server" "github.com/HDT3213/godis/src/tcp" "os" ) @@ -23,6 +23,6 @@ func main() { }) tcp.ListenAndServe(&tcp.Config{ - Address: fmt.Sprintf("%s:%d", config.Properties.Bind, config.Properties.Port), - }, RedisServer.MakeHandler()) + Address: fmt.Sprintf("%s:%d", config.Properties.Bind, config.Properties.Port), + }, RedisServer.MakeHandler()) } diff --git a/src/config/config.go b/src/config/config.go index 0cf1ccb0..34f37c26 100644 --- a/src/config/config.go +++ b/src/config/config.go @@ -23,12 +23,12 @@ type PropertyHolder struct { var Properties *PropertyHolder func init() { - // default config - Properties = &PropertyHolder{ - Bind: "127.0.0.1", - Port: 6379, - AppendOnly: false, - } + // default config + Properties = &PropertyHolder{ + Bind: "127.0.0.1", + Port: 6379, + AppendOnly: false, + } } func LoadConfig(configFilename string) *PropertyHolder { diff --git a/src/datastruct/dict/dict.go b/src/datastruct/dict/dict.go index 5129bb08..f2428bbf 100644 --- a/src/datastruct/dict/dict.go +++ b/src/datastruct/dict/dict.go @@ -1,16 +1,16 @@ package dict -type Consumer func(key string, val interface{})bool +type Consumer func(key string, val interface{}) bool type Dict interface { - Get(key string) (val interface{}, exists bool) - Len() int - Put(key string, val interface{}) (result int) - PutIfAbsent(key string, val interface{}) (result int) - PutIfExists(key string, val interface{}) (result int) - Remove(key string) (result int) - ForEach(consumer Consumer) - Keys() []string - RandomKeys(limit int) []string - RandomDistinctKeys(limit int) []string + Get(key string) (val interface{}, exists bool) + Len() int + Put(key string, val interface{}) (result int) + PutIfAbsent(key string, val interface{}) (result int) + PutIfExists(key string, val interface{}) (result int) + Remove(key string) (result int) + ForEach(consumer Consumer) + Keys() []string + RandomKeys(limit int) []string + RandomDistinctKeys(limit int) []string } diff --git a/src/datastruct/dict/dict_test.go b/src/datastruct/dict/dict_test.go index 285de31b..2cb551d2 100644 --- a/src/datastruct/dict/dict_test.go +++ b/src/datastruct/dict/dict_test.go @@ -1,244 +1,244 @@ package dict import ( - "strconv" - "sync" - "testing" + "strconv" + "sync" + "testing" ) func TestPut(t *testing.T) { - d := MakeConcurrent(0) - count := 100 - var wg sync.WaitGroup - wg.Add(count) - for i := 0; i < count; i++ { - go func(i int) { - // insert - key := "k" + strconv.Itoa(i) - ret := d.Put(key, i) - if ret != 1 { // insert 1 - t.Error("put test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key: " + key) - } - val, ok := d.Get(key) - if ok { - intVal, _ := val.(int) - if intVal != i { - t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal) + ", key: " + key) - } - } else { - _, ok := d.Get(key) - t.Error("put test failed: expected true, actual: false, key: " + key + ", retry: " + strconv.FormatBool(ok)) - } - wg.Done() - }(i) - } - wg.Wait() + d := MakeConcurrent(0) + count := 100 + var wg sync.WaitGroup + wg.Add(count) + for i := 0; i < count; i++ { + go func(i int) { + // insert + key := "k" + strconv.Itoa(i) + ret := d.Put(key, i) + if ret != 1 { // insert 1 + t.Error("put test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key: " + key) + } + val, ok := d.Get(key) + if ok { + intVal, _ := val.(int) + if intVal != i { + t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal) + ", key: " + key) + } + } else { + _, ok := d.Get(key) + t.Error("put test failed: expected true, actual: false, key: " + key + ", retry: " + strconv.FormatBool(ok)) + } + wg.Done() + }(i) + } + wg.Wait() } func TestPutIfAbsent(t *testing.T) { - d := MakeConcurrent(0) - count := 100 - var wg sync.WaitGroup - wg.Add(count) - for i := 0; i < count; i++ { - go func(i int) { - // insert - key := "k" + strconv.Itoa(i) - ret := d.PutIfAbsent(key, i) - if ret != 1 { // insert 1 - t.Error("put test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key: " + key) - } - val, ok := d.Get(key) - if ok { - intVal, _ := val.(int) - if intVal != i { - t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal) + - ", key: " + key) - } - } else { - _, ok := d.Get(key) - t.Error("put test failed: expected true, actual: false, key: " + key + ", retry: " + strconv.FormatBool(ok)) - } - - // update - ret = d.PutIfAbsent(key, i * 10) - if ret != 0 { // no update - t.Error("put test failed: expected result 0, actual: " + strconv.Itoa(ret)) - } - val, ok = d.Get(key) - if ok { - intVal, _ := val.(int) - if intVal != i { - t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal) + ", key: " + key) - } - } else { - t.Error("put test failed: expected true, actual: false, key: " + key) - } - wg.Done() - }(i) - } - wg.Wait() + d := MakeConcurrent(0) + count := 100 + var wg sync.WaitGroup + wg.Add(count) + for i := 0; i < count; i++ { + go func(i int) { + // insert + key := "k" + strconv.Itoa(i) + ret := d.PutIfAbsent(key, i) + if ret != 1 { // insert 1 + t.Error("put test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key: " + key) + } + val, ok := d.Get(key) + if ok { + intVal, _ := val.(int) + if intVal != i { + t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal) + + ", key: " + key) + } + } else { + _, ok := d.Get(key) + t.Error("put test failed: expected true, actual: false, key: " + key + ", retry: " + strconv.FormatBool(ok)) + } + + // update + ret = d.PutIfAbsent(key, i*10) + if ret != 0 { // no update + t.Error("put test failed: expected result 0, actual: " + strconv.Itoa(ret)) + } + val, ok = d.Get(key) + if ok { + intVal, _ := val.(int) + if intVal != i { + t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal) + ", key: " + key) + } + } else { + t.Error("put test failed: expected true, actual: false, key: " + key) + } + wg.Done() + }(i) + } + wg.Wait() } func TestPutIfExists(t *testing.T) { - d := MakeConcurrent(0) - count := 100 - var wg sync.WaitGroup - wg.Add(count) - for i := 0; i < count; i++ { - go func(i int) { - // insert - key := "k" + strconv.Itoa(i) - // insert - ret := d.PutIfExists(key, i) - if ret != 0 { // insert - t.Error("put test failed: expected result 0, actual: " + strconv.Itoa(ret)) - } - - d.Put(key, i) - ret = d.PutIfExists(key, 10 * i) - val, ok := d.Get(key) - if ok { - intVal, _ := val.(int) - if intVal != 10 * i { - t.Error("put test failed: expected " + strconv.Itoa(10 * i) + ", actual: " + strconv.Itoa(intVal)) - } - } else { - _, ok := d.Get(key) - t.Error("put test failed: expected true, actual: false, key: " + key + ", retry: " + strconv.FormatBool(ok)) - } - wg.Done() - }(i) - } - wg.Wait() + d := MakeConcurrent(0) + count := 100 + var wg sync.WaitGroup + wg.Add(count) + for i := 0; i < count; i++ { + go func(i int) { + // insert + key := "k" + strconv.Itoa(i) + // insert + ret := d.PutIfExists(key, i) + if ret != 0 { // insert + t.Error("put test failed: expected result 0, actual: " + strconv.Itoa(ret)) + } + + d.Put(key, i) + ret = d.PutIfExists(key, 10*i) + val, ok := d.Get(key) + if ok { + intVal, _ := val.(int) + if intVal != 10*i { + t.Error("put test failed: expected " + strconv.Itoa(10*i) + ", actual: " + strconv.Itoa(intVal)) + } + } else { + _, ok := d.Get(key) + t.Error("put test failed: expected true, actual: false, key: " + key + ", retry: " + strconv.FormatBool(ok)) + } + wg.Done() + }(i) + } + wg.Wait() } func TestRemove(t *testing.T) { - d := MakeConcurrent(0) - - // remove head node - for i := 0; i < 100; i++ { - // insert - key := "k" + strconv.Itoa(i) - d.Put(key, i) - } - for i := 0; i < 100; i++ { - key := "k" + strconv.Itoa(i) - - val, ok := d.Get(key) - if ok { - intVal, _ := val.(int) - if intVal != i { - t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal)) - } - } else { - t.Error("put test failed: expected true, actual: false") - } - - ret := d.Remove(key) - if ret != 1 { - t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key:" + key) - } - _, ok = d.Get(key) - if ok { - t.Error("remove test failed: expected true, actual false") - } - ret = d.Remove(key) - if ret != 0 { - t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret)) - } - } - - // remove tail node - d = MakeConcurrent(0) - for i := 0; i < 100; i++ { - // insert - key := "k" + strconv.Itoa(i) - d.Put(key, i) - } - for i := 9; i >= 0; i-- { - key := "k" + strconv.Itoa(i) - - val, ok := d.Get(key) - if ok { - intVal, _ := val.(int) - if intVal != i { - t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal)) - } - } else { - t.Error("put test failed: expected true, actual: false") - } - - ret := d.Remove(key) - if ret != 1 { - t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret)) - } - _, ok = d.Get(key) - if ok { - t.Error("remove test failed: expected true, actual false") - } - ret = d.Remove(key) - if ret != 0 { - t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret)) - } - } - - // remove middle node - d = MakeConcurrent(0) - d.Put("head", 0) - for i := 0; i < 10; i++ { - // insert - key := "k" + strconv.Itoa(i) - d.Put(key, i) - } - d.Put("tail", 0) - for i := 9; i >= 0; i-- { - key := "k" + strconv.Itoa(i) - - val, ok := d.Get(key) - if ok { - intVal, _ := val.(int) - if intVal != i { - t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal)) - } - } else { - t.Error("put test failed: expected true, actual: false") - } - - ret := d.Remove(key) - if ret != 1 { - t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret)) - } - _, ok = d.Get(key) - if ok { - t.Error("remove test failed: expected true, actual false") - } - ret = d.Remove(key) - if ret != 0 { - t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret)) - } - } + d := MakeConcurrent(0) + + // remove head node + for i := 0; i < 100; i++ { + // insert + key := "k" + strconv.Itoa(i) + d.Put(key, i) + } + for i := 0; i < 100; i++ { + key := "k" + strconv.Itoa(i) + + val, ok := d.Get(key) + if ok { + intVal, _ := val.(int) + if intVal != i { + t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal)) + } + } else { + t.Error("put test failed: expected true, actual: false") + } + + ret := d.Remove(key) + if ret != 1 { + t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key:" + key) + } + _, ok = d.Get(key) + if ok { + t.Error("remove test failed: expected true, actual false") + } + ret = d.Remove(key) + if ret != 0 { + t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret)) + } + } + + // remove tail node + d = MakeConcurrent(0) + for i := 0; i < 100; i++ { + // insert + key := "k" + strconv.Itoa(i) + d.Put(key, i) + } + for i := 9; i >= 0; i-- { + key := "k" + strconv.Itoa(i) + + val, ok := d.Get(key) + if ok { + intVal, _ := val.(int) + if intVal != i { + t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal)) + } + } else { + t.Error("put test failed: expected true, actual: false") + } + + ret := d.Remove(key) + if ret != 1 { + t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret)) + } + _, ok = d.Get(key) + if ok { + t.Error("remove test failed: expected true, actual false") + } + ret = d.Remove(key) + if ret != 0 { + t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret)) + } + } + + // remove middle node + d = MakeConcurrent(0) + d.Put("head", 0) + for i := 0; i < 10; i++ { + // insert + key := "k" + strconv.Itoa(i) + d.Put(key, i) + } + d.Put("tail", 0) + for i := 9; i >= 0; i-- { + key := "k" + strconv.Itoa(i) + + val, ok := d.Get(key) + if ok { + intVal, _ := val.(int) + if intVal != i { + t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal)) + } + } else { + t.Error("put test failed: expected true, actual: false") + } + + ret := d.Remove(key) + if ret != 1 { + t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret)) + } + _, ok = d.Get(key) + if ok { + t.Error("remove test failed: expected true, actual false") + } + ret = d.Remove(key) + if ret != 0 { + t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret)) + } + } } func TestForEach(t *testing.T) { - d := MakeConcurrent(0) - size := 100 - for i := 0; i < size; i++ { - // insert - key := "k" + strconv.Itoa(i) - d.Put(key, i) - } - i := 0 - d.ForEach(func(key string, value interface{})bool { - intVal, _ := value.(int) - expectedKey := "k" + strconv.Itoa(intVal) - if key != expectedKey { - t.Error("remove test failed: expected " + expectedKey + ", actual: " + key) - } - i++ - return true - }) - if i != size { - t.Error("remove test failed: expected " + strconv.Itoa(size) + ", actual: " + strconv.Itoa(i)) - } -} \ No newline at end of file + d := MakeConcurrent(0) + size := 100 + for i := 0; i < size; i++ { + // insert + key := "k" + strconv.Itoa(i) + d.Put(key, i) + } + i := 0 + d.ForEach(func(key string, value interface{}) bool { + intVal, _ := value.(int) + expectedKey := "k" + strconv.Itoa(intVal) + if key != expectedKey { + t.Error("remove test failed: expected " + expectedKey + ", actual: " + key) + } + i++ + return true + }) + if i != size { + t.Error("remove test failed: expected " + strconv.Itoa(size) + ", actual: " + strconv.Itoa(i)) + } +} diff --git a/src/datastruct/dict/simple.go b/src/datastruct/dict/simple.go index ac39b675..276a8711 100644 --- a/src/datastruct/dict/simple.go +++ b/src/datastruct/dict/simple.go @@ -1,108 +1,108 @@ package dict type SimpleDict struct { - m map[string]interface{} + m map[string]interface{} } func MakeSimple() *SimpleDict { - return &SimpleDict{ - m: make(map[string]interface{}), - } + return &SimpleDict{ + m: make(map[string]interface{}), + } } func (dict *SimpleDict) Get(key string) (val interface{}, exists bool) { - val, ok := dict.m[key] - return val, ok + val, ok := dict.m[key] + return val, ok } func (dict *SimpleDict) Len() int { - if dict.m == nil { - panic("m is nil") - } - return len(dict.m) + if dict.m == nil { + panic("m is nil") + } + return len(dict.m) } func (dict *SimpleDict) Put(key string, val interface{}) (result int) { - _, existed := dict.m[key] - dict.m[key] = val - if existed { - return 0 - } else { - return 1 - } + _, existed := dict.m[key] + dict.m[key] = val + if existed { + return 0 + } else { + return 1 + } } func (dict *SimpleDict) PutIfAbsent(key string, val interface{}) (result int) { - _, existed := dict.m[key] - if existed { - return 0 - } else { - dict.m[key] = val - return 1 - } + _, existed := dict.m[key] + if existed { + return 0 + } else { + dict.m[key] = val + return 1 + } } func (dict *SimpleDict) PutIfExists(key string, val interface{}) (result int) { - _, existed := dict.m[key] - if existed { - dict.m[key] = val - return 1 - } else { - return 0 - } + _, existed := dict.m[key] + if existed { + dict.m[key] = val + return 1 + } else { + return 0 + } } func (dict *SimpleDict) Remove(key string) (result int) { - _, existed := dict.m[key] - delete(dict.m, key) - if existed { - return 1 - } else { - return 0 - } + _, existed := dict.m[key] + delete(dict.m, key) + if existed { + return 1 + } else { + return 0 + } } func (dict *SimpleDict) Keys() []string { - result := make([]string, len(dict.m)) - i := 0 - for k := range dict.m { - result[i] = k - } - return result + result := make([]string, len(dict.m)) + i := 0 + for k := range dict.m { + result[i] = k + } + return result } func (dict *SimpleDict) ForEach(consumer Consumer) { - for k, v := range dict.m { - if !consumer(k, v) { - break - } - } + for k, v := range dict.m { + if !consumer(k, v) { + break + } + } } func (dict *SimpleDict) RandomKeys(limit int) []string { - result := make([]string, limit) - for i := 0; i < limit; i++ { - for k := range dict.m { - result[i] = k - break - } - } - return result + result := make([]string, limit) + for i := 0; i < limit; i++ { + for k := range dict.m { + result[i] = k + break + } + } + return result } func (dict *SimpleDict) RandomDistinctKeys(limit int) []string { - size := limit - if size > len(dict.m) { - size = len(dict.m) - } - result := make([]string, size) - i := 0 - for k := range dict.m { - if i == limit { - break - } - result[i] = k - i++ - } - return result + size := limit + if size > len(dict.m) { + size = len(dict.m) + } + result := make([]string, size) + i := 0 + for k := range dict.m { + if i == limit { + break + } + result[i] = k + i++ + } + return result } diff --git a/src/datastruct/list/linked.go b/src/datastruct/list/linked.go index 5f79428a..1a5aabb0 100644 --- a/src/datastruct/list/linked.go +++ b/src/datastruct/list/linked.go @@ -3,324 +3,324 @@ package list import "github.com/HDT3213/godis/src/datastruct/utils" type LinkedList struct { - first *node - last *node - size int + first *node + last *node + size int } type node struct { - val interface{} - prev *node - next * node + val interface{} + prev *node + next *node } -func (list *LinkedList)Add(val interface{}) { - if list == nil { - panic("list is nil") - } - n := &node{ - val: val, - } - if list.last == nil { - // empty list - list.first = n - list.last = n - } else { - n.prev = list.last - list.last.next = n - list.last = n - } - list.size++ +func (list *LinkedList) Add(val interface{}) { + if list == nil { + panic("list is nil") + } + n := &node{ + val: val, + } + if list.last == nil { + // empty list + list.first = n + list.last = n + } else { + n.prev = list.last + list.last.next = n + list.last = n + } + list.size++ } -func (list *LinkedList)find(index int)(n *node) { - if index < list.size / 2 { - n := list.first - for i := 0; i < index; i++ { - n = n.next - } - return n - } else { - n := list.last - for i := list.size - 1; i > index; i-- { - n = n.prev - } - return n - } +func (list *LinkedList) find(index int) (n *node) { + if index < list.size/2 { + n := list.first + for i := 0; i < index; i++ { + n = n.next + } + return n + } else { + n := list.last + for i := list.size - 1; i > index; i-- { + n = n.prev + } + return n + } } -func (list *LinkedList)Get(index int)(val interface{}) { - if list == nil { - panic("list is nil") - } - if index < 0 || index >= list.size { - panic("index out of bound") - } - return list.find(index).val +func (list *LinkedList) Get(index int) (val interface{}) { + if list == nil { + panic("list is nil") + } + if index < 0 || index >= list.size { + panic("index out of bound") + } + return list.find(index).val } -func (list *LinkedList)Set(index int, val interface{}) { - if list == nil { - panic("list is nil") - } - if index < 0 || index > list.size { - panic("index out of bound") - } - n := list.find(index) - n.val = val +func (list *LinkedList) Set(index int, val interface{}) { + if list == nil { + panic("list is nil") + } + if index < 0 || index > list.size { + panic("index out of bound") + } + n := list.find(index) + n.val = val } -func (list *LinkedList)Insert(index int, val interface{}) { - if list == nil { - panic("list is nil") - } - if index < 0 || index > list.size { - panic("index out of bound") - } +func (list *LinkedList) Insert(index int, val interface{}) { + if list == nil { + panic("list is nil") + } + if index < 0 || index > list.size { + panic("index out of bound") + } - if index == list.size { - list.Add(val) - return - } else { - // list is not empty - pivot := list.find(index) - n := &node{ - val: val, - prev: pivot.prev, - next: pivot, - } - if pivot.prev == nil { - list.first = n - } else { - pivot.prev.next = n - } - pivot.prev = n - list.size++ - } + if index == list.size { + list.Add(val) + return + } else { + // list is not empty + pivot := list.find(index) + n := &node{ + val: val, + prev: pivot.prev, + next: pivot, + } + if pivot.prev == nil { + list.first = n + } else { + pivot.prev.next = n + } + pivot.prev = n + list.size++ + } } -func (list *LinkedList)removeNode(n *node) { - if n.prev == nil { - list.first = n.next - } else { - n.prev.next = n.next - } - if n.next == nil { - list.last = n.prev - } else { - n.next.prev = n.prev - } +func (list *LinkedList) removeNode(n *node) { + if n.prev == nil { + list.first = n.next + } else { + n.prev.next = n.next + } + if n.next == nil { + list.last = n.prev + } else { + n.next.prev = n.prev + } - // for gc - n.prev = nil - n.next = nil + // for gc + n.prev = nil + n.next = nil - list.size-- + list.size-- } -func (list *LinkedList)Remove(index int)(val interface{}) { - if list == nil { - panic("list is nil") - } - if index < 0 || index >= list.size { - panic("index out of bound") - } +func (list *LinkedList) Remove(index int) (val interface{}) { + if list == nil { + panic("list is nil") + } + if index < 0 || index >= list.size { + panic("index out of bound") + } - n := list.find(index) - list.removeNode(n) - return n.val + n := list.find(index) + list.removeNode(n) + return n.val } -func (list *LinkedList)RemoveLast()(val interface{}) { - if list == nil { - panic("list is nil") - } - if list.last == nil { - // empty list - return nil - } - n := list.last - list.removeNode(n) - return n.val +func (list *LinkedList) RemoveLast() (val interface{}) { + if list == nil { + panic("list is nil") + } + if list.last == nil { + // empty list + return nil + } + n := list.last + list.removeNode(n) + return n.val } -func (list *LinkedList)RemoveAllByVal(val interface{})int { - if list == nil { - panic("list is nil") - } - n := list.first - removed := 0 - for n != nil { - var toRemoveNode *node - if utils.Equals(n.val, val) { - toRemoveNode = n - } - if n.next == nil { - if toRemoveNode != nil { - removed++ - list.removeNode(toRemoveNode) - } - break - } else { - n = n.next - } - if toRemoveNode != nil { - removed++ - list.removeNode(toRemoveNode) - } - } - return removed +func (list *LinkedList) RemoveAllByVal(val interface{}) int { + if list == nil { + panic("list is nil") + } + n := list.first + removed := 0 + for n != nil { + var toRemoveNode *node + if utils.Equals(n.val, val) { + toRemoveNode = n + } + if n.next == nil { + if toRemoveNode != nil { + removed++ + list.removeNode(toRemoveNode) + } + break + } else { + n = n.next + } + if toRemoveNode != nil { + removed++ + list.removeNode(toRemoveNode) + } + } + return removed } /** * remove at most `count` values of the specified value in this list * scan from left to right */ -func (list *LinkedList) RemoveByVal(val interface{}, count int)int { - if list == nil { - panic("list is nil") - } - n := list.first - removed := 0 - for n != nil { - var toRemoveNode *node - if utils.Equals(n.val, val) { - toRemoveNode = n - } - if n.next == nil { - if toRemoveNode != nil { - removed++ - list.removeNode(toRemoveNode) - } - break - } else { - n = n.next - } +func (list *LinkedList) RemoveByVal(val interface{}, count int) int { + if list == nil { + panic("list is nil") + } + n := list.first + removed := 0 + for n != nil { + var toRemoveNode *node + if utils.Equals(n.val, val) { + toRemoveNode = n + } + if n.next == nil { + if toRemoveNode != nil { + removed++ + list.removeNode(toRemoveNode) + } + break + } else { + n = n.next + } - if toRemoveNode != nil { - removed++ - list.removeNode(toRemoveNode) - } - if removed == count { - break - } - } - return removed + if toRemoveNode != nil { + removed++ + list.removeNode(toRemoveNode) + } + if removed == count { + break + } + } + return removed } -func (list *LinkedList) ReverseRemoveByVal(val interface{}, count int)int { - if list == nil { - panic("list is nil") - } - n := list.last - removed := 0 - for n != nil { - var toRemoveNode *node - if utils.Equals(n.val, val) { - toRemoveNode = n - } - if n.prev == nil { - if toRemoveNode != nil { - removed++ - list.removeNode(toRemoveNode) - } - break - } else { - n = n.prev - } +func (list *LinkedList) ReverseRemoveByVal(val interface{}, count int) int { + if list == nil { + panic("list is nil") + } + n := list.last + removed := 0 + for n != nil { + var toRemoveNode *node + if utils.Equals(n.val, val) { + toRemoveNode = n + } + if n.prev == nil { + if toRemoveNode != nil { + removed++ + list.removeNode(toRemoveNode) + } + break + } else { + n = n.prev + } - if toRemoveNode != nil { - removed++ - list.removeNode(toRemoveNode) - } - if removed == count { - break - } - } - return removed + if toRemoveNode != nil { + removed++ + list.removeNode(toRemoveNode) + } + if removed == count { + break + } + } + return removed } -func (list *LinkedList)Len()int { - if list == nil { - panic("list is nil") - } - return list.size +func (list *LinkedList) Len() int { + if list == nil { + panic("list is nil") + } + return list.size } -func (list *LinkedList)ForEach(consumer func(int, interface{})bool) { - if list == nil { - panic("list is nil") - } - n := list.first - i := 0 - for n != nil { - goNext := consumer(i, n.val) - if !goNext || n.next == nil { - break - } else { - i++ - n = n.next - } - } +func (list *LinkedList) ForEach(consumer func(int, interface{}) bool) { + if list == nil { + panic("list is nil") + } + n := list.first + i := 0 + for n != nil { + goNext := consumer(i, n.val) + if !goNext || n.next == nil { + break + } else { + i++ + n = n.next + } + } } -func (list *LinkedList)Contains(val interface{})bool { - contains := false - list.ForEach(func(i int, actual interface{}) bool { - if actual == val { - contains = true - return false - } - return true - }) - return contains +func (list *LinkedList) Contains(val interface{}) bool { + contains := false + list.ForEach(func(i int, actual interface{}) bool { + if actual == val { + contains = true + return false + } + return true + }) + return contains } -func (list *LinkedList)Range(start int, stop int)[]interface{} { - if list == nil { - panic("list is nil") - } - if start < 0 || start >= list.size { - panic("`start` out of range") - } - if stop < start || stop > list.size { - panic("`stop` out of range") - } +func (list *LinkedList) Range(start int, stop int) []interface{} { + if list == nil { + panic("list is nil") + } + if start < 0 || start >= list.size { + panic("`start` out of range") + } + if stop < start || stop > list.size { + panic("`stop` out of range") + } - sliceSize := stop - start - slice := make([]interface{}, sliceSize) - n := list.first - i := 0 - sliceIndex := 0 - for n != nil { - if i >= start && i < stop { - slice[sliceIndex] = n.val - sliceIndex++ - } else if i >= stop { - break - } - if n.next == nil { - break - } else { - i++ - n = n.next - } - } - return slice + sliceSize := stop - start + slice := make([]interface{}, sliceSize) + n := list.first + i := 0 + sliceIndex := 0 + for n != nil { + if i >= start && i < stop { + slice[sliceIndex] = n.val + sliceIndex++ + } else if i >= stop { + break + } + if n.next == nil { + break + } else { + i++ + n = n.next + } + } + return slice } func Make(vals ...interface{}) *LinkedList { - list := LinkedList{} - for _, v := range vals { - list.Add(v) - } - return &list + list := LinkedList{} + for _, v := range vals { + list.Add(v) + } + return &list } func MakeBytesList(vals ...[]byte) *LinkedList { - list := LinkedList{} - for _, v := range vals { - list.Add(v) - } - return &list -} \ No newline at end of file + list := LinkedList{} + for _, v := range vals { + list.Add(v) + } + return &list +} diff --git a/src/datastruct/list/linked_test.go b/src/datastruct/list/linked_test.go index f2e82af1..d624498e 100644 --- a/src/datastruct/list/linked_test.go +++ b/src/datastruct/list/linked_test.go @@ -1,215 +1,215 @@ package list import ( - "testing" - "strconv" - "strings" + "strconv" + "strings" + "testing" ) func ToString(list *LinkedList) string { - arr := make([]string, list.size) - list.ForEach(func(i int, v interface{}) bool { - integer, _ := v.(int) - arr[i] = strconv.Itoa(integer) - return true - }) - return "[" + strings.Join(arr, ", ") + "]" + arr := make([]string, list.size) + list.ForEach(func(i int, v interface{}) bool { + integer, _ := v.(int) + arr[i] = strconv.Itoa(integer) + return true + }) + return "[" + strings.Join(arr, ", ") + "]" } func TestAdd(t *testing.T) { - list := Make() - for i := 0; i < 10; i++ { - list.Add(i) - } - list.ForEach(func(i int, v interface{}) bool { - intVal, _ := v.(int) - if intVal != i { - t.Error("add test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal)) - } - return true - }) + list := Make() + for i := 0; i < 10; i++ { + list.Add(i) + } + list.ForEach(func(i int, v interface{}) bool { + intVal, _ := v.(int) + if intVal != i { + t.Error("add test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal)) + } + return true + }) } func TestGet(t *testing.T) { - list := Make() - for i := 0; i < 10; i++ { - list.Add(i) - } - for i := 0; i < 10; i++ { - v := list.Get(i) - k, _ := v.(int) - if i != k { - t.Error("get test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(k)) - } - } + list := Make() + for i := 0; i < 10; i++ { + list.Add(i) + } + for i := 0; i < 10; i++ { + v := list.Get(i) + k, _ := v.(int) + if i != k { + t.Error("get test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(k)) + } + } } func TestRemove(t *testing.T) { - list := Make() - for i := 0; i < 10; i++ { - list.Add(i) - } - for i := 9; i >= 0; i-- { - list.Remove(i) - if i != list.Len() { - t.Error("remove test fail: expected size " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(list.Len())) - } - list.ForEach(func(i int, v interface{}) bool { - intVal, _ := v.(int) - if intVal != i { - t.Error("remove test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal)) - } - return true - }) - } + list := Make() + for i := 0; i < 10; i++ { + list.Add(i) + } + for i := 9; i >= 0; i-- { + list.Remove(i) + if i != list.Len() { + t.Error("remove test fail: expected size " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(list.Len())) + } + list.ForEach(func(i int, v interface{}) bool { + intVal, _ := v.(int) + if intVal != i { + t.Error("remove test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal)) + } + return true + }) + } } func TestRemoveVal(t *testing.T) { - list := Make() - for i := 0; i < 10; i++ { - list.Add(i) - list.Add(i) - } - for index := 0; index < list.Len(); index++ { - list.RemoveAllByVal(index) - list.ForEach(func(i int, v interface{}) bool { - intVal, _ := v.(int) - if intVal == index { - t.Error("remove test fail: found " + strconv.Itoa(index) + " at index: " + strconv.Itoa(i)) - } - return true - }) - } + list := Make() + for i := 0; i < 10; i++ { + list.Add(i) + list.Add(i) + } + for index := 0; index < list.Len(); index++ { + list.RemoveAllByVal(index) + list.ForEach(func(i int, v interface{}) bool { + intVal, _ := v.(int) + if intVal == index { + t.Error("remove test fail: found " + strconv.Itoa(index) + " at index: " + strconv.Itoa(i)) + } + return true + }) + } - list = Make() - for i := 0; i < 10; i++ { - list.Add(i) - list.Add(i) - } - for i := 0; i < 10; i++ { - list.RemoveByVal(i, 1) - } - list.ForEach(func(i int, v interface{}) bool { - intVal, _ := v.(int) - if intVal != i { - t.Error("test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal)) - } - return true - }) - for i := 0; i < 10; i++ { - list.RemoveByVal(i, 1) - } - if list.Len() != 0 { - t.Error("test fail: expected 0, actual: " + strconv.Itoa(list.Len())) - } + list = Make() + for i := 0; i < 10; i++ { + list.Add(i) + list.Add(i) + } + for i := 0; i < 10; i++ { + list.RemoveByVal(i, 1) + } + list.ForEach(func(i int, v interface{}) bool { + intVal, _ := v.(int) + if intVal != i { + t.Error("test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal)) + } + return true + }) + for i := 0; i < 10; i++ { + list.RemoveByVal(i, 1) + } + if list.Len() != 0 { + t.Error("test fail: expected 0, actual: " + strconv.Itoa(list.Len())) + } - list = Make() - for i := 0; i < 10; i++ { - list.Add(i) - list.Add(i) - } - for i := 0; i < 10; i++ { - list.ReverseRemoveByVal(i, 1) - } - list.ForEach(func(i int, v interface{}) bool { - intVal, _ := v.(int) - if intVal != i { - t.Error("test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal)) - } - return true - }) - for i := 0; i < 10; i++ { - list.ReverseRemoveByVal(i, 1) - } - if list.Len() != 0 { - t.Error("test fail: expected 0, actual: " + strconv.Itoa(list.Len())) - } + list = Make() + for i := 0; i < 10; i++ { + list.Add(i) + list.Add(i) + } + for i := 0; i < 10; i++ { + list.ReverseRemoveByVal(i, 1) + } + list.ForEach(func(i int, v interface{}) bool { + intVal, _ := v.(int) + if intVal != i { + t.Error("test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal)) + } + return true + }) + for i := 0; i < 10; i++ { + list.ReverseRemoveByVal(i, 1) + } + if list.Len() != 0 { + t.Error("test fail: expected 0, actual: " + strconv.Itoa(list.Len())) + } } func TestInsert(t *testing.T) { - list := Make() - for i := 0; i < 10; i++ { - list.Add(i) - } - for i := 0; i < 10; i++ { - list.Insert(i*2, i) + list := Make() + for i := 0; i < 10; i++ { + list.Add(i) + } + for i := 0; i < 10; i++ { + list.Insert(i*2, i) - list.ForEach(func(j int, v interface{}) bool { - var expected int - if j < (i + 1) * 2 { - if j%2 == 0 { - expected = j / 2 - } else { - expected = (j - 1) / 2 - } - } else { - expected = j - i - 1 - } - actual, _ := list.Get(j).(int) - if actual != expected { - t.Error("insert test fail: at i " + strconv.Itoa(i) + " expected " + strconv.Itoa(expected) + ", actual: " + strconv.Itoa(actual)) - } - return true - }) + list.ForEach(func(j int, v interface{}) bool { + var expected int + if j < (i+1)*2 { + if j%2 == 0 { + expected = j / 2 + } else { + expected = (j - 1) / 2 + } + } else { + expected = j - i - 1 + } + actual, _ := list.Get(j).(int) + if actual != expected { + t.Error("insert test fail: at i " + strconv.Itoa(i) + " expected " + strconv.Itoa(expected) + ", actual: " + strconv.Itoa(actual)) + } + return true + }) - for j := 0; j < list.Len(); j++ { - var expected int - if j < (i + 1) * 2 { - if j%2 == 0 { - expected = j / 2 - } else { - expected = (j - 1) / 2 - } - } else { - expected = j - i - 1 - } - actual, _ := list.Get(j).(int) - if actual != expected { - t.Error("insert test fail: at i " + strconv.Itoa(i) + " expected " + strconv.Itoa(expected) + ", actual: " + strconv.Itoa(actual)) - } - } + for j := 0; j < list.Len(); j++ { + var expected int + if j < (i+1)*2 { + if j%2 == 0 { + expected = j / 2 + } else { + expected = (j - 1) / 2 + } + } else { + expected = j - i - 1 + } + actual, _ := list.Get(j).(int) + if actual != expected { + t.Error("insert test fail: at i " + strconv.Itoa(i) + " expected " + strconv.Itoa(expected) + ", actual: " + strconv.Itoa(actual)) + } + } - } + } } func TestRemoveLast(t *testing.T) { - list := Make() - for i := 0; i < 10; i++ { - list.Add(i) - } - for i := 9; i >= 0; i-- { - val := list.RemoveLast() - intVal, _ := val.(int) - if intVal != i { - t.Error("add test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal)) - } - } + list := Make() + for i := 0; i < 10; i++ { + list.Add(i) + } + for i := 9; i >= 0; i-- { + val := list.RemoveLast() + intVal, _ := val.(int) + if intVal != i { + t.Error("add test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal)) + } + } } func TestRange(t *testing.T) { - list := Make() - size := 10 - for i := 0; i < size; i++ { - list.Add(i) - } - for start := 0; start < size; start++ { - for stop := start; stop < size; stop++ { - slice := list.Range(start, stop) - if len(slice) != stop - start { - t.Error("expected " + strconv.Itoa(stop - start) + ", get: " + strconv.Itoa(len(slice)) + - ", range: [" + strconv.Itoa(start) + "," + strconv.Itoa(stop) + "]") - } - sliceIndex := 0 - for i := start; i < stop; i++ { - val := slice[sliceIndex] - intVal, _ := val.(int) - if intVal != i { - t.Error("expected " + strconv.Itoa(i) + ", get: " + strconv.Itoa(intVal) + - ", range: [" + strconv.Itoa(start) + "," + strconv.Itoa(stop) + "]") - } - sliceIndex++ - } - } - } -} \ No newline at end of file + list := Make() + size := 10 + for i := 0; i < size; i++ { + list.Add(i) + } + for start := 0; start < size; start++ { + for stop := start; stop < size; stop++ { + slice := list.Range(start, stop) + if len(slice) != stop-start { + t.Error("expected " + strconv.Itoa(stop-start) + ", get: " + strconv.Itoa(len(slice)) + + ", range: [" + strconv.Itoa(start) + "," + strconv.Itoa(stop) + "]") + } + sliceIndex := 0 + for i := start; i < stop; i++ { + val := slice[sliceIndex] + intVal, _ := val.(int) + if intVal != i { + t.Error("expected " + strconv.Itoa(i) + ", get: " + strconv.Itoa(intVal) + + ", range: [" + strconv.Itoa(start) + "," + strconv.Itoa(stop) + "]") + } + sliceIndex++ + } + } + } +} diff --git a/src/datastruct/lock/lock_map.go b/src/datastruct/lock/lock_map.go index 36a4e904..8457d3f0 100644 --- a/src/datastruct/lock/lock_map.go +++ b/src/datastruct/lock/lock_map.go @@ -1,153 +1,152 @@ package lock import ( - "fmt" - "runtime" - "sort" - "strconv" - "strings" - "sync" - "testing" - "time" + "fmt" + "runtime" + "sort" + "strconv" + "strings" + "sync" + "testing" + "time" ) const ( - prime32 = uint32(16777619) + prime32 = uint32(16777619) ) type Locks struct { - table []*sync.RWMutex + table []*sync.RWMutex } func Make(tableSize int) *Locks { - table := make([]*sync.RWMutex, tableSize) - for i := 0; i < tableSize; i++ { - table[i] = &sync.RWMutex{} - } - return &Locks{ - table: table, - } + table := make([]*sync.RWMutex, tableSize) + for i := 0; i < tableSize; i++ { + table[i] = &sync.RWMutex{} + } + return &Locks{ + table: table, + } } func fnv32(key string) uint32 { - hash := uint32(2166136261) - for i := 0; i < len(key); i++ { - hash *= prime32 - hash ^= uint32(key[i]) - } - return hash + hash := uint32(2166136261) + for i := 0; i < len(key); i++ { + hash *= prime32 + hash ^= uint32(key[i]) + } + return hash } func (locks *Locks) spread(hashCode uint32) uint32 { - if locks == nil { - panic("dict is nil") - } - tableSize := uint32(len(locks.table)) - return (tableSize - 1) & uint32(hashCode) + if locks == nil { + panic("dict is nil") + } + tableSize := uint32(len(locks.table)) + return (tableSize - 1) & uint32(hashCode) } -func (locks *Locks)Lock(key string) { - index := locks.spread(fnv32(key)) - mu := locks.table[index] - mu.Lock() +func (locks *Locks) Lock(key string) { + index := locks.spread(fnv32(key)) + mu := locks.table[index] + mu.Lock() } -func (locks *Locks)RLock(key string) { - index := locks.spread(fnv32(key)) - mu := locks.table[index] - mu.RLock() +func (locks *Locks) RLock(key string) { + index := locks.spread(fnv32(key)) + mu := locks.table[index] + mu.RLock() } -func (locks *Locks)UnLock(key string) { - index := locks.spread(fnv32(key)) - mu := locks.table[index] - mu.Unlock() +func (locks *Locks) UnLock(key string) { + index := locks.spread(fnv32(key)) + mu := locks.table[index] + mu.Unlock() } -func (locks *Locks)RUnLock(key string) { - index := locks.spread(fnv32(key)) - mu := locks.table[index] - mu.RUnlock() +func (locks *Locks) RUnLock(key string) { + index := locks.spread(fnv32(key)) + mu := locks.table[index] + mu.RUnlock() } func (locks *Locks) toLockIndices(keys []string, reverse bool) []uint32 { - indexMap := make(map[uint32]bool) - for _, key := range keys { - index := locks.spread(fnv32(key)) - indexMap[index] = true - } - indices := make([]uint32, 0, len(indexMap)) - for index := range indexMap { - indices = append(indices, index) - } - sort.Slice(indices, func(i, j int) bool { - if !reverse { - return indices[i] < indices[j] - } else { - return indices[i] > indices[j] - } - }) - return indices + indexMap := make(map[uint32]bool) + for _, key := range keys { + index := locks.spread(fnv32(key)) + indexMap[index] = true + } + indices := make([]uint32, 0, len(indexMap)) + for index := range indexMap { + indices = append(indices, index) + } + sort.Slice(indices, func(i, j int) bool { + if !reverse { + return indices[i] < indices[j] + } else { + return indices[i] > indices[j] + } + }) + return indices } -func (locks *Locks)Locks(keys ...string) { - indices := locks.toLockIndices(keys, false) - for _, index := range indices { - mu := locks.table[index] - mu.Lock() - } +func (locks *Locks) Locks(keys ...string) { + indices := locks.toLockIndices(keys, false) + for _, index := range indices { + mu := locks.table[index] + mu.Lock() + } } -func (locks *Locks)RLocks(keys ...string) { - indices := locks.toLockIndices(keys, false) - for _, index := range indices { - mu := locks.table[index] - mu.RLock() - } +func (locks *Locks) RLocks(keys ...string) { + indices := locks.toLockIndices(keys, false) + for _, index := range indices { + mu := locks.table[index] + mu.RLock() + } } - -func (locks *Locks)UnLocks(keys ...string) { - indices := locks.toLockIndices(keys, true) - for _, index := range indices { - mu := locks.table[index] - mu.Unlock() - } +func (locks *Locks) UnLocks(keys ...string) { + indices := locks.toLockIndices(keys, true) + for _, index := range indices { + mu := locks.table[index] + mu.Unlock() + } } -func (locks *Locks)RUnLocks(keys ...string) { - indices := locks.toLockIndices(keys, true) - for _, index := range indices { - mu := locks.table[index] - mu.RUnlock() - } +func (locks *Locks) RUnLocks(keys ...string) { + indices := locks.toLockIndices(keys, true) + for _, index := range indices { + mu := locks.table[index] + mu.RUnlock() + } } func GoID() int { - var buf [64]byte - n := runtime.Stack(buf[:], false) - idField := strings.Fields(strings.TrimPrefix(string(buf[:n]), "goroutine "))[0] - id, err := strconv.Atoi(idField) - if err != nil { - panic(fmt.Sprintf("cannot get goroutine id: %v", err)) - } - return id + var buf [64]byte + n := runtime.Stack(buf[:], false) + idField := strings.Fields(strings.TrimPrefix(string(buf[:n]), "goroutine "))[0] + id, err := strconv.Atoi(idField) + if err != nil { + panic(fmt.Sprintf("cannot get goroutine id: %v", err)) + } + return id } func debug(testing.T) { - lm := Locks{} - size := 10 - var wg sync.WaitGroup - wg.Add(size) - for i := 0; i < size; i++ { - go func(i int) { - lm.Locks("1", "2") - println("go: " + strconv.Itoa(GoID())) - time.Sleep(time.Second) - println("go: " + strconv.Itoa(GoID())) - lm.UnLocks("1", "2") - wg.Done() - }(i) - } - wg.Wait() -} \ No newline at end of file + lm := Locks{} + size := 10 + var wg sync.WaitGroup + wg.Add(size) + for i := 0; i < size; i++ { + go func(i int) { + lm.Locks("1", "2") + println("go: " + strconv.Itoa(GoID())) + time.Sleep(time.Second) + println("go: " + strconv.Itoa(GoID())) + lm.UnLocks("1", "2") + wg.Done() + }(i) + } + wg.Wait() +} diff --git a/src/datastruct/set/set_test.go b/src/datastruct/set/set_test.go index b8b85fb7..fd95aefc 100644 --- a/src/datastruct/set/set_test.go +++ b/src/datastruct/set/set_test.go @@ -1,32 +1,32 @@ package set import ( - "strconv" - "testing" + "strconv" + "testing" ) func TestSet(t *testing.T) { - size := 10 - set := Make() - for i := 0; i < size; i++ { - set.Add(strconv.Itoa(i)) - } - for i := 0; i < size; i++ { - ok := set.Has(strconv.Itoa(i)) - if !ok { - t.Error("expected true actual false, key: " + strconv.Itoa(i)) - } - } - for i := 0; i < size; i++ { - ok := set.Remove(strconv.Itoa(i)) - if ok != 1 { - t.Error("expected true actual false, key: " + strconv.Itoa(i)) - } - } - for i := 0; i < size; i++ { - ok := set.Has(strconv.Itoa(i)) - if ok { - t.Error("expected false actual true, key: " + strconv.Itoa(i)) - } - } -} \ No newline at end of file + size := 10 + set := Make() + for i := 0; i < size; i++ { + set.Add(strconv.Itoa(i)) + } + for i := 0; i < size; i++ { + ok := set.Has(strconv.Itoa(i)) + if !ok { + t.Error("expected true actual false, key: " + strconv.Itoa(i)) + } + } + for i := 0; i < size; i++ { + ok := set.Remove(strconv.Itoa(i)) + if ok != 1 { + t.Error("expected true actual false, key: " + strconv.Itoa(i)) + } + } + for i := 0; i < size; i++ { + ok := set.Has(strconv.Itoa(i)) + if ok { + t.Error("expected false actual true, key: " + strconv.Itoa(i)) + } + } +} diff --git a/src/datastruct/sortedset/border.go b/src/datastruct/sortedset/border.go index efa0741c..922ff375 100644 --- a/src/datastruct/sortedset/border.go +++ b/src/datastruct/sortedset/border.go @@ -1,8 +1,8 @@ package sortedset import ( - "errors" - "strconv" + "errors" + "strconv" ) /* @@ -14,78 +14,78 @@ import ( */ const ( - negativeInf int8 = -1 - positiveInf int8 = 1 + negativeInf int8 = -1 + positiveInf int8 = 1 ) type ScoreBorder struct { - Inf int8 - Value float64 - Exclude bool + Inf int8 + Value float64 + Exclude bool } // if max.greater(score) then the score is within the upper border // do not use min.greater() -func (border *ScoreBorder)greater(value float64)bool { - if border.Inf == negativeInf { - return false - } else if border.Inf == positiveInf { - return true - } - if border.Exclude { - return border.Value > value - } else { - return border.Value >= value - } +func (border *ScoreBorder) greater(value float64) bool { + if border.Inf == negativeInf { + return false + } else if border.Inf == positiveInf { + return true + } + if border.Exclude { + return border.Value > value + } else { + return border.Value >= value + } } -func (border *ScoreBorder)less(value float64)bool { - if border.Inf == negativeInf { - return true - } else if border.Inf == positiveInf { - return false - } - if border.Exclude { - return border.Value < value - } else { - return border.Value <= value - } +func (border *ScoreBorder) less(value float64) bool { + if border.Inf == negativeInf { + return true + } else if border.Inf == positiveInf { + return false + } + if border.Exclude { + return border.Value < value + } else { + return border.Value <= value + } } -var positiveInfBorder = &ScoreBorder { - Inf: positiveInf, +var positiveInfBorder = &ScoreBorder{ + Inf: positiveInf, } -var negativeInfBorder = &ScoreBorder { - Inf: negativeInf, +var negativeInfBorder = &ScoreBorder{ + Inf: negativeInf, } -func ParseScoreBorder(s string)(*ScoreBorder, error) { - if s == "inf" || s == "+inf" { - return positiveInfBorder, nil - } - if s == "-inf" { - return negativeInfBorder, nil - } - if s[0] == '(' { - value, err := strconv.ParseFloat(s[1:], 64) - if err != nil { - return nil, errors.New("ERR min or max is not a float") - } - return &ScoreBorder{ - Inf: 0, - Value: value, - Exclude: true, - }, nil - } else { - value, err := strconv.ParseFloat(s, 64) - if err != nil { - return nil, errors.New("ERR min or max is not a float") - } - return &ScoreBorder{ - Inf: 0, - Value: value, - Exclude: false, - }, nil - } -} \ No newline at end of file +func ParseScoreBorder(s string) (*ScoreBorder, error) { + if s == "inf" || s == "+inf" { + return positiveInfBorder, nil + } + if s == "-inf" { + return negativeInfBorder, nil + } + if s[0] == '(' { + value, err := strconv.ParseFloat(s[1:], 64) + if err != nil { + return nil, errors.New("ERR min or max is not a float") + } + return &ScoreBorder{ + Inf: 0, + Value: value, + Exclude: true, + }, nil + } else { + value, err := strconv.ParseFloat(s, 64) + if err != nil { + return nil, errors.New("ERR min or max is not a float") + } + return &ScoreBorder{ + Inf: 0, + Value: value, + Exclude: false, + }, nil + } +} diff --git a/src/datastruct/sortedset/skiplist.go b/src/datastruct/sortedset/skiplist.go index 0304c51f..6f04bcb2 100644 --- a/src/datastruct/sortedset/skiplist.go +++ b/src/datastruct/sortedset/skiplist.go @@ -3,130 +3,129 @@ package sortedset import "math/rand" const ( - maxLevel = 16 + maxLevel = 16 ) - type Element struct { - Member string - Score float64 + Member string + Score float64 } // level aspect of a Node type Level struct { - forward *Node // forward node has greater score - span int64 + forward *Node // forward node has greater score + span int64 } type Node struct { - Element - backward *Node - level []*Level // level[0] is base level + Element + backward *Node + level []*Level // level[0] is base level } type skiplist struct { - header *Node - tail *Node - length int64 - level int16 + header *Node + tail *Node + length int64 + level int16 } -func makeNode(level int16, score float64, member string)*Node { - n := &Node{ - Element: Element{ - Score: score, - Member: member, - }, - level: make([]*Level, level), - } - for i := range n.level { - n.level[i] = new(Level) - } - return n +func makeNode(level int16, score float64, member string) *Node { + n := &Node{ + Element: Element{ + Score: score, + Member: member, + }, + level: make([]*Level, level), + } + for i := range n.level { + n.level[i] = new(Level) + } + return n } -func makeSkiplist()*skiplist { - return &skiplist{ - level: 1, - header: makeNode(maxLevel, 0, ""), - } +func makeSkiplist() *skiplist { + return &skiplist{ + level: 1, + header: makeNode(maxLevel, 0, ""), + } } func randomLevel() int16 { - level := int16(1) - for float32(rand.Int31()&0xFFFF) < (0.25 * 0xFFFF) { - level++ - } - if level < maxLevel { - return level - } - return maxLevel + level := int16(1) + for float32(rand.Int31()&0xFFFF) < (0.25 * 0xFFFF) { + level++ + } + if level < maxLevel { + return level + } + return maxLevel } -func (skiplist *skiplist)insert(member string, score float64)*Node { - update := make([]*Node, maxLevel) // link new node with node in `update` - rank := make([]int64, maxLevel) - - // find position to insert - node := skiplist.header - for i := skiplist.level - 1; i >= 0; i-- { - if i == skiplist.level - 1 { - rank[i] = 0 - } else { - rank[i] = rank[i + 1] // store rank that is crossed to reach the insert position - } - if node.level[i] != nil { - // traverse the skip list - for node.level[i].forward != nil && - (node.level[i].forward.Score < score || - (node.level[i].forward.Score == score && node.level[i].forward.Member < member)) { // same score, different key - rank[i] += node.level[i].span - node = node.level[i].forward - } - } - update[i] = node - } - - level := randomLevel() - // extend skiplist level - if level > skiplist.level { - for i := skiplist.level; i < level; i++ { - rank[i] = 0 - update[i] = skiplist.header - update[i].level[i].span = skiplist.length - } - skiplist.level = level - } - - // make node and link into skiplist - node = makeNode(level, score, member) - for i := int16(0); i < level; i++ { - node.level[i].forward = update[i].level[i].forward - update[i].level[i].forward = node - - // update span covered by update[i] as node is inserted here - node.level[i].span = update[i].level[i].span - (rank[0] - rank[i]) - update[i].level[i].span = (rank[0] - rank[i]) + 1 - } - - // increment span for untouched levels - for i := level; i < skiplist.level; i++ { - update[i].level[i].span++ - } - - // set backward node - if update[0] == skiplist.header { - node.backward = nil - } else { - node.backward = update[0] - } - if node.level[0].forward != nil { - node.level[0].forward.backward = node - } else { - skiplist.tail = node - } - skiplist.length++ - return node +func (skiplist *skiplist) insert(member string, score float64) *Node { + update := make([]*Node, maxLevel) // link new node with node in `update` + rank := make([]int64, maxLevel) + + // find position to insert + node := skiplist.header + for i := skiplist.level - 1; i >= 0; i-- { + if i == skiplist.level-1 { + rank[i] = 0 + } else { + rank[i] = rank[i+1] // store rank that is crossed to reach the insert position + } + if node.level[i] != nil { + // traverse the skip list + for node.level[i].forward != nil && + (node.level[i].forward.Score < score || + (node.level[i].forward.Score == score && node.level[i].forward.Member < member)) { // same score, different key + rank[i] += node.level[i].span + node = node.level[i].forward + } + } + update[i] = node + } + + level := randomLevel() + // extend skiplist level + if level > skiplist.level { + for i := skiplist.level; i < level; i++ { + rank[i] = 0 + update[i] = skiplist.header + update[i].level[i].span = skiplist.length + } + skiplist.level = level + } + + // make node and link into skiplist + node = makeNode(level, score, member) + for i := int16(0); i < level; i++ { + node.level[i].forward = update[i].level[i].forward + update[i].level[i].forward = node + + // update span covered by update[i] as node is inserted here + node.level[i].span = update[i].level[i].span - (rank[0] - rank[i]) + update[i].level[i].span = (rank[0] - rank[i]) + 1 + } + + // increment span for untouched levels + for i := level; i < skiplist.level; i++ { + update[i].level[i].span++ + } + + // set backward node + if update[0] == skiplist.header { + node.backward = nil + } else { + node.backward = update[0] + } + if node.level[0].forward != nil { + node.level[0].forward.backward = node + } else { + skiplist.tail = node + } + skiplist.length++ + return node } /* @@ -134,212 +133,212 @@ func (skiplist *skiplist)insert(member string, score float64)*Node { * param update: backward node (of target) */ func (skiplist *skiplist) removeNode(node *Node, update []*Node) { - for i := int16(0); i < skiplist.level; i++ { - if update[i].level[i].forward == node { - update[i].level[i].span += node.level[i].span - 1 - update[i].level[i].forward = node.level[i].forward - } else { - update[i].level[i].span-- - } - } - if node.level[0].forward != nil { - node.level[0].forward.backward = node.backward - } else { - skiplist.tail = node.backward - } - for skiplist.level > 1 && skiplist.header.level[skiplist.level-1].forward == nil { - skiplist.level-- - } - skiplist.length-- + for i := int16(0); i < skiplist.level; i++ { + if update[i].level[i].forward == node { + update[i].level[i].span += node.level[i].span - 1 + update[i].level[i].forward = node.level[i].forward + } else { + update[i].level[i].span-- + } + } + if node.level[0].forward != nil { + node.level[0].forward.backward = node.backward + } else { + skiplist.tail = node.backward + } + for skiplist.level > 1 && skiplist.header.level[skiplist.level-1].forward == nil { + skiplist.level-- + } + skiplist.length-- } /* * return: has found and removed node */ -func (skiplist *skiplist) remove(member string, score float64)bool { - /* - * find backward node (of target) or last node of each level - * their forward need to be updated - */ - update := make([]*Node, maxLevel) - node := skiplist.header - for i := skiplist.level - 1; i >= 0; i-- { - for node.level[i].forward != nil && - (node.level[i].forward.Score < score || - (node.level[i].forward.Score == score && - node.level[i].forward.Member < member)) { - node = node.level[i].forward - } - update[i] = node - } - node = node.level[0].forward - if node != nil && score == node.Score && node.Member == member { - skiplist.removeNode(node, update) - // free x - return true - } - return false +func (skiplist *skiplist) remove(member string, score float64) bool { + /* + * find backward node (of target) or last node of each level + * their forward need to be updated + */ + update := make([]*Node, maxLevel) + node := skiplist.header + for i := skiplist.level - 1; i >= 0; i-- { + for node.level[i].forward != nil && + (node.level[i].forward.Score < score || + (node.level[i].forward.Score == score && + node.level[i].forward.Member < member)) { + node = node.level[i].forward + } + update[i] = node + } + node = node.level[0].forward + if node != nil && score == node.Score && node.Member == member { + skiplist.removeNode(node, update) + // free x + return true + } + return false } /* * return: 1 based rank, 0 means member not found */ -func (skiplist *skiplist) getRank(member string, score float64)int64 { - var rank int64 = 0 - x := skiplist.header - for i := skiplist.level - 1; i >= 0; i-- { - for x.level[i].forward != nil && - (x.level[i].forward.Score < score || - (x.level[i].forward.Score == score && - x.level[i].forward.Member <= member)) { - rank += x.level[i].span - x = x.level[i].forward - } - - /* x might be equal to zsl->header, so test if obj is non-NULL */ - if x.Member == member { - return rank - } - } - return 0 +func (skiplist *skiplist) getRank(member string, score float64) int64 { + var rank int64 = 0 + x := skiplist.header + for i := skiplist.level - 1; i >= 0; i-- { + for x.level[i].forward != nil && + (x.level[i].forward.Score < score || + (x.level[i].forward.Score == score && + x.level[i].forward.Member <= member)) { + rank += x.level[i].span + x = x.level[i].forward + } + + /* x might be equal to zsl->header, so test if obj is non-NULL */ + if x.Member == member { + return rank + } + } + return 0 } /* * 1-based rank */ -func (skiplist *skiplist) getByRank(rank int64)*Node { - var i int64 = 0 - n := skiplist.header - // scan from top level - for level := skiplist.level - 1; level >= 0; level-- { - for n.level[level].forward != nil && (i+n.level[level].span) <= rank { - i += n.level[level].span - n = n.level[level].forward - } - if i == rank { - return n - } - } - return nil +func (skiplist *skiplist) getByRank(rank int64) *Node { + var i int64 = 0 + n := skiplist.header + // scan from top level + for level := skiplist.level - 1; level >= 0; level-- { + for n.level[level].forward != nil && (i+n.level[level].span) <= rank { + i += n.level[level].span + n = n.level[level].forward + } + if i == rank { + return n + } + } + return nil } func (skiplist *skiplist) hasInRange(min *ScoreBorder, max *ScoreBorder) bool { - // min & max = empty - if min.Value > max.Value || (min.Value == max.Value && (min.Exclude || max.Exclude)) { - return false - } - // min > tail - n := skiplist.tail - if n == nil || !min.less(n.Score) { - return false - } - // max < head - n = skiplist.header.level[0].forward - if n == nil || !max.greater(n.Score) { - return false - } - return true + // min & max = empty + if min.Value > max.Value || (min.Value == max.Value && (min.Exclude || max.Exclude)) { + return false + } + // min > tail + n := skiplist.tail + if n == nil || !min.less(n.Score) { + return false + } + // max < head + n = skiplist.header.level[0].forward + if n == nil || !max.greater(n.Score) { + return false + } + return true } func (skiplist *skiplist) getFirstInScoreRange(min *ScoreBorder, max *ScoreBorder) *Node { - if !skiplist.hasInRange(min, max) { - return nil - } - n := skiplist.header - // scan from top level - for level := skiplist.level - 1; level >= 0; level-- { - // if forward is not in range than move forward - for n.level[level].forward != nil && !min.less(n.level[level].forward.Score) { - n = n.level[level].forward - } - } - /* This is an inner range, so the next node cannot be NULL. */ - n = n.level[0].forward - if !max.greater(n.Score) { - return nil - } - return n + if !skiplist.hasInRange(min, max) { + return nil + } + n := skiplist.header + // scan from top level + for level := skiplist.level - 1; level >= 0; level-- { + // if forward is not in range than move forward + for n.level[level].forward != nil && !min.less(n.level[level].forward.Score) { + n = n.level[level].forward + } + } + /* This is an inner range, so the next node cannot be NULL. */ + n = n.level[0].forward + if !max.greater(n.Score) { + return nil + } + return n } func (skiplist *skiplist) getLastInScoreRange(min *ScoreBorder, max *ScoreBorder) *Node { - if !skiplist.hasInRange(min, max) { - return nil - } - n := skiplist.header - // scan from top level - for level := skiplist.level - 1; level >= 0; level-- { - for n.level[level].forward != nil && max.greater(n.level[level].forward.Score) { - n = n.level[level].forward - } - } - if !min.less(n.Score) { - return nil - } - return n + if !skiplist.hasInRange(min, max) { + return nil + } + n := skiplist.header + // scan from top level + for level := skiplist.level - 1; level >= 0; level-- { + for n.level[level].forward != nil && max.greater(n.level[level].forward.Score) { + n = n.level[level].forward + } + } + if !min.less(n.Score) { + return nil + } + return n } /* * return removed elements */ -func (skiplist *skiplist) RemoveRangeByScore(min *ScoreBorder, max *ScoreBorder)(removed []*Element) { - update := make([]*Node, maxLevel) - removed = make([]*Element, 0) - // find backward nodes (of target range) or last node of each level - node := skiplist.header - for i := skiplist.level - 1; i >= 0; i-- { - for node.level[i].forward != nil { - if min.less(node.level[i].forward.Score) { // already in range - break - } - node = node.level[i].forward - } - update[i] = node - } - - // node is the first one within range - node = node.level[0].forward - - // remove nodes in range - for node != nil { - if !max.greater(node.Score) { // already out of range - break - } - next := node.level[0].forward - removedElement := node.Element - removed = append(removed, &removedElement) - skiplist.removeNode(node, update) - node = next - } - return removed +func (skiplist *skiplist) RemoveRangeByScore(min *ScoreBorder, max *ScoreBorder) (removed []*Element) { + update := make([]*Node, maxLevel) + removed = make([]*Element, 0) + // find backward nodes (of target range) or last node of each level + node := skiplist.header + for i := skiplist.level - 1; i >= 0; i-- { + for node.level[i].forward != nil { + if min.less(node.level[i].forward.Score) { // already in range + break + } + node = node.level[i].forward + } + update[i] = node + } + + // node is the first one within range + node = node.level[0].forward + + // remove nodes in range + for node != nil { + if !max.greater(node.Score) { // already out of range + break + } + next := node.level[0].forward + removedElement := node.Element + removed = append(removed, &removedElement) + skiplist.removeNode(node, update) + node = next + } + return removed } // 1-based rank, including start, exclude stop -func (skiplist *skiplist) RemoveRangeByRank(start int64, stop int64)(removed []*Element) { - var i int64 = 0 // rank of iterator - update := make([]*Node, maxLevel) - removed = make([]*Element, 0) - - // scan from top level - node := skiplist.header - for level := skiplist.level - 1; level >= 0; level-- { - for node.level[level].forward != nil && (i+node.level[level].span) < start { - i += node.level[level].span - node = node.level[level].forward - } - update[level] = node - } - - i++ - node = node.level[0].forward // first node in range - - // remove nodes in range - for node != nil && i < stop { - next := node.level[0].forward - removedElement := node.Element - removed = append(removed, &removedElement) - skiplist.removeNode(node, update) - node = next - i++ - } - return removed +func (skiplist *skiplist) RemoveRangeByRank(start int64, stop int64) (removed []*Element) { + var i int64 = 0 // rank of iterator + update := make([]*Node, maxLevel) + removed = make([]*Element, 0) + + // scan from top level + node := skiplist.header + for level := skiplist.level - 1; level >= 0; level-- { + for node.level[level].forward != nil && (i+node.level[level].span) < start { + i += node.level[level].span + node = node.level[level].forward + } + update[level] = node + } + + i++ + node = node.level[0].forward // first node in range + + // remove nodes in range + for node != nil && i < stop { + next := node.level[0].forward + removedElement := node.Element + removed = append(removed, &removedElement) + skiplist.removeNode(node, update) + node = next + i++ + } + return removed } diff --git a/src/datastruct/sortedset/sortedset.go b/src/datastruct/sortedset/sortedset.go index 50e4bf73..e9d51beb 100644 --- a/src/datastruct/sortedset/sortedset.go +++ b/src/datastruct/sortedset/sortedset.go @@ -1,227 +1,226 @@ package sortedset import ( - "strconv" + "strconv" ) type SortedSet struct { - dict map[string]*Element - skiplist *skiplist + dict map[string]*Element + skiplist *skiplist } -func Make()*SortedSet { - return &SortedSet{ - dict: make(map[string]*Element), - skiplist: makeSkiplist(), - } +func Make() *SortedSet { + return &SortedSet{ + dict: make(map[string]*Element), + skiplist: makeSkiplist(), + } } /* * return: has inserted new node */ -func (sortedSet *SortedSet)Add(member string, score float64)bool { - element, ok := sortedSet.dict[member] - sortedSet.dict[member] = &Element{ - Member: member, - Score: score, - } - if ok { - if score != element.Score { - sortedSet.skiplist.remove(member, score) - sortedSet.skiplist.insert(member, score) - } - return false - } else { - sortedSet.skiplist.insert(member, score) - return true - } +func (sortedSet *SortedSet) Add(member string, score float64) bool { + element, ok := sortedSet.dict[member] + sortedSet.dict[member] = &Element{ + Member: member, + Score: score, + } + if ok { + if score != element.Score { + sortedSet.skiplist.remove(member, score) + sortedSet.skiplist.insert(member, score) + } + return false + } else { + sortedSet.skiplist.insert(member, score) + return true + } } -func (sortedSet *SortedSet) Len()int64 { - return int64(len(sortedSet.dict)) +func (sortedSet *SortedSet) Len() int64 { + return int64(len(sortedSet.dict)) } func (sortedSet *SortedSet) Get(member string) (element *Element, ok bool) { - element, ok = sortedSet.dict[member] - if !ok { - return nil, false - } - return element, true + element, ok = sortedSet.dict[member] + if !ok { + return nil, false + } + return element, true } -func (sortedSet *SortedSet) Remove(member string)bool { - v, ok := sortedSet.dict[member] - if ok { - sortedSet.skiplist.remove(member, v.Score) - delete(sortedSet.dict, member) - return true - } - return false +func (sortedSet *SortedSet) Remove(member string) bool { + v, ok := sortedSet.dict[member] + if ok { + sortedSet.skiplist.remove(member, v.Score) + delete(sortedSet.dict, member) + return true + } + return false } /** * get 0-based rank */ func (sortedSet *SortedSet) GetRank(member string, desc bool) (rank int64) { - element, ok := sortedSet.dict[member] - if !ok { - return -1 - } - r := sortedSet.skiplist.getRank(member, element.Score) - if desc { - r = sortedSet.skiplist.length - r - } else { - r-- - } - return r + element, ok := sortedSet.dict[member] + if !ok { + return -1 + } + r := sortedSet.skiplist.getRank(member, element.Score) + if desc { + r = sortedSet.skiplist.length - r + } else { + r-- + } + return r } /** * traverse [start, stop), 0-based rank */ -func (sortedSet *SortedSet) ForEach(start int64, stop int64, desc bool, consumer func(element *Element)bool) { - size := int64(sortedSet.Len()) - if start < 0 || start >= size { - panic("illegal start " + strconv.FormatInt(start, 10)) - } - if stop < start || stop > size { - panic("illegal end " + strconv.FormatInt(stop, 10)) - } - - // find start node - var node *Node - if desc { - node = sortedSet.skiplist.tail - if start > 0 { - node = sortedSet.skiplist.getByRank(int64(size - start)) - } - } else { - node = sortedSet.skiplist.header.level[0].forward - if start > 0 { - node = sortedSet.skiplist.getByRank(int64(start + 1)) - } - } - - sliceSize := int(stop - start) - for i := 0; i < sliceSize; i++ { - if !consumer(&node.Element) { - break - } - if desc { - node = node.backward - } else { - node = node.level[0].forward - } - } +func (sortedSet *SortedSet) ForEach(start int64, stop int64, desc bool, consumer func(element *Element) bool) { + size := int64(sortedSet.Len()) + if start < 0 || start >= size { + panic("illegal start " + strconv.FormatInt(start, 10)) + } + if stop < start || stop > size { + panic("illegal end " + strconv.FormatInt(stop, 10)) + } + + // find start node + var node *Node + if desc { + node = sortedSet.skiplist.tail + if start > 0 { + node = sortedSet.skiplist.getByRank(int64(size - start)) + } + } else { + node = sortedSet.skiplist.header.level[0].forward + if start > 0 { + node = sortedSet.skiplist.getByRank(int64(start + 1)) + } + } + + sliceSize := int(stop - start) + for i := 0; i < sliceSize; i++ { + if !consumer(&node.Element) { + break + } + if desc { + node = node.backward + } else { + node = node.level[0].forward + } + } } /** * return [start, stop), 0-based rank * assert start in [0, size), stop in [start, size] */ -func (sortedSet *SortedSet) Range(start int64, stop int64, desc bool)[]*Element { - sliceSize := int(stop - start) - slice := make([]*Element, sliceSize) - i := 0 - sortedSet.ForEach(start, stop, desc, func(element *Element)bool { - slice[i] = element - i++ - return true - }) - return slice +func (sortedSet *SortedSet) Range(start int64, stop int64, desc bool) []*Element { + sliceSize := int(stop - start) + slice := make([]*Element, sliceSize) + i := 0 + sortedSet.ForEach(start, stop, desc, func(element *Element) bool { + slice[i] = element + i++ + return true + }) + return slice } -func (sortedSet *SortedSet) Count(min *ScoreBorder, max *ScoreBorder)int64 { - var i int64 = 0 - // ascending order - sortedSet.ForEach(0, sortedSet.Len(), false, func(element *Element) bool { - gtMin := min.less(element.Score) // greater than min - if !gtMin { - // has not into range, continue foreach - return true - } - ltMax := max.greater(element.Score) // less than max - if !ltMax { - // break through score border, break foreach - return false - } - // gtMin && ltMax - i++ - return true - }) - return i +func (sortedSet *SortedSet) Count(min *ScoreBorder, max *ScoreBorder) int64 { + var i int64 = 0 + // ascending order + sortedSet.ForEach(0, sortedSet.Len(), false, func(element *Element) bool { + gtMin := min.less(element.Score) // greater than min + if !gtMin { + // has not into range, continue foreach + return true + } + ltMax := max.greater(element.Score) // less than max + if !ltMax { + // break through score border, break foreach + return false + } + // gtMin && ltMax + i++ + return true + }) + return i } func (sortedSet *SortedSet) ForEachByScore(min *ScoreBorder, max *ScoreBorder, offset int64, limit int64, desc bool, consumer func(element *Element) bool) { - // find start node - var node *Node - if desc { - node = sortedSet.skiplist.getLastInScoreRange(min, max) - } else { - node = sortedSet.skiplist.getFirstInScoreRange(min, max) - } - - for node != nil && offset > 0 { - if desc { - node = node.backward - } else { - node = node.level[0].forward - } - offset-- - } - - // A negative limit returns all elements from the offset - for i := 0; (i < int(limit) || limit < 0) && node != nil; i++ { - if !consumer(&node.Element) { - break - } - if desc { - node = node.backward - } else { - node = node.level[0].forward - } - if node == nil { - break - } - gtMin := min.less(node.Element.Score) // greater than min - ltMax := max.greater(node.Element.Score) - if !gtMin || !ltMax { - break // break through score border - } - } + // find start node + var node *Node + if desc { + node = sortedSet.skiplist.getLastInScoreRange(min, max) + } else { + node = sortedSet.skiplist.getFirstInScoreRange(min, max) + } + + for node != nil && offset > 0 { + if desc { + node = node.backward + } else { + node = node.level[0].forward + } + offset-- + } + + // A negative limit returns all elements from the offset + for i := 0; (i < int(limit) || limit < 0) && node != nil; i++ { + if !consumer(&node.Element) { + break + } + if desc { + node = node.backward + } else { + node = node.level[0].forward + } + if node == nil { + break + } + gtMin := min.less(node.Element.Score) // greater than min + ltMax := max.greater(node.Element.Score) + if !gtMin || !ltMax { + break // break through score border + } + } } /* * param limit: <0 means no limit */ -func (sortedSet *SortedSet) RangeByScore(min *ScoreBorder, max *ScoreBorder, offset int64, limit int64, desc bool)[]*Element { - if limit == 0 || offset < 0{ - return make([]*Element, 0) - } - slice := make([]*Element, 0) - sortedSet.ForEachByScore(min, max, offset, limit, desc, func(element *Element) bool { - slice = append(slice, element) - return true - }) - return slice +func (sortedSet *SortedSet) RangeByScore(min *ScoreBorder, max *ScoreBorder, offset int64, limit int64, desc bool) []*Element { + if limit == 0 || offset < 0 { + return make([]*Element, 0) + } + slice := make([]*Element, 0) + sortedSet.ForEachByScore(min, max, offset, limit, desc, func(element *Element) bool { + slice = append(slice, element) + return true + }) + return slice } -func (sortedSet *SortedSet) RemoveByScore(min *ScoreBorder, max *ScoreBorder)int64 { - removed := sortedSet.skiplist.RemoveRangeByScore(min, max) - for _, element := range removed { - delete(sortedSet.dict, element.Member) - } - return int64(len(removed)) +func (sortedSet *SortedSet) RemoveByScore(min *ScoreBorder, max *ScoreBorder) int64 { + removed := sortedSet.skiplist.RemoveRangeByScore(min, max) + for _, element := range removed { + delete(sortedSet.dict, element.Member) + } + return int64(len(removed)) } - /* * 0-based rank, [start, stop) */ -func (sortedSet *SortedSet) RemoveByRank(start int64, stop int64)int64 { - removed := sortedSet.skiplist.RemoveRangeByRank(start + 1, stop + 1) - for _, element := range removed { - delete(sortedSet.dict, element.Member) - } - return int64(len(removed)) -} \ No newline at end of file +func (sortedSet *SortedSet) RemoveByRank(start int64, stop int64) int64 { + removed := sortedSet.skiplist.RemoveRangeByRank(start+1, stop+1) + for _, element := range removed { + delete(sortedSet.dict, element.Member) + } + return int64(len(removed)) +} diff --git a/src/datastruct/utils/utils.go b/src/datastruct/utils/utils.go index 5c6b5e50..68e520cd 100644 --- a/src/datastruct/utils/utils.go +++ b/src/datastruct/utils/utils.go @@ -1,28 +1,28 @@ package utils -func Equals(a interface{}, b interface{})bool { - sliceA, okA := a.([]byte) - sliceB, okB := b.([]byte) - if okA && okB { - return BytesEquals(sliceA, sliceB) - } - return a == b +func Equals(a interface{}, b interface{}) bool { + sliceA, okA := a.([]byte) + sliceB, okB := b.([]byte) + if okA && okB { + return BytesEquals(sliceA, sliceB) + } + return a == b } func BytesEquals(a []byte, b []byte) bool { - if (a == nil && b != nil) || (a != nil && b == nil) { - return false - } - if len(a) != len(b) { - return false - } - size := len(a) - for i := 0; i < size; i++ { - av := a[i] - bv := b[i] - if av != bv { - return false - } - } - return true + if (a == nil && b != nil) || (a != nil && b == nil) { + return false + } + if len(a) != len(b) { + return false + } + size := len(a) + for i := 0; i < size; i++ { + av := a[i] + bv := b[i] + if av != bv { + return false + } + } + return true } diff --git a/src/db/aof.go b/src/db/aof.go index 81d0e393..bfb1ce4c 100644 --- a/src/db/aof.go +++ b/src/db/aof.go @@ -2,7 +2,7 @@ package db import ( "bufio" - "github.com/HDT3213/godis/src/config" + "github.com/HDT3213/godis/src/config" "github.com/HDT3213/godis/src/datastruct/dict" List "github.com/HDT3213/godis/src/datastruct/list" "github.com/HDT3213/godis/src/datastruct/lock" @@ -29,18 +29,18 @@ func makeExpireCmd(key string, expireAt time.Time) *reply.MultiBulkReply { } func makeAofCmd(cmd string, args [][]byte) *reply.MultiBulkReply { - params := make([][]byte, len(args)+1) - copy(params[1:], args) - params[0] = []byte(cmd) - return reply.MakeMultiBulkReply(params) + params := make([][]byte, len(args)+1) + copy(params[1:], args) + params[0] = []byte(cmd) + return reply.MakeMultiBulkReply(params) } // send command to aof func (db *DB) AddAof(args *reply.MultiBulkReply) { - // aofChan == nil when loadAof - if config.Properties.AppendOnly && db.aofChan != nil { - db.aofChan <- args - } + // aofChan == nil when loadAof + if config.Properties.AppendOnly && db.aofChan != nil { + db.aofChan <- args + } } // listen aof channel and write into file @@ -72,12 +72,12 @@ func trim(msg []byte) string { // read aof file func (db *DB) loadAof(maxBytes int) { - // delete aofChan to prevent write again - aofChan := db.aofChan - db.aofChan = nil - defer func(aofChan chan *reply.MultiBulkReply) { - db.aofChan = aofChan - }(aofChan) + // delete aofChan to prevent write again + aofChan := db.aofChan + db.aofChan = nil + defer func(aofChan chan *reply.MultiBulkReply) { + db.aofChan = aofChan + }(aofChan) file, err := os.Open(db.aofFilename) if err != nil { diff --git a/src/db/db.go b/src/db/db.go index b84a230c..3de2f047 100644 --- a/src/db/db.go +++ b/src/db/db.go @@ -1,297 +1,297 @@ package db import ( - "fmt" - "github.com/HDT3213/godis/src/config" - "github.com/HDT3213/godis/src/datastruct/dict" - List "github.com/HDT3213/godis/src/datastruct/list" - "github.com/HDT3213/godis/src/datastruct/lock" - "github.com/HDT3213/godis/src/interface/redis" - "github.com/HDT3213/godis/src/lib/logger" - "github.com/HDT3213/godis/src/pubsub" - "github.com/HDT3213/godis/src/redis/reply" - "os" - "runtime/debug" - "strings" - "sync" - "time" + "fmt" + "github.com/HDT3213/godis/src/config" + "github.com/HDT3213/godis/src/datastruct/dict" + List "github.com/HDT3213/godis/src/datastruct/list" + "github.com/HDT3213/godis/src/datastruct/lock" + "github.com/HDT3213/godis/src/interface/redis" + "github.com/HDT3213/godis/src/lib/logger" + "github.com/HDT3213/godis/src/pubsub" + "github.com/HDT3213/godis/src/redis/reply" + "os" + "runtime/debug" + "strings" + "sync" + "time" ) type DataEntity struct { - Data interface{} + Data interface{} } const ( - dataDictSize = 1 << 16 - ttlDictSize = 1 << 10 - lockerSize = 128 - aofQueueSize = 1 << 16 + dataDictSize = 1 << 16 + ttlDictSize = 1 << 10 + lockerSize = 128 + aofQueueSize = 1 << 16 ) // args don't include cmd line type CmdFunc func(db *DB, args [][]byte) redis.Reply type DB struct { - // key -> DataEntity - Data dict.Dict - // key -> expireTime (time.Time) - TTLMap dict.Dict - // channel -> list<*client> - SubMap dict.Dict + // key -> DataEntity + Data dict.Dict + // key -> expireTime (time.Time) + TTLMap dict.Dict + // channel -> list<*client> + SubMap dict.Dict - // dict will ensure thread safety of its method - // use this mutex for complicated command only, eg. rpush, incr ... - Locker *lock.Locks + // dict will ensure thread safety of its method + // use this mutex for complicated command only, eg. rpush, incr ... + Locker *lock.Locks - // TimerTask interval - interval time.Duration + // TimerTask interval + interval time.Duration - stopWorld sync.WaitGroup + stopWorld sync.WaitGroup - hub *pubsub.Hub + hub *pubsub.Hub - // main goroutine send commands to aof goroutine through aofChan - aofChan chan *reply.MultiBulkReply - aofFile *os.File - aofFilename string + // main goroutine send commands to aof goroutine through aofChan + aofChan chan *reply.MultiBulkReply + aofFile *os.File + aofFilename string - aofRewriteChan chan *reply.MultiBulkReply - pausingAof sync.RWMutex + aofRewriteChan chan *reply.MultiBulkReply + pausingAof sync.RWMutex } var router = MakeRouter() func MakeDB() *DB { - db := &DB{ - Data: dict.MakeConcurrent(dataDictSize), - TTLMap: dict.MakeConcurrent(ttlDictSize), - Locker: lock.Make(lockerSize), - interval: 5 * time.Second, - hub: pubsub.MakeHub(), - } - - // aof - if config.Properties.AppendOnly { - db.aofFilename = config.Properties.AppendFilename - db.loadAof(0) - aofFile, err := os.OpenFile(db.aofFilename, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0600) - if err != nil { - logger.Warn(err) - } else { - db.aofFile = aofFile - db.aofChan = make(chan *reply.MultiBulkReply, aofQueueSize) - } - go func() { - db.handleAof() - }() - } - - // start timer - db.TimerTask() - return db + db := &DB{ + Data: dict.MakeConcurrent(dataDictSize), + TTLMap: dict.MakeConcurrent(ttlDictSize), + Locker: lock.Make(lockerSize), + interval: 5 * time.Second, + hub: pubsub.MakeHub(), + } + + // aof + if config.Properties.AppendOnly { + db.aofFilename = config.Properties.AppendFilename + db.loadAof(0) + aofFile, err := os.OpenFile(db.aofFilename, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0600) + if err != nil { + logger.Warn(err) + } else { + db.aofFile = aofFile + db.aofChan = make(chan *reply.MultiBulkReply, aofQueueSize) + } + go func() { + db.handleAof() + }() + } + + // start timer + db.TimerTask() + return db } func (db *DB) Close() { - if db.aofFile != nil { - err := db.aofFile.Close() - if err != nil { - logger.Warn(err) - } - } + if db.aofFile != nil { + err := db.aofFile.Close() + if err != nil { + logger.Warn(err) + } + } } func (db *DB) Exec(c redis.Connection, args [][]byte) (result redis.Reply) { - defer func() { - if err := recover(); err != nil { - logger.Warn(fmt.Sprintf("error occurs: %v\n%s", err, string(debug.Stack()))) - result = &reply.UnknownErrReply{} - } - }() - - cmd := strings.ToLower(string(args[0])) - - // special commands - if cmd == "subscribe" { - if len(args) < 2 { - return &reply.ArgNumErrReply{Cmd: "subscribe"} - } - return pubsub.Subscribe(db.hub, c, args[1:]) - } else if cmd == "publish" { - return pubsub.Publish(db.hub, args[1:]) - } else if cmd == "unsubscribe" { - return pubsub.UnSubscribe(db.hub, c, args[1:]) - } else if cmd == "bgrewriteaof" { - // aof.go imports router.go, router.go cannot import BGRewriteAOF from aof.go - reply := BGRewriteAOF(db, args[1:]) - return reply - } - - // normal commands - cmdFunc, ok := router[cmd] - if !ok { - return reply.MakeErrReply("ERR unknown command '" + cmd + "'") - } - if len(args) > 1 { - result = cmdFunc(db, args[1:]) - } else { - result = cmdFunc(db, [][]byte{}) - } - - // aof - - return + defer func() { + if err := recover(); err != nil { + logger.Warn(fmt.Sprintf("error occurs: %v\n%s", err, string(debug.Stack()))) + result = &reply.UnknownErrReply{} + } + }() + + cmd := strings.ToLower(string(args[0])) + + // special commands + if cmd == "subscribe" { + if len(args) < 2 { + return &reply.ArgNumErrReply{Cmd: "subscribe"} + } + return pubsub.Subscribe(db.hub, c, args[1:]) + } else if cmd == "publish" { + return pubsub.Publish(db.hub, args[1:]) + } else if cmd == "unsubscribe" { + return pubsub.UnSubscribe(db.hub, c, args[1:]) + } else if cmd == "bgrewriteaof" { + // aof.go imports router.go, router.go cannot import BGRewriteAOF from aof.go + reply := BGRewriteAOF(db, args[1:]) + return reply + } + + // normal commands + cmdFunc, ok := router[cmd] + if !ok { + return reply.MakeErrReply("ERR unknown command '" + cmd + "'") + } + if len(args) > 1 { + result = cmdFunc(db, args[1:]) + } else { + result = cmdFunc(db, [][]byte{}) + } + + // aof + + return } /* ---- Data Access ----- */ func (db *DB) Get(key string) (*DataEntity, bool) { - db.stopWorld.Wait() - - raw, ok := db.Data.Get(key) - if !ok { - return nil, false - } - if db.IsExpired(key) { - return nil, false - } - entity, _ := raw.(*DataEntity) - return entity, true + db.stopWorld.Wait() + + raw, ok := db.Data.Get(key) + if !ok { + return nil, false + } + if db.IsExpired(key) { + return nil, false + } + entity, _ := raw.(*DataEntity) + return entity, true } func (db *DB) Put(key string, entity *DataEntity) int { - db.stopWorld.Wait() - return db.Data.Put(key, entity) + db.stopWorld.Wait() + return db.Data.Put(key, entity) } func (db *DB) PutIfExists(key string, entity *DataEntity) int { - db.stopWorld.Wait() - return db.Data.PutIfExists(key, entity) + db.stopWorld.Wait() + return db.Data.PutIfExists(key, entity) } func (db *DB) PutIfAbsent(key string, entity *DataEntity) int { - db.stopWorld.Wait() - return db.Data.PutIfAbsent(key, entity) + db.stopWorld.Wait() + return db.Data.PutIfAbsent(key, entity) } func (db *DB) Remove(key string) { - db.stopWorld.Wait() - db.Data.Remove(key) - db.TTLMap.Remove(key) + db.stopWorld.Wait() + db.Data.Remove(key) + db.TTLMap.Remove(key) } func (db *DB) Removes(keys ...string) (deleted int) { - db.stopWorld.Wait() - deleted = 0 - for _, key := range keys { - _, exists := db.Data.Get(key) - if exists { - db.Data.Remove(key) - db.TTLMap.Remove(key) - deleted++ - } - } - return deleted + db.stopWorld.Wait() + deleted = 0 + for _, key := range keys { + _, exists := db.Data.Get(key) + if exists { + db.Data.Remove(key) + db.TTLMap.Remove(key) + deleted++ + } + } + return deleted } func (db *DB) Flush() { - db.stopWorld.Add(1) - defer db.stopWorld.Done() + db.stopWorld.Add(1) + defer db.stopWorld.Done() - db.Data = dict.MakeConcurrent(dataDictSize) - db.TTLMap = dict.MakeConcurrent(ttlDictSize) - db.Locker = lock.Make(lockerSize) + db.Data = dict.MakeConcurrent(dataDictSize) + db.TTLMap = dict.MakeConcurrent(ttlDictSize) + db.Locker = lock.Make(lockerSize) } /* ---- Lock Function ----- */ func (db *DB) Lock(key string) { - db.Locker.Lock(key) + db.Locker.Lock(key) } func (db *DB) RLock(key string) { - db.Locker.RLock(key) + db.Locker.RLock(key) } func (db *DB) UnLock(key string) { - db.Locker.UnLock(key) + db.Locker.UnLock(key) } func (db *DB) RUnLock(key string) { - db.Locker.RUnLock(key) + db.Locker.RUnLock(key) } func (db *DB) Locks(keys ...string) { - db.Locker.Locks(keys...) + db.Locker.Locks(keys...) } func (db *DB) RLocks(keys ...string) { - db.Locker.RLocks(keys...) + db.Locker.RLocks(keys...) } func (db *DB) UnLocks(keys ...string) { - db.Locker.UnLocks(keys...) + db.Locker.UnLocks(keys...) } func (db *DB) RUnLocks(keys ...string) { - db.Locker.RUnLocks(keys...) + db.Locker.RUnLocks(keys...) } /* ---- TTL Functions ---- */ func (db *DB) Expire(key string, expireTime time.Time) { - db.stopWorld.Wait() - db.TTLMap.Put(key, expireTime) + db.stopWorld.Wait() + db.TTLMap.Put(key, expireTime) } func (db *DB) Persist(key string) { - db.stopWorld.Wait() - db.TTLMap.Remove(key) + db.stopWorld.Wait() + db.TTLMap.Remove(key) } func (db *DB) IsExpired(key string) bool { - rawExpireTime, ok := db.TTLMap.Get(key) - if !ok { - return false - } - expireTime, _ := rawExpireTime.(time.Time) - expired := time.Now().After(expireTime) - if expired { - db.Remove(key) - } - return expired + rawExpireTime, ok := db.TTLMap.Get(key) + if !ok { + return false + } + expireTime, _ := rawExpireTime.(time.Time) + expired := time.Now().After(expireTime) + if expired { + db.Remove(key) + } + return expired } func (db *DB) CleanExpired() { - now := time.Now() - toRemove := &List.LinkedList{} - db.TTLMap.ForEach(func(key string, val interface{}) bool { - expireTime, _ := val.(time.Time) - if now.After(expireTime) { - // expired - db.Data.Remove(key) - toRemove.Add(key) - } - return true - }) - toRemove.ForEach(func(i int, val interface{}) bool { - key, _ := val.(string) - db.TTLMap.Remove(key) - return true - }) + now := time.Now() + toRemove := &List.LinkedList{} + db.TTLMap.ForEach(func(key string, val interface{}) bool { + expireTime, _ := val.(time.Time) + if now.After(expireTime) { + // expired + db.Data.Remove(key) + toRemove.Add(key) + } + return true + }) + toRemove.ForEach(func(i int, val interface{}) bool { + key, _ := val.(string) + db.TTLMap.Remove(key) + return true + }) } func (db *DB) TimerTask() { - ticker := time.NewTicker(db.interval) - go func() { - for range ticker.C { - db.CleanExpired() - } - }() + ticker := time.NewTicker(db.interval) + go func() { + for range ticker.C { + db.CleanExpired() + } + }() } /* ---- Subscribe Functions ---- */ func (db *DB) AfterClientClose(c redis.Connection) { - pubsub.UnsubscribeAll(db.hub, c) + pubsub.UnsubscribeAll(db.hub, c) } diff --git a/src/db/hash.go b/src/db/hash.go index 1230117e..51ab85af 100644 --- a/src/db/hash.go +++ b/src/db/hash.go @@ -1,422 +1,422 @@ package db import ( - Dict "github.com/HDT3213/godis/src/datastruct/dict" - "github.com/HDT3213/godis/src/interface/redis" - "github.com/HDT3213/godis/src/redis/reply" - "github.com/shopspring/decimal" - "strconv" + Dict "github.com/HDT3213/godis/src/datastruct/dict" + "github.com/HDT3213/godis/src/interface/redis" + "github.com/HDT3213/godis/src/redis/reply" + "github.com/shopspring/decimal" + "strconv" ) func (db *DB) getAsDict(key string) (Dict.Dict, reply.ErrorReply) { - entity, exists := db.Get(key) - if !exists { - return nil, nil - } - dict, ok := entity.Data.(Dict.Dict) - if !ok { - return nil, &reply.WrongTypeErrReply{} - } - return dict, nil + entity, exists := db.Get(key) + if !exists { + return nil, nil + } + dict, ok := entity.Data.(Dict.Dict) + if !ok { + return nil, &reply.WrongTypeErrReply{} + } + return dict, nil } func (db *DB) getOrInitDict(key string) (dict Dict.Dict, inited bool, errReply reply.ErrorReply) { - dict, errReply = db.getAsDict(key) - if errReply != nil { - return nil, false, errReply - } - inited = false - if dict == nil { - dict = Dict.MakeSimple() - db.Put(key, &DataEntity{ - Data: dict, - }) - inited = true - } - return dict, inited, nil + dict, errReply = db.getAsDict(key) + if errReply != nil { + return nil, false, errReply + } + inited = false + if dict == nil { + dict = Dict.MakeSimple() + db.Put(key, &DataEntity{ + Data: dict, + }) + inited = true + } + return dict, inited, nil } func HSet(db *DB, args [][]byte) redis.Reply { - // parse args - if len(args) != 3 { - return reply.MakeErrReply("ERR wrong number of arguments for 'hset' command") - } - key := string(args[0]) - field := string(args[1]) - value := args[2] - - // lock - db.Lock(key) - defer db.UnLock(key) - - // get or init entity - dict, _, errReply := db.getOrInitDict(key) - if errReply != nil { - return errReply - } - - result := dict.Put(field, value) - db.AddAof(makeAofCmd("hset", args)) - return reply.MakeIntReply(int64(result)) + // parse args + if len(args) != 3 { + return reply.MakeErrReply("ERR wrong number of arguments for 'hset' command") + } + key := string(args[0]) + field := string(args[1]) + value := args[2] + + // lock + db.Lock(key) + defer db.UnLock(key) + + // get or init entity + dict, _, errReply := db.getOrInitDict(key) + if errReply != nil { + return errReply + } + + result := dict.Put(field, value) + db.AddAof(makeAofCmd("hset", args)) + return reply.MakeIntReply(int64(result)) } func HSetNX(db *DB, args [][]byte) redis.Reply { - // parse args - if len(args) != 3 { - return reply.MakeErrReply("ERR wrong number of arguments for 'hsetnx' command") - } - key := string(args[0]) - field := string(args[1]) - value := args[2] - - db.Lock(key) - defer db.UnLock(key) - - dict, _, errReply := db.getOrInitDict(key) - if errReply != nil { - return errReply - } - - result := dict.PutIfAbsent(field, value) - if result > 0 { - db.AddAof(makeAofCmd("hsetnx", args)) - - } - return reply.MakeIntReply(int64(result)) + // parse args + if len(args) != 3 { + return reply.MakeErrReply("ERR wrong number of arguments for 'hsetnx' command") + } + key := string(args[0]) + field := string(args[1]) + value := args[2] + + db.Lock(key) + defer db.UnLock(key) + + dict, _, errReply := db.getOrInitDict(key) + if errReply != nil { + return errReply + } + + result := dict.PutIfAbsent(field, value) + if result > 0 { + db.AddAof(makeAofCmd("hsetnx", args)) + + } + return reply.MakeIntReply(int64(result)) } func HGet(db *DB, args [][]byte) redis.Reply { - // parse args - if len(args) != 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'hget' command") - } - key := string(args[0]) - field := string(args[1]) - - // get entity - dict, errReply := db.getAsDict(key) - if errReply != nil { - return errReply - } - if dict == nil { - return &reply.NullBulkReply{} - } - - raw, exists := dict.Get(field) - if !exists { - return &reply.NullBulkReply{} - } - value, _ := raw.([]byte) - return reply.MakeBulkReply(value) + // parse args + if len(args) != 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'hget' command") + } + key := string(args[0]) + field := string(args[1]) + + // get entity + dict, errReply := db.getAsDict(key) + if errReply != nil { + return errReply + } + if dict == nil { + return &reply.NullBulkReply{} + } + + raw, exists := dict.Get(field) + if !exists { + return &reply.NullBulkReply{} + } + value, _ := raw.([]byte) + return reply.MakeBulkReply(value) } func HExists(db *DB, args [][]byte) redis.Reply { - // parse args - if len(args) != 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'hexists' command") - } - key := string(args[0]) - field := string(args[1]) - - // get entity - dict, errReply := db.getAsDict(key) - if errReply != nil { - return errReply - } - if dict == nil { - return reply.MakeIntReply(0) - } - - _, exists := dict.Get(field) - if exists { - return reply.MakeIntReply(1) - } - return reply.MakeIntReply(0) + // parse args + if len(args) != 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'hexists' command") + } + key := string(args[0]) + field := string(args[1]) + + // get entity + dict, errReply := db.getAsDict(key) + if errReply != nil { + return errReply + } + if dict == nil { + return reply.MakeIntReply(0) + } + + _, exists := dict.Get(field) + if exists { + return reply.MakeIntReply(1) + } + return reply.MakeIntReply(0) } func HDel(db *DB, args [][]byte) redis.Reply { - // parse args - if len(args) < 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'hdel' command") - } - key := string(args[0]) - fields := make([]string, len(args) - 1) - fieldArgs := args[1:] - for i, v := range fieldArgs { - fields[i] = string(v) - } - - db.Lock(key) - defer db.UnLock(key) - - // get entity - dict, errReply := db.getAsDict(key) - if errReply != nil { - return errReply - } - if dict == nil { - return reply.MakeIntReply(0) - } - - deleted := 0 - for _, field := range fields { - result := dict.Remove(field) - deleted += result - } - if dict.Len() == 0 { - db.Remove(key) - } - if deleted > 0 { - db.AddAof(makeAofCmd("hdel", args)) - } - - return reply.MakeIntReply(int64(deleted)) + // parse args + if len(args) < 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'hdel' command") + } + key := string(args[0]) + fields := make([]string, len(args)-1) + fieldArgs := args[1:] + for i, v := range fieldArgs { + fields[i] = string(v) + } + + db.Lock(key) + defer db.UnLock(key) + + // get entity + dict, errReply := db.getAsDict(key) + if errReply != nil { + return errReply + } + if dict == nil { + return reply.MakeIntReply(0) + } + + deleted := 0 + for _, field := range fields { + result := dict.Remove(field) + deleted += result + } + if dict.Len() == 0 { + db.Remove(key) + } + if deleted > 0 { + db.AddAof(makeAofCmd("hdel", args)) + } + + return reply.MakeIntReply(int64(deleted)) } func HLen(db *DB, args [][]byte) redis.Reply { - // parse args - if len(args) != 1 { - return reply.MakeErrReply("ERR wrong number of arguments for 'hlen' command") - } - key := string(args[0]) - - dict, errReply := db.getAsDict(key) - if errReply != nil { - return errReply - } - if dict == nil { - return reply.MakeIntReply(0) - } - return reply.MakeIntReply(int64(dict.Len())) + // parse args + if len(args) != 1 { + return reply.MakeErrReply("ERR wrong number of arguments for 'hlen' command") + } + key := string(args[0]) + + dict, errReply := db.getAsDict(key) + if errReply != nil { + return errReply + } + if dict == nil { + return reply.MakeIntReply(0) + } + return reply.MakeIntReply(int64(dict.Len())) } func HMSet(db *DB, args [][]byte) redis.Reply { - // parse args - if len(args) < 3 || len(args) % 2 != 1 { - return reply.MakeErrReply("ERR wrong number of arguments for 'hmset' command") - } - key := string(args[0]) - size := (len(args) - 1) / 2 - fields := make([]string, size) - values := make([][]byte, size) - for i := 0; i < size; i++ { - fields[i] = string(args[2 * i + 1]) - values[i] = args[2 * i + 2] - } - - // lock key - db.Locker.Lock(key) - defer db.Locker.UnLock(key) - - // get or init entity - dict, _, errReply := db.getOrInitDict(key) - if errReply != nil { - return errReply - } - - // put data - for i, field := range fields { - value := values[i] - dict.Put(field, value) - } - db.AddAof(makeAofCmd("hmset", args)) - return &reply.OkReply{} + // parse args + if len(args) < 3 || len(args)%2 != 1 { + return reply.MakeErrReply("ERR wrong number of arguments for 'hmset' command") + } + key := string(args[0]) + size := (len(args) - 1) / 2 + fields := make([]string, size) + values := make([][]byte, size) + for i := 0; i < size; i++ { + fields[i] = string(args[2*i+1]) + values[i] = args[2*i+2] + } + + // lock key + db.Locker.Lock(key) + defer db.Locker.UnLock(key) + + // get or init entity + dict, _, errReply := db.getOrInitDict(key) + if errReply != nil { + return errReply + } + + // put data + for i, field := range fields { + value := values[i] + dict.Put(field, value) + } + db.AddAof(makeAofCmd("hmset", args)) + return &reply.OkReply{} } func HMGet(db *DB, args [][]byte) redis.Reply { - if len(args) < 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'hmget' command") - } - key := string(args[0]) - size := len(args) - 1 - fields := make([]string, size) - for i := 0; i < size; i++ { - fields[i] = string(args[i + 1]) - } - - db.RLock(key) - defer db.RUnLock(key) - - // get entity - result := make([][]byte, size) - dict, errReply := db.getAsDict(key) - if errReply != nil { - return errReply - } - if dict == nil { - return reply.MakeMultiBulkReply(result) - } - - for i, field := range fields { - value, ok := dict.Get(field) - if !ok { - result[i] = nil - } else { - bytes, _ := value.([]byte) - result[i] = bytes - } - } - return reply.MakeMultiBulkReply(result) + if len(args) < 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'hmget' command") + } + key := string(args[0]) + size := len(args) - 1 + fields := make([]string, size) + for i := 0; i < size; i++ { + fields[i] = string(args[i+1]) + } + + db.RLock(key) + defer db.RUnLock(key) + + // get entity + result := make([][]byte, size) + dict, errReply := db.getAsDict(key) + if errReply != nil { + return errReply + } + if dict == nil { + return reply.MakeMultiBulkReply(result) + } + + for i, field := range fields { + value, ok := dict.Get(field) + if !ok { + result[i] = nil + } else { + bytes, _ := value.([]byte) + result[i] = bytes + } + } + return reply.MakeMultiBulkReply(result) } func HKeys(db *DB, args [][]byte) redis.Reply { - if len(args) != 1 { - return reply.MakeErrReply("ERR wrong number of arguments for 'hkeys' command") - } - key := string(args[0]) - - db.RLock(key) - defer db.RUnLock(key) - - dict, errReply := db.getAsDict(key) - if errReply != nil { - return errReply - } - if dict == nil { - return &reply.EmptyMultiBulkReply{} - } - - fields := make([][]byte, dict.Len()) - i := 0 - dict.ForEach(func(key string, val interface{})bool { - fields[i] = []byte(key) - i++ - return true - }) - return reply.MakeMultiBulkReply(fields[:i]) + if len(args) != 1 { + return reply.MakeErrReply("ERR wrong number of arguments for 'hkeys' command") + } + key := string(args[0]) + + db.RLock(key) + defer db.RUnLock(key) + + dict, errReply := db.getAsDict(key) + if errReply != nil { + return errReply + } + if dict == nil { + return &reply.EmptyMultiBulkReply{} + } + + fields := make([][]byte, dict.Len()) + i := 0 + dict.ForEach(func(key string, val interface{}) bool { + fields[i] = []byte(key) + i++ + return true + }) + return reply.MakeMultiBulkReply(fields[:i]) } func HVals(db *DB, args [][]byte) redis.Reply { - if len(args) != 1 { - return reply.MakeErrReply("ERR wrong number of arguments for 'hvals' command") - } - key := string(args[0]) - - db.RLock(key) - defer db.RUnLock(key) - - // get entity - dict, errReply := db.getAsDict(key) - if errReply != nil { - return errReply - } - if dict == nil { - return &reply.EmptyMultiBulkReply{} - } - - values := make([][]byte, dict.Len()) - i := 0 - dict.ForEach(func(key string, val interface{})bool { - values[i], _ = val.([]byte) - i++ - return true - }) - return reply.MakeMultiBulkReply(values[:i]) + if len(args) != 1 { + return reply.MakeErrReply("ERR wrong number of arguments for 'hvals' command") + } + key := string(args[0]) + + db.RLock(key) + defer db.RUnLock(key) + + // get entity + dict, errReply := db.getAsDict(key) + if errReply != nil { + return errReply + } + if dict == nil { + return &reply.EmptyMultiBulkReply{} + } + + values := make([][]byte, dict.Len()) + i := 0 + dict.ForEach(func(key string, val interface{}) bool { + values[i], _ = val.([]byte) + i++ + return true + }) + return reply.MakeMultiBulkReply(values[:i]) } func HGetAll(db *DB, args [][]byte) redis.Reply { - if len(args) != 1 { - return reply.MakeErrReply("ERR wrong number of arguments for 'hgetAll' command") - } - key := string(args[0]) - - db.RLock(key) - defer db.RUnLock(key) - - // get entity - dict, errReply := db.getAsDict(key) - if errReply != nil { - return errReply - } - if dict == nil { - return &reply.EmptyMultiBulkReply{} - } - - size := dict.Len() - result := make([][]byte, size * 2) - i := 0 - dict.ForEach(func(key string, val interface{})bool { - result[i] = []byte(key) - i++ - result[i], _ = val.([]byte) - i++ - return true - }) - return reply.MakeMultiBulkReply(result[:i]) + if len(args) != 1 { + return reply.MakeErrReply("ERR wrong number of arguments for 'hgetAll' command") + } + key := string(args[0]) + + db.RLock(key) + defer db.RUnLock(key) + + // get entity + dict, errReply := db.getAsDict(key) + if errReply != nil { + return errReply + } + if dict == nil { + return &reply.EmptyMultiBulkReply{} + } + + size := dict.Len() + result := make([][]byte, size*2) + i := 0 + dict.ForEach(func(key string, val interface{}) bool { + result[i] = []byte(key) + i++ + result[i], _ = val.([]byte) + i++ + return true + }) + return reply.MakeMultiBulkReply(result[:i]) } func HIncrBy(db *DB, args [][]byte) redis.Reply { - if len(args) != 3 { - return reply.MakeErrReply("ERR wrong number of arguments for 'hincrby' command") - } - key := string(args[0]) - field := string(args[1]) - rawDelta := string(args[2]) - delta, err := strconv.ParseInt(rawDelta, 10, 64) - if err != nil { - return reply.MakeErrReply("ERR value is not an integer or out of range") - } - - db.Locker.Lock(key) - defer db.Locker.UnLock(key) - - dict, _, errReply := db.getOrInitDict(key) - if errReply != nil { - return errReply - } - - value, exists := dict.Get(field) - if !exists { - dict.Put(field, args[2]) - db.AddAof(makeAofCmd("hincrby", args)) - return reply.MakeBulkReply(args[2]) - } else { - val, err := strconv.ParseInt(string(value.([]byte)), 10, 64) - if err != nil { - return reply.MakeErrReply("ERR hash value is not an integer") - } - val += delta - bytes := []byte(strconv.FormatInt(val, 10)) - dict.Put(field, bytes) - db.AddAof(makeAofCmd("hincrby", args)) - return reply.MakeBulkReply(bytes) - } + if len(args) != 3 { + return reply.MakeErrReply("ERR wrong number of arguments for 'hincrby' command") + } + key := string(args[0]) + field := string(args[1]) + rawDelta := string(args[2]) + delta, err := strconv.ParseInt(rawDelta, 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + + db.Locker.Lock(key) + defer db.Locker.UnLock(key) + + dict, _, errReply := db.getOrInitDict(key) + if errReply != nil { + return errReply + } + + value, exists := dict.Get(field) + if !exists { + dict.Put(field, args[2]) + db.AddAof(makeAofCmd("hincrby", args)) + return reply.MakeBulkReply(args[2]) + } else { + val, err := strconv.ParseInt(string(value.([]byte)), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR hash value is not an integer") + } + val += delta + bytes := []byte(strconv.FormatInt(val, 10)) + dict.Put(field, bytes) + db.AddAof(makeAofCmd("hincrby", args)) + return reply.MakeBulkReply(bytes) + } } func HIncrByFloat(db *DB, args [][]byte) redis.Reply { - if len(args) != 3 { - return reply.MakeErrReply("ERR wrong number of arguments for 'hincrbyfloat' command") - } - key := string(args[0]) - field := string(args[1]) - rawDelta := string(args[2]) - delta, err := decimal.NewFromString(rawDelta) - if err != nil { - return reply.MakeErrReply("ERR value is not a valid float") - } - - db.Lock(key) - defer db.UnLock(key) - - // get or init entity - dict, _, errReply := db.getOrInitDict(key) - if errReply != nil { - return errReply - } - - value, exists := dict.Get(field) - if !exists { - dict.Put(field, args[2]) - return reply.MakeBulkReply(args[2]) - } else { - val, err := decimal.NewFromString(string(value.([]byte))) - if err != nil { - return reply.MakeErrReply("ERR hash value is not a float") - } - result := val.Add(delta) - resultBytes:= []byte(result.String()) - dict.Put(field, resultBytes) - db.AddAof(makeAofCmd("hincrbyfloat", args)) - return reply.MakeBulkReply(resultBytes) - } -} \ No newline at end of file + if len(args) != 3 { + return reply.MakeErrReply("ERR wrong number of arguments for 'hincrbyfloat' command") + } + key := string(args[0]) + field := string(args[1]) + rawDelta := string(args[2]) + delta, err := decimal.NewFromString(rawDelta) + if err != nil { + return reply.MakeErrReply("ERR value is not a valid float") + } + + db.Lock(key) + defer db.UnLock(key) + + // get or init entity + dict, _, errReply := db.getOrInitDict(key) + if errReply != nil { + return errReply + } + + value, exists := dict.Get(field) + if !exists { + dict.Put(field, args[2]) + return reply.MakeBulkReply(args[2]) + } else { + val, err := decimal.NewFromString(string(value.([]byte))) + if err != nil { + return reply.MakeErrReply("ERR hash value is not a float") + } + result := val.Add(delta) + resultBytes := []byte(result.String()) + dict.Put(field, resultBytes) + db.AddAof(makeAofCmd("hincrbyfloat", args)) + return reply.MakeBulkReply(resultBytes) + } +} diff --git a/src/db/list.go b/src/db/list.go index 913a2988..9cfa1973 100644 --- a/src/db/list.go +++ b/src/db/list.go @@ -1,439 +1,439 @@ package db import ( - List "github.com/HDT3213/godis/src/datastruct/list" - "github.com/HDT3213/godis/src/interface/redis" - "github.com/HDT3213/godis/src/redis/reply" - "strconv" + List "github.com/HDT3213/godis/src/datastruct/list" + "github.com/HDT3213/godis/src/interface/redis" + "github.com/HDT3213/godis/src/redis/reply" + "strconv" ) -func (db *DB) getAsList(key string)(*List.LinkedList, reply.ErrorReply) { - entity, ok := db.Get(key) - if !ok { - return nil, nil - } - bytes, ok := entity.Data.(*List.LinkedList) - if !ok { - return nil, &reply.WrongTypeErrReply{} - } - return bytes, nil +func (db *DB) getAsList(key string) (*List.LinkedList, reply.ErrorReply) { + entity, ok := db.Get(key) + if !ok { + return nil, nil + } + bytes, ok := entity.Data.(*List.LinkedList) + if !ok { + return nil, &reply.WrongTypeErrReply{} + } + return bytes, nil } -func (db *DB) getOrInitList(key string)(list *List.LinkedList, inited bool, errReply reply.ErrorReply) { - list, errReply = db.getAsList(key) - if errReply != nil { - return nil, false, errReply - } - inited = false - if list == nil { - list = &List.LinkedList{} - db.Put(key, &DataEntity{ - Data: list, - }) - inited = true - } - return list, inited, nil +func (db *DB) getOrInitList(key string) (list *List.LinkedList, inited bool, errReply reply.ErrorReply) { + list, errReply = db.getAsList(key) + if errReply != nil { + return nil, false, errReply + } + inited = false + if list == nil { + list = &List.LinkedList{} + db.Put(key, &DataEntity{ + Data: list, + }) + inited = true + } + return list, inited, nil } func LIndex(db *DB, args [][]byte) redis.Reply { - // parse args - if len(args) != 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'lindex' command") - } - key := string(args[0]) - index64, err := strconv.ParseInt(string(args[1]), 10, 64) - if err != nil { - return reply.MakeErrReply("ERR value is not an integer or out of range") - } - index := int(index64) - - // get entity - list, errReply := db.getAsList(key) - if errReply != nil { - return errReply - } - if list == nil { - return &reply.NullBulkReply{} - } - - size := list.Len() // assert: size > 0 - if index < -1 * size { - return &reply.NullBulkReply{} - } else if index < 0 { - index = size + index - } else if index >= size { - return &reply.NullBulkReply{} - } - - val, _ := list.Get(index).([]byte) - return reply.MakeBulkReply(val) + // parse args + if len(args) != 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'lindex' command") + } + key := string(args[0]) + index64, err := strconv.ParseInt(string(args[1]), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + index := int(index64) + + // get entity + list, errReply := db.getAsList(key) + if errReply != nil { + return errReply + } + if list == nil { + return &reply.NullBulkReply{} + } + + size := list.Len() // assert: size > 0 + if index < -1*size { + return &reply.NullBulkReply{} + } else if index < 0 { + index = size + index + } else if index >= size { + return &reply.NullBulkReply{} + } + + val, _ := list.Get(index).([]byte) + return reply.MakeBulkReply(val) } func LLen(db *DB, args [][]byte) redis.Reply { - // parse args - if len(args) != 1 { - return reply.MakeErrReply("ERR wrong number of arguments for 'llen' command") - } - key := string(args[0]) - - list, errReply := db.getAsList(key) - if errReply != nil { - return errReply - } - if list == nil { - return reply.MakeIntReply(0) - } - - size := int64(list.Len()) - return reply.MakeIntReply(size) + // parse args + if len(args) != 1 { + return reply.MakeErrReply("ERR wrong number of arguments for 'llen' command") + } + key := string(args[0]) + + list, errReply := db.getAsList(key) + if errReply != nil { + return errReply + } + if list == nil { + return reply.MakeIntReply(0) + } + + size := int64(list.Len()) + return reply.MakeIntReply(size) } func LPop(db *DB, args [][]byte) redis.Reply { - // parse args - if len(args) != 1 { - return reply.MakeErrReply("ERR wrong number of arguments for 'lindex' command") - } - key := string(args[0]) - - // lock - db.Lock(key) - defer db.UnLock(key) - - // get data - list, errReply := db.getAsList(key) - if errReply != nil { - return errReply - } - if list == nil { - return &reply.NullBulkReply{} - } - - val, _ := list.Remove(0).([]byte) - if list.Len() == 0 { - db.Remove(key) - } - db.AddAof(makeAofCmd("lpop", args)) - return reply.MakeBulkReply(val) + // parse args + if len(args) != 1 { + return reply.MakeErrReply("ERR wrong number of arguments for 'lindex' command") + } + key := string(args[0]) + + // lock + db.Lock(key) + defer db.UnLock(key) + + // get data + list, errReply := db.getAsList(key) + if errReply != nil { + return errReply + } + if list == nil { + return &reply.NullBulkReply{} + } + + val, _ := list.Remove(0).([]byte) + if list.Len() == 0 { + db.Remove(key) + } + db.AddAof(makeAofCmd("lpop", args)) + return reply.MakeBulkReply(val) } func LPush(db *DB, args [][]byte) redis.Reply { - if len(args) < 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'lpush' command") - } - key := string(args[0]) - values := args[1:] - - // lock - db.Locker.Lock(key) - defer db.Locker.UnLock(key) - - // get or init entity - list, _, errReply := db.getOrInitList(key) - if errReply != nil { - return errReply - } - - // insert - for _, value := range values { - list.Insert(0, value) - } - - db.AddAof(makeAofCmd("lpush", args)) - return reply.MakeIntReply(int64(list.Len())) + if len(args) < 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'lpush' command") + } + key := string(args[0]) + values := args[1:] + + // lock + db.Locker.Lock(key) + defer db.Locker.UnLock(key) + + // get or init entity + list, _, errReply := db.getOrInitList(key) + if errReply != nil { + return errReply + } + + // insert + for _, value := range values { + list.Insert(0, value) + } + + db.AddAof(makeAofCmd("lpush", args)) + return reply.MakeIntReply(int64(list.Len())) } func LPushX(db *DB, args [][]byte) redis.Reply { - if len(args) < 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'lpushx' command") - } - key := string(args[0]) - values := args[1:] - - // lock - db.Locker.Lock(key) - defer db.Locker.UnLock(key) - - // get or init entity - list, errReply := db.getAsList(key) - if errReply != nil { - return errReply - } - if list == nil { - return reply.MakeIntReply(0) - } - - // insert - for _, value := range values { - list.Insert(0, value) - } - db.AddAof(makeAofCmd("lpushx", args)) - return reply.MakeIntReply(int64(list.Len())) + if len(args) < 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'lpushx' command") + } + key := string(args[0]) + values := args[1:] + + // lock + db.Locker.Lock(key) + defer db.Locker.UnLock(key) + + // get or init entity + list, errReply := db.getAsList(key) + if errReply != nil { + return errReply + } + if list == nil { + return reply.MakeIntReply(0) + } + + // insert + for _, value := range values { + list.Insert(0, value) + } + db.AddAof(makeAofCmd("lpushx", args)) + return reply.MakeIntReply(int64(list.Len())) } func LRange(db *DB, args [][]byte) redis.Reply { - // parse args - if len(args) != 3 { - return reply.MakeErrReply("ERR wrong number of arguments for 'lrange' command") - } - key := string(args[0]) - start64, err := strconv.ParseInt(string(args[1]), 10, 64) - if err != nil { - return reply.MakeErrReply("ERR value is not an integer or out of range") - } - start := int(start64) - stop64, err := strconv.ParseInt(string(args[2]), 10, 64) - if err != nil { - return reply.MakeErrReply("ERR value is not an integer or out of range") - } - stop := int(stop64) - - // lock key - db.RLock(key) - defer db.RUnLock(key) - - // get data - list, errReply := db.getAsList(key) - if errReply != nil { - return errReply - } - if list == nil { - return &reply.EmptyMultiBulkReply{} - } - - // compute index - size := list.Len() // assert: size > 0 - if start < -1 * size { - start = 0 - } else if start < 0 { - start = size + start - } else if start >= size { - return &reply.EmptyMultiBulkReply{} - } - if stop < -1 * size { - stop = 0 - } else if stop < 0 { - stop = size + stop + 1 - } else if stop < size { - stop = stop + 1 - } else { - stop = size - } - if stop < start { - stop = start - } - - // assert: start in [0, size - 1], stop in [start, size] - slice := list.Range(start, stop) - result := make([][]byte, len(slice)) - for i, raw := range slice { - bytes, _ := raw.([]byte) - result[i] = bytes - } - return reply.MakeMultiBulkReply(result) + // parse args + if len(args) != 3 { + return reply.MakeErrReply("ERR wrong number of arguments for 'lrange' command") + } + key := string(args[0]) + start64, err := strconv.ParseInt(string(args[1]), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + start := int(start64) + stop64, err := strconv.ParseInt(string(args[2]), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + stop := int(stop64) + + // lock key + db.RLock(key) + defer db.RUnLock(key) + + // get data + list, errReply := db.getAsList(key) + if errReply != nil { + return errReply + } + if list == nil { + return &reply.EmptyMultiBulkReply{} + } + + // compute index + size := list.Len() // assert: size > 0 + if start < -1*size { + start = 0 + } else if start < 0 { + start = size + start + } else if start >= size { + return &reply.EmptyMultiBulkReply{} + } + if stop < -1*size { + stop = 0 + } else if stop < 0 { + stop = size + stop + 1 + } else if stop < size { + stop = stop + 1 + } else { + stop = size + } + if stop < start { + stop = start + } + + // assert: start in [0, size - 1], stop in [start, size] + slice := list.Range(start, stop) + result := make([][]byte, len(slice)) + for i, raw := range slice { + bytes, _ := raw.([]byte) + result[i] = bytes + } + return reply.MakeMultiBulkReply(result) } func LRem(db *DB, args [][]byte) redis.Reply { - // parse args - if len(args) != 3 { - return reply.MakeErrReply("ERR wrong number of arguments for 'lrem' command") - } - key := string(args[0]) - count64, err := strconv.ParseInt(string(args[1]), 10, 64) - if err != nil { - return reply.MakeErrReply("ERR value is not an integer or out of range") - } - count := int(count64) - value := args[2] - - // lock - db.Lock(key) - defer db.UnLock(key) - - // get data entity - list, errReply := db.getAsList(key) - if errReply != nil { - return errReply - } - if list == nil { - return reply.MakeIntReply(0) - } - - var removed int - if count == 0 { - removed = list.RemoveAllByVal(value) - } else if count > 0 { - removed = list.RemoveByVal(value, count) - } else { - removed = list.ReverseRemoveByVal(value, -count) - } - - if list.Len() == 0 { - db.Remove(key) - } - if removed > 0 { - db.AddAof(makeAofCmd("lrem", args)) - } - - return reply.MakeIntReply(int64(removed)) + // parse args + if len(args) != 3 { + return reply.MakeErrReply("ERR wrong number of arguments for 'lrem' command") + } + key := string(args[0]) + count64, err := strconv.ParseInt(string(args[1]), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + count := int(count64) + value := args[2] + + // lock + db.Lock(key) + defer db.UnLock(key) + + // get data entity + list, errReply := db.getAsList(key) + if errReply != nil { + return errReply + } + if list == nil { + return reply.MakeIntReply(0) + } + + var removed int + if count == 0 { + removed = list.RemoveAllByVal(value) + } else if count > 0 { + removed = list.RemoveByVal(value, count) + } else { + removed = list.ReverseRemoveByVal(value, -count) + } + + if list.Len() == 0 { + db.Remove(key) + } + if removed > 0 { + db.AddAof(makeAofCmd("lrem", args)) + } + + return reply.MakeIntReply(int64(removed)) } func LSet(db *DB, args [][]byte) redis.Reply { - // parse args - if len(args) != 3 { - return reply.MakeErrReply("ERR wrong number of arguments for 'lset' command") - } - key := string(args[0]) - index64, err := strconv.ParseInt(string(args[1]), 10, 64) - if err != nil { - return reply.MakeErrReply("ERR value is not an integer or out of range") - } - index := int(index64) - value := args[2] - - // lock - db.Locker.Lock(key) - defer db.Locker.UnLock(key) - - // get data - list, errReply := db.getAsList(key) - if errReply != nil { - return errReply - } - if list == nil { - return reply.MakeErrReply("ERR no such key") - } - - size := list.Len() // assert: size > 0 - if index < -1 * size { - return reply.MakeErrReply("ERR index out of range") - } else if index < 0 { - index = size + index - } else if index >= size { - return reply.MakeErrReply("ERR index out of range") - } - - list.Set(index, value) - db.AddAof(makeAofCmd("lset", args)) - return &reply.OkReply{} + // parse args + if len(args) != 3 { + return reply.MakeErrReply("ERR wrong number of arguments for 'lset' command") + } + key := string(args[0]) + index64, err := strconv.ParseInt(string(args[1]), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + index := int(index64) + value := args[2] + + // lock + db.Locker.Lock(key) + defer db.Locker.UnLock(key) + + // get data + list, errReply := db.getAsList(key) + if errReply != nil { + return errReply + } + if list == nil { + return reply.MakeErrReply("ERR no such key") + } + + size := list.Len() // assert: size > 0 + if index < -1*size { + return reply.MakeErrReply("ERR index out of range") + } else if index < 0 { + index = size + index + } else if index >= size { + return reply.MakeErrReply("ERR index out of range") + } + + list.Set(index, value) + db.AddAof(makeAofCmd("lset", args)) + return &reply.OkReply{} } func RPop(db *DB, args [][]byte) redis.Reply { - // parse args - if len(args) != 1 { - return reply.MakeErrReply("ERR wrong number of arguments for 'rpop' command") - } - key := string(args[0]) - - // lock - db.Lock(key) - defer db.UnLock(key) - - // get data - list, errReply := db.getAsList(key) - if errReply != nil { - return errReply - } - if list == nil { - return &reply.NullBulkReply{} - } - - val, _ := list.RemoveLast().([]byte) - if list.Len() == 0 { - db.Remove(key) - } - db.AddAof(makeAofCmd("rpop", args)) - return reply.MakeBulkReply(val) + // parse args + if len(args) != 1 { + return reply.MakeErrReply("ERR wrong number of arguments for 'rpop' command") + } + key := string(args[0]) + + // lock + db.Lock(key) + defer db.UnLock(key) + + // get data + list, errReply := db.getAsList(key) + if errReply != nil { + return errReply + } + if list == nil { + return &reply.NullBulkReply{} + } + + val, _ := list.RemoveLast().([]byte) + if list.Len() == 0 { + db.Remove(key) + } + db.AddAof(makeAofCmd("rpop", args)) + return reply.MakeBulkReply(val) } func RPopLPush(db *DB, args [][]byte) redis.Reply { - if len(args) != 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'rpoplpush' command") - } - sourceKey := string(args[0]) - destKey := string(args[1]) - - // lock - db.Locks(sourceKey, destKey) - defer db.UnLocks(sourceKey, destKey) - - // get source entity - sourceList, errReply := db.getAsList(sourceKey) - if errReply != nil { - return errReply - } - if sourceList == nil { - return &reply.NullBulkReply{} - } - - // get dest entity - destList, _, errReply := db.getOrInitList(destKey) - if errReply != nil { - return errReply - } - - // pop and push - val, _ := sourceList.RemoveLast().([]byte) - destList.Insert(0, val) - - if sourceList.Len() == 0 { - db.Remove(sourceKey) - } - - db.AddAof(makeAofCmd("rpoplpush", args)) - return reply.MakeBulkReply(val) + if len(args) != 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'rpoplpush' command") + } + sourceKey := string(args[0]) + destKey := string(args[1]) + + // lock + db.Locks(sourceKey, destKey) + defer db.UnLocks(sourceKey, destKey) + + // get source entity + sourceList, errReply := db.getAsList(sourceKey) + if errReply != nil { + return errReply + } + if sourceList == nil { + return &reply.NullBulkReply{} + } + + // get dest entity + destList, _, errReply := db.getOrInitList(destKey) + if errReply != nil { + return errReply + } + + // pop and push + val, _ := sourceList.RemoveLast().([]byte) + destList.Insert(0, val) + + if sourceList.Len() == 0 { + db.Remove(sourceKey) + } + + db.AddAof(makeAofCmd("rpoplpush", args)) + return reply.MakeBulkReply(val) } func RPush(db *DB, args [][]byte) redis.Reply { - // parse args - if len(args) < 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'rpush' command") - } - key := string(args[0]) - values := args[1:] - - // lock - db.Lock(key) - defer db.UnLock(key) - - // get or init entity - list, _, errReply := db.getOrInitList(key) - if errReply != nil { - return errReply - } - - // put list - for _, value := range values { - list.Add(value) - } - db.AddAof(makeAofCmd("rpush", args)) - return reply.MakeIntReply(int64(list.Len())) + // parse args + if len(args) < 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'rpush' command") + } + key := string(args[0]) + values := args[1:] + + // lock + db.Lock(key) + defer db.UnLock(key) + + // get or init entity + list, _, errReply := db.getOrInitList(key) + if errReply != nil { + return errReply + } + + // put list + for _, value := range values { + list.Add(value) + } + db.AddAof(makeAofCmd("rpush", args)) + return reply.MakeIntReply(int64(list.Len())) } func RPushX(db *DB, args [][]byte) redis.Reply { - if len(args) < 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'rpush' command") - } - key := string(args[0]) - values := args[1:] - - // lock - db.Lock(key) - defer db.UnLock(key) - - // get or init entity - list, errReply := db.getAsList(key) - if errReply != nil { - return errReply - } - if list == nil { - return reply.MakeIntReply(0) - } - - // put list - for _, value := range values { - list.Add(value) - } - db.AddAof(makeAofCmd("rpushx", args)) - - return reply.MakeIntReply(int64(list.Len())) -} \ No newline at end of file + if len(args) < 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'rpush' command") + } + key := string(args[0]) + values := args[1:] + + // lock + db.Lock(key) + defer db.UnLock(key) + + // get or init entity + list, errReply := db.getAsList(key) + if errReply != nil { + return errReply + } + if list == nil { + return reply.MakeIntReply(0) + } + + // put list + for _, value := range values { + list.Add(value) + } + db.AddAof(makeAofCmd("rpushx", args)) + + return reply.MakeIntReply(int64(list.Len())) +} diff --git a/src/db/router.go b/src/db/router.go index 512ca542..dfed2b4b 100644 --- a/src/db/router.go +++ b/src/db/router.go @@ -1,102 +1,102 @@ package db -func MakeRouter()map[string]CmdFunc { - routerMap := make(map[string]CmdFunc) - routerMap["ping"] = Ping +func MakeRouter() map[string]CmdFunc { + routerMap := make(map[string]CmdFunc) + routerMap["ping"] = Ping - routerMap["del"] = Del - routerMap["expire"] = Expire - routerMap["expireat"] = ExpireAt - routerMap["pexpire"] = PExpire - routerMap["pexpireat"] = PExpireAt - routerMap["ttl"] = TTL - routerMap["pttl"] = PTTL - routerMap["persist"] = Persist - routerMap["exists"] = Exists - routerMap["type"] = Type - routerMap["rename"] = Rename - routerMap["renamenx"] = RenameNx + routerMap["del"] = Del + routerMap["expire"] = Expire + routerMap["expireat"] = ExpireAt + routerMap["pexpire"] = PExpire + routerMap["pexpireat"] = PExpireAt + routerMap["ttl"] = TTL + routerMap["pttl"] = PTTL + routerMap["persist"] = Persist + routerMap["exists"] = Exists + routerMap["type"] = Type + routerMap["rename"] = Rename + routerMap["renamenx"] = RenameNx - routerMap["set"] = Set - routerMap["setnx"] = SetNX - routerMap["setex"] = SetEX - routerMap["psetex"] = PSetEX - routerMap["mset"] = MSet - routerMap["mget"] = MGet - routerMap["msetnx"] = MSetNX - routerMap["get"] = Get - routerMap["getset"] = GetSet - routerMap["incr"] = Incr - routerMap["incrby"] = IncrBy - routerMap["incrbyfloat"] = IncrByFloat - routerMap["decr"] = Decr - routerMap["decrby"] = DecrBy + routerMap["set"] = Set + routerMap["setnx"] = SetNX + routerMap["setex"] = SetEX + routerMap["psetex"] = PSetEX + routerMap["mset"] = MSet + routerMap["mget"] = MGet + routerMap["msetnx"] = MSetNX + routerMap["get"] = Get + routerMap["getset"] = GetSet + routerMap["incr"] = Incr + routerMap["incrby"] = IncrBy + routerMap["incrbyfloat"] = IncrByFloat + routerMap["decr"] = Decr + routerMap["decrby"] = DecrBy - routerMap["lpush"] = LPush - routerMap["lpushx"] = LPushX - routerMap["rpush"] = RPush - routerMap["rpushx"] = RPushX - routerMap["lpop"] = LPop - routerMap["rpop"] = RPop - routerMap["rpoplpush"] = RPopLPush - routerMap["lrem"] = LRem - routerMap["llen"] = LLen - routerMap["lindex"] = LIndex - routerMap["lset"] = LSet - routerMap["lrange"] = LRange + routerMap["lpush"] = LPush + routerMap["lpushx"] = LPushX + routerMap["rpush"] = RPush + routerMap["rpushx"] = RPushX + routerMap["lpop"] = LPop + routerMap["rpop"] = RPop + routerMap["rpoplpush"] = RPopLPush + routerMap["lrem"] = LRem + routerMap["llen"] = LLen + routerMap["lindex"] = LIndex + routerMap["lset"] = LSet + routerMap["lrange"] = LRange - routerMap["hset"] = HSet - routerMap["hsetnx"] = HSetNX - routerMap["hget"] = HGet - routerMap["hexists"] = HExists - routerMap["hdel"] = HDel - routerMap["hlen"] = HLen - routerMap["hmget"] = HMGet - routerMap["hmset"] = HMSet - routerMap["hkeys"] = HKeys - routerMap["hvals"] = HVals - routerMap["hgetall"] = HGetAll - routerMap["hincrby"] = HIncrBy - routerMap["hincrbyfloat"] = HIncrByFloat + routerMap["hset"] = HSet + routerMap["hsetnx"] = HSetNX + routerMap["hget"] = HGet + routerMap["hexists"] = HExists + routerMap["hdel"] = HDel + routerMap["hlen"] = HLen + routerMap["hmget"] = HMGet + routerMap["hmset"] = HMSet + routerMap["hkeys"] = HKeys + routerMap["hvals"] = HVals + routerMap["hgetall"] = HGetAll + routerMap["hincrby"] = HIncrBy + routerMap["hincrbyfloat"] = HIncrByFloat - routerMap["sadd"] = SAdd - routerMap["sismember"] = SIsMember - routerMap["srem"] = SRem - routerMap["scard"] = SCard - routerMap["smembers"] = SMembers - routerMap["sinter"] = SInter - routerMap["sinterstore"] = SInterStore - routerMap["sunion"] = SUnion - routerMap["sunionstore"] = SUnionStore - routerMap["sdiff"] = SDiff - routerMap["sdiffstore"] = SDiffStore - routerMap["srandmember"] = SRandMember + routerMap["sadd"] = SAdd + routerMap["sismember"] = SIsMember + routerMap["srem"] = SRem + routerMap["scard"] = SCard + routerMap["smembers"] = SMembers + routerMap["sinter"] = SInter + routerMap["sinterstore"] = SInterStore + routerMap["sunion"] = SUnion + routerMap["sunionstore"] = SUnionStore + routerMap["sdiff"] = SDiff + routerMap["sdiffstore"] = SDiffStore + routerMap["srandmember"] = SRandMember - routerMap["zadd"] = ZAdd - routerMap["zscore"] = ZScore - routerMap["zincrby"] = ZIncrBy - routerMap["zrank"] = ZRank - routerMap["zcount"] = ZCount - routerMap["zrevrank"] = ZRevRank - routerMap["zcard"] = ZCard - routerMap["zrange"] = ZRange - routerMap["zrevrange"] = ZRevRange - routerMap["zrangebyscore"] = ZRangeByScore - routerMap["zrevrangebyscore"] = ZRevRangeByScore - routerMap["zrem"] = ZRem - routerMap["zremrangebyscore"] = ZRemRangeByScore - routerMap["zremrangebyrank"] = ZRemRangeByRank + routerMap["zadd"] = ZAdd + routerMap["zscore"] = ZScore + routerMap["zincrby"] = ZIncrBy + routerMap["zrank"] = ZRank + routerMap["zcount"] = ZCount + routerMap["zrevrank"] = ZRevRank + routerMap["zcard"] = ZCard + routerMap["zrange"] = ZRange + routerMap["zrevrange"] = ZRevRange + routerMap["zrangebyscore"] = ZRangeByScore + routerMap["zrevrangebyscore"] = ZRevRangeByScore + routerMap["zrem"] = ZRem + routerMap["zremrangebyscore"] = ZRemRangeByScore + routerMap["zremrangebyrank"] = ZRemRangeByRank - routerMap["geoadd"] = GeoAdd - routerMap["geopos"] = GeoPos - routerMap["geodist"] = GeoDist - routerMap["geohash"] = GeoHash - routerMap["georadius"] = GeoRadius - routerMap["georadiusbymember"] = GeoRadiusByMember + routerMap["geoadd"] = GeoAdd + routerMap["geopos"] = GeoPos + routerMap["geodist"] = GeoDist + routerMap["geohash"] = GeoHash + routerMap["georadius"] = GeoRadius + routerMap["georadiusbymember"] = GeoRadiusByMember - routerMap["flushdb"] = FlushDB - routerMap["flushall"] = FlushAll - routerMap["keys"] = Keys + routerMap["flushdb"] = FlushDB + routerMap["flushall"] = FlushAll + routerMap["keys"] = Keys - return routerMap + return routerMap } diff --git a/src/db/server.go b/src/db/server.go index cd34726e..6422efed 100644 --- a/src/db/server.go +++ b/src/db/server.go @@ -1,16 +1,16 @@ package db import ( - "github.com/HDT3213/godis/src/interface/redis" - "github.com/HDT3213/godis/src/redis/reply" + "github.com/HDT3213/godis/src/interface/redis" + "github.com/HDT3213/godis/src/redis/reply" ) func Ping(db *DB, args [][]byte) redis.Reply { - if len(args) == 0 { - return &reply.PongReply{} - } else if len(args) == 1 { - return reply.MakeStatusReply("\"" + string(args[0]) + "\"") - } else { - return reply.MakeErrReply("ERR wrong number of arguments for 'ping' command") - } -} \ No newline at end of file + if len(args) == 0 { + return &reply.PongReply{} + } else if len(args) == 1 { + return reply.MakeStatusReply("\"" + string(args[0]) + "\"") + } else { + return reply.MakeErrReply("ERR wrong number of arguments for 'ping' command") + } +} diff --git a/src/db/sortedset.go b/src/db/sortedset.go index bd7bd74c..60edf25e 100644 --- a/src/db/sortedset.go +++ b/src/db/sortedset.go @@ -1,605 +1,605 @@ package db import ( - SortedSet "github.com/HDT3213/godis/src/datastruct/sortedset" - "github.com/HDT3213/godis/src/interface/redis" - "github.com/HDT3213/godis/src/redis/reply" - "strconv" - "strings" + SortedSet "github.com/HDT3213/godis/src/datastruct/sortedset" + "github.com/HDT3213/godis/src/interface/redis" + "github.com/HDT3213/godis/src/redis/reply" + "strconv" + "strings" ) -func (db *DB)getAsSortedSet(key string)(*SortedSet.SortedSet, reply.ErrorReply) { - entity, exists := db.Get(key) - if !exists { - return nil, nil - } - sortedSet, ok := entity.Data.(*SortedSet.SortedSet) - if !ok { - return nil, &reply.WrongTypeErrReply{} - } - return sortedSet, nil +func (db *DB) getAsSortedSet(key string) (*SortedSet.SortedSet, reply.ErrorReply) { + entity, exists := db.Get(key) + if !exists { + return nil, nil + } + sortedSet, ok := entity.Data.(*SortedSet.SortedSet) + if !ok { + return nil, &reply.WrongTypeErrReply{} + } + return sortedSet, nil } -func (db *DB) getOrInitSortedSet(key string)(sortedSet *SortedSet.SortedSet, inited bool, errReply reply.ErrorReply) { - sortedSet, errReply = db.getAsSortedSet(key) - if errReply != nil { - return nil, false, errReply - } - inited = false - if sortedSet == nil { - sortedSet = SortedSet.Make() - db.Put(key, &DataEntity{ - Data: sortedSet, - }) - inited = true - } - return sortedSet, inited, nil +func (db *DB) getOrInitSortedSet(key string) (sortedSet *SortedSet.SortedSet, inited bool, errReply reply.ErrorReply) { + sortedSet, errReply = db.getAsSortedSet(key) + if errReply != nil { + return nil, false, errReply + } + inited = false + if sortedSet == nil { + sortedSet = SortedSet.Make() + db.Put(key, &DataEntity{ + Data: sortedSet, + }) + inited = true + } + return sortedSet, inited, nil } func ZAdd(db *DB, args [][]byte) redis.Reply { - if len(args) < 3 || len(args) % 2 != 1 { - return reply.MakeErrReply("ERR wrong number of arguments for 'zadd' command") - } - key := string(args[0]) - size := (len(args) - 1) / 2 - elements := make([]*SortedSet.Element, size) - for i := 0; i < size; i++ { - scoreValue := args[2 * i + 1] - member := string(args[2 * i + 2]) - score, err := strconv.ParseFloat(string(scoreValue), 64) - if err != nil { - return reply.MakeErrReply("ERR value is not a valid float") - } - elements[i] = &SortedSet.Element{ - Member:member, - Score:score, - } - } - - // lock - db.Lock(key) - defer db.UnLock(key) - - // get or init entity - sortedSet, _, errReply := db.getOrInitSortedSet(key) - if errReply != nil { - return errReply - } - - i := 0 - for _, e := range elements { - if sortedSet.Add(e.Member, e.Score) { - i++ - } - } - - db.AddAof(makeAofCmd("sdiffstore", args)) - - return reply.MakeIntReply(int64(i)) + if len(args) < 3 || len(args)%2 != 1 { + return reply.MakeErrReply("ERR wrong number of arguments for 'zadd' command") + } + key := string(args[0]) + size := (len(args) - 1) / 2 + elements := make([]*SortedSet.Element, size) + for i := 0; i < size; i++ { + scoreValue := args[2*i+1] + member := string(args[2*i+2]) + score, err := strconv.ParseFloat(string(scoreValue), 64) + if err != nil { + return reply.MakeErrReply("ERR value is not a valid float") + } + elements[i] = &SortedSet.Element{ + Member: member, + Score: score, + } + } + + // lock + db.Lock(key) + defer db.UnLock(key) + + // get or init entity + sortedSet, _, errReply := db.getOrInitSortedSet(key) + if errReply != nil { + return errReply + } + + i := 0 + for _, e := range elements { + if sortedSet.Add(e.Member, e.Score) { + i++ + } + } + + db.AddAof(makeAofCmd("sdiffstore", args)) + + return reply.MakeIntReply(int64(i)) } func ZScore(db *DB, args [][]byte) redis.Reply { - // parse args - if len(args) != 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'zscore' command") - } - key := string(args[0]) - member := string(args[1]) - - sortedSet, errReply := db.getAsSortedSet(key) - if errReply != nil { - return errReply - } - if sortedSet == nil { - return &reply.NullBulkReply{} - } - - element, exists := sortedSet.Get(member) - if !exists { - return &reply.NullBulkReply{} - } - value := strconv.FormatFloat(element.Score, 'f', -1, 64) - return reply.MakeBulkReply([]byte(value)) + // parse args + if len(args) != 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'zscore' command") + } + key := string(args[0]) + member := string(args[1]) + + sortedSet, errReply := db.getAsSortedSet(key) + if errReply != nil { + return errReply + } + if sortedSet == nil { + return &reply.NullBulkReply{} + } + + element, exists := sortedSet.Get(member) + if !exists { + return &reply.NullBulkReply{} + } + value := strconv.FormatFloat(element.Score, 'f', -1, 64) + return reply.MakeBulkReply([]byte(value)) } func ZRank(db *DB, args [][]byte) redis.Reply { - // parse args - if len(args) != 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'zrank' command") - } - key := string(args[0]) - member := string(args[1]) - - // get entity - sortedSet, errReply := db.getAsSortedSet(key) - if errReply != nil { - return errReply - } - if sortedSet == nil { - return &reply.NullBulkReply{} - } - - rank := sortedSet.GetRank(member, false) - if rank < 0 { - return &reply.NullBulkReply{} - } - return reply.MakeIntReply(rank) + // parse args + if len(args) != 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'zrank' command") + } + key := string(args[0]) + member := string(args[1]) + + // get entity + sortedSet, errReply := db.getAsSortedSet(key) + if errReply != nil { + return errReply + } + if sortedSet == nil { + return &reply.NullBulkReply{} + } + + rank := sortedSet.GetRank(member, false) + if rank < 0 { + return &reply.NullBulkReply{} + } + return reply.MakeIntReply(rank) } func ZRevRank(db *DB, args [][]byte) redis.Reply { - // parse args - if len(args) != 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'zrevrank' command") - } - key := string(args[0]) - member := string(args[1]) - - // get entity - sortedSet, errReply := db.getAsSortedSet(key) - if errReply != nil { - return errReply - } - if sortedSet == nil { - return &reply.NullBulkReply{} - } - - rank := sortedSet.GetRank(member, true) - if rank < 0 { - return &reply.NullBulkReply{} - } - return reply.MakeIntReply(rank) + // parse args + if len(args) != 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'zrevrank' command") + } + key := string(args[0]) + member := string(args[1]) + + // get entity + sortedSet, errReply := db.getAsSortedSet(key) + if errReply != nil { + return errReply + } + if sortedSet == nil { + return &reply.NullBulkReply{} + } + + rank := sortedSet.GetRank(member, true) + if rank < 0 { + return &reply.NullBulkReply{} + } + return reply.MakeIntReply(rank) } func ZCard(db *DB, args [][]byte) redis.Reply { - // parse args - if len(args) != 1 { - return reply.MakeErrReply("ERR wrong number of arguments for 'zcard' command") - } - key := string(args[0]) - - // get entity - sortedSet, errReply := db.getAsSortedSet(key) - if errReply != nil { - return errReply - } - if sortedSet == nil { - return reply.MakeIntReply(0) - } - - return reply.MakeIntReply(int64(sortedSet.Len())) + // parse args + if len(args) != 1 { + return reply.MakeErrReply("ERR wrong number of arguments for 'zcard' command") + } + key := string(args[0]) + + // get entity + sortedSet, errReply := db.getAsSortedSet(key) + if errReply != nil { + return errReply + } + if sortedSet == nil { + return reply.MakeIntReply(0) + } + + return reply.MakeIntReply(int64(sortedSet.Len())) } func ZRange(db *DB, args [][]byte) redis.Reply { - // parse args - if len(args) != 3 && len(args) != 4 { - return reply.MakeErrReply("ERR wrong number of arguments for 'zrange' command") - } - withScores := false - if len(args) == 4 { - if strings.ToUpper(string(args[3])) != "WITHSCORES" { - return reply.MakeErrReply("syntax error") - } else { - withScores = true - } - } - key := string(args[0]) - start, err := strconv.ParseInt(string(args[1]), 10, 64) - if err != nil { - return reply.MakeErrReply("ERR value is not an integer or out of range") - } - stop, err := strconv.ParseInt(string(args[2]), 10, 64) - if err != nil { - return reply.MakeErrReply("ERR value is not an integer or out of range") - } - return range0(db, key, start, stop, withScores, false) + // parse args + if len(args) != 3 && len(args) != 4 { + return reply.MakeErrReply("ERR wrong number of arguments for 'zrange' command") + } + withScores := false + if len(args) == 4 { + if strings.ToUpper(string(args[3])) != "WITHSCORES" { + return reply.MakeErrReply("syntax error") + } else { + withScores = true + } + } + key := string(args[0]) + start, err := strconv.ParseInt(string(args[1]), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + stop, err := strconv.ParseInt(string(args[2]), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + return range0(db, key, start, stop, withScores, false) } func ZRevRange(db *DB, args [][]byte) redis.Reply { - // parse args - if len(args) != 3 && len(args) != 4 { - return reply.MakeErrReply("ERR wrong number of arguments for 'zrevrange' command") - } - withScores := false - if len(args) == 4 { - if string(args[3]) != "WITHSCORES" { - return reply.MakeErrReply("syntax error") - } else { - withScores = true - } - } - key := string(args[0]) - start, err := strconv.ParseInt(string(args[1]), 10, 64) - if err != nil { - return reply.MakeErrReply("ERR value is not an integer or out of range") - } - stop, err := strconv.ParseInt(string(args[2]), 10, 64) - if err != nil { - return reply.MakeErrReply("ERR value is not an integer or out of range") - } - return range0(db, key, start, stop, withScores, true) + // parse args + if len(args) != 3 && len(args) != 4 { + return reply.MakeErrReply("ERR wrong number of arguments for 'zrevrange' command") + } + withScores := false + if len(args) == 4 { + if string(args[3]) != "WITHSCORES" { + return reply.MakeErrReply("syntax error") + } else { + withScores = true + } + } + key := string(args[0]) + start, err := strconv.ParseInt(string(args[1]), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + stop, err := strconv.ParseInt(string(args[2]), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + return range0(db, key, start, stop, withScores, true) } -func range0(db *DB, key string, start int64, stop int64, withScores bool, desc bool)redis.Reply { - // lock key - db.Locker.RLock(key) - defer db.Locker.RUnLock(key) - - // get data - sortedSet, errReply := db.getAsSortedSet(key) - if errReply != nil { - return errReply - } - if sortedSet == nil { - return &reply.EmptyMultiBulkReply{} - } - - // compute index - size := sortedSet.Len() // assert: size > 0 - if start < -1 * size { - start = 0 - } else if start < 0 { - start = size + start - } else if start >= size { - return &reply.EmptyMultiBulkReply{} - } - if stop < -1 * size { - stop = 0 - } else if stop < 0 { - stop = size + stop + 1 - } else if stop < size { - stop = stop + 1 - } else { - stop = size - } - if stop < start { - stop = start - } - - // assert: start in [0, size - 1], stop in [start, size] - slice := sortedSet.Range(start, stop, desc) - if withScores { - result := make([][]byte, len(slice) * 2) - i := 0 - for _, element := range slice { - result[i] = []byte(element.Member) - i++ - scoreStr := strconv.FormatFloat(element.Score, 'f', -1, 64) - result[i] = []byte(scoreStr) - i++ - } - return reply.MakeMultiBulkReply(result) - } else { - result := make([][]byte, len(slice)) - i := 0 - for _, element := range slice { - result[i] = []byte(element.Member) - i++ - } - return reply.MakeMultiBulkReply(result) - } +func range0(db *DB, key string, start int64, stop int64, withScores bool, desc bool) redis.Reply { + // lock key + db.Locker.RLock(key) + defer db.Locker.RUnLock(key) + + // get data + sortedSet, errReply := db.getAsSortedSet(key) + if errReply != nil { + return errReply + } + if sortedSet == nil { + return &reply.EmptyMultiBulkReply{} + } + + // compute index + size := sortedSet.Len() // assert: size > 0 + if start < -1*size { + start = 0 + } else if start < 0 { + start = size + start + } else if start >= size { + return &reply.EmptyMultiBulkReply{} + } + if stop < -1*size { + stop = 0 + } else if stop < 0 { + stop = size + stop + 1 + } else if stop < size { + stop = stop + 1 + } else { + stop = size + } + if stop < start { + stop = start + } + + // assert: start in [0, size - 1], stop in [start, size] + slice := sortedSet.Range(start, stop, desc) + if withScores { + result := make([][]byte, len(slice)*2) + i := 0 + for _, element := range slice { + result[i] = []byte(element.Member) + i++ + scoreStr := strconv.FormatFloat(element.Score, 'f', -1, 64) + result[i] = []byte(scoreStr) + i++ + } + return reply.MakeMultiBulkReply(result) + } else { + result := make([][]byte, len(slice)) + i := 0 + for _, element := range slice { + result[i] = []byte(element.Member) + i++ + } + return reply.MakeMultiBulkReply(result) + } } func ZCount(db *DB, args [][]byte) redis.Reply { - if len(args) != 3 { - return reply.MakeErrReply("ERR wrong number of arguments for 'zcount' command") - } - key := string(args[0]) - - min, err := SortedSet.ParseScoreBorder(string(args[1])) - if err != nil { - return reply.MakeErrReply(err.Error()) - } - - max, err := SortedSet.ParseScoreBorder(string(args[2])) - if err != nil { - return reply.MakeErrReply(err.Error()) - } - - db.Locker.RLock(key) - defer db.Locker.RUnLock(key) - - // get data - sortedSet, errReply := db.getAsSortedSet(key) - if errReply != nil { - return errReply - } - if sortedSet == nil { - return reply.MakeIntReply(0) - } - - return reply.MakeIntReply(sortedSet.Count(min, max)) + if len(args) != 3 { + return reply.MakeErrReply("ERR wrong number of arguments for 'zcount' command") + } + key := string(args[0]) + + min, err := SortedSet.ParseScoreBorder(string(args[1])) + if err != nil { + return reply.MakeErrReply(err.Error()) + } + + max, err := SortedSet.ParseScoreBorder(string(args[2])) + if err != nil { + return reply.MakeErrReply(err.Error()) + } + + db.Locker.RLock(key) + defer db.Locker.RUnLock(key) + + // get data + sortedSet, errReply := db.getAsSortedSet(key) + if errReply != nil { + return errReply + } + if sortedSet == nil { + return reply.MakeIntReply(0) + } + + return reply.MakeIntReply(sortedSet.Count(min, max)) } /* * param limit: limit < 0 means no limit */ -func rangeByScore0(db *DB, key string, min *SortedSet.ScoreBorder, max *SortedSet.ScoreBorder, offset int64, limit int64, withScores bool, desc bool)redis.Reply { - // lock key - db.Locker.RLock(key) - defer db.Locker.RUnLock(key) - - // get data - sortedSet, errReply := db.getAsSortedSet(key) - if errReply != nil { - return errReply - } - if sortedSet == nil { - return &reply.EmptyMultiBulkReply{} - } - - slice := sortedSet.RangeByScore(min, max, offset, limit, desc) - if withScores { - result := make([][]byte, len(slice) * 2) - i := 0 - for _, element := range slice { - result[i] = []byte(element.Member) - i++ - scoreStr := strconv.FormatFloat(element.Score, 'f', -1, 64) - result[i] = []byte(scoreStr) - i++ - } - return reply.MakeMultiBulkReply(result) - } else { - result := make([][]byte, len(slice)) - i := 0 - for _, element := range slice { - result[i] = []byte(element.Member) - i++ - } - return reply.MakeMultiBulkReply(result) - } +func rangeByScore0(db *DB, key string, min *SortedSet.ScoreBorder, max *SortedSet.ScoreBorder, offset int64, limit int64, withScores bool, desc bool) redis.Reply { + // lock key + db.Locker.RLock(key) + defer db.Locker.RUnLock(key) + + // get data + sortedSet, errReply := db.getAsSortedSet(key) + if errReply != nil { + return errReply + } + if sortedSet == nil { + return &reply.EmptyMultiBulkReply{} + } + + slice := sortedSet.RangeByScore(min, max, offset, limit, desc) + if withScores { + result := make([][]byte, len(slice)*2) + i := 0 + for _, element := range slice { + result[i] = []byte(element.Member) + i++ + scoreStr := strconv.FormatFloat(element.Score, 'f', -1, 64) + result[i] = []byte(scoreStr) + i++ + } + return reply.MakeMultiBulkReply(result) + } else { + result := make([][]byte, len(slice)) + i := 0 + for _, element := range slice { + result[i] = []byte(element.Member) + i++ + } + return reply.MakeMultiBulkReply(result) + } } func ZRangeByScore(db *DB, args [][]byte) redis.Reply { - if len(args) < 3 { - return reply.MakeErrReply("ERR wrong number of arguments for 'zrangebyscore' command") - } - key := string(args[0]) - - min, err := SortedSet.ParseScoreBorder(string(args[1])) - if err != nil { - return reply.MakeErrReply(err.Error()) - } - - max, err := SortedSet.ParseScoreBorder(string(args[2])) - if err != nil { - return reply.MakeErrReply(err.Error()) - } - - withScores := false - var offset int64 = 0 - var limit int64 = -1 - if len(args) > 3 { - for i := 3; i < len(args); { - s := string(args[i]) - if strings.ToUpper(s) == "WITHSCORES" { - withScores = true - i++ - } else if strings.ToUpper(s) == "LIMIT" { - if len(args) < i+3 { - return reply.MakeErrReply("ERR syntax error") - } - offset, err = strconv.ParseInt(string(args[i+1]), 10, 64) - if err != nil { - return reply.MakeErrReply("ERR value is not an integer or out of range") - } - limit, err = strconv.ParseInt(string(args[i+2]), 10, 64) - if err != nil { - return reply.MakeErrReply("ERR value is not an integer or out of range") - } - i += 3 - } else { - return reply.MakeErrReply("ERR syntax error") - } - } - } - return rangeByScore0(db, key, min, max, offset, limit, withScores, false) + if len(args) < 3 { + return reply.MakeErrReply("ERR wrong number of arguments for 'zrangebyscore' command") + } + key := string(args[0]) + + min, err := SortedSet.ParseScoreBorder(string(args[1])) + if err != nil { + return reply.MakeErrReply(err.Error()) + } + + max, err := SortedSet.ParseScoreBorder(string(args[2])) + if err != nil { + return reply.MakeErrReply(err.Error()) + } + + withScores := false + var offset int64 = 0 + var limit int64 = -1 + if len(args) > 3 { + for i := 3; i < len(args); { + s := string(args[i]) + if strings.ToUpper(s) == "WITHSCORES" { + withScores = true + i++ + } else if strings.ToUpper(s) == "LIMIT" { + if len(args) < i+3 { + return reply.MakeErrReply("ERR syntax error") + } + offset, err = strconv.ParseInt(string(args[i+1]), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + limit, err = strconv.ParseInt(string(args[i+2]), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + i += 3 + } else { + return reply.MakeErrReply("ERR syntax error") + } + } + } + return rangeByScore0(db, key, min, max, offset, limit, withScores, false) } func ZRevRangeByScore(db *DB, args [][]byte) redis.Reply { - if len(args) < 3 { - return reply.MakeErrReply("ERR wrong number of arguments for 'zrangebyscore' command") - } - key := string(args[0]) - - min, err := SortedSet.ParseScoreBorder(string(args[2])) - if err != nil { - return reply.MakeErrReply(err.Error()) - } - - max, err := SortedSet.ParseScoreBorder(string(args[1])) - if err != nil { - return reply.MakeErrReply(err.Error()) - } - - withScores := false - var offset int64 = 0 - var limit int64 = -1 - if len(args) > 3 { - for i := 3; i < len(args); { - s := string(args[i]) - if strings.ToUpper(s) == "WITHSCORES" { - withScores = true - i++ - } else if strings.ToUpper(s) == "LIMIT" { - if len(args) < i+3 { - return reply.MakeErrReply("ERR syntax error") - } - offset, err = strconv.ParseInt(string(args[i+1]), 10, 64) - if err != nil { - return reply.MakeErrReply("ERR value is not an integer or out of range") - } - limit, err = strconv.ParseInt(string(args[i+2]), 10, 64) - if err != nil { - return reply.MakeErrReply("ERR value is not an integer or out of range") - } - i += 3 - } else { - return reply.MakeErrReply("ERR syntax error") - } - } - } - return rangeByScore0(db, key, min, max, offset, limit, withScores, true) + if len(args) < 3 { + return reply.MakeErrReply("ERR wrong number of arguments for 'zrangebyscore' command") + } + key := string(args[0]) + + min, err := SortedSet.ParseScoreBorder(string(args[2])) + if err != nil { + return reply.MakeErrReply(err.Error()) + } + + max, err := SortedSet.ParseScoreBorder(string(args[1])) + if err != nil { + return reply.MakeErrReply(err.Error()) + } + + withScores := false + var offset int64 = 0 + var limit int64 = -1 + if len(args) > 3 { + for i := 3; i < len(args); { + s := string(args[i]) + if strings.ToUpper(s) == "WITHSCORES" { + withScores = true + i++ + } else if strings.ToUpper(s) == "LIMIT" { + if len(args) < i+3 { + return reply.MakeErrReply("ERR syntax error") + } + offset, err = strconv.ParseInt(string(args[i+1]), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + limit, err = strconv.ParseInt(string(args[i+2]), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + i += 3 + } else { + return reply.MakeErrReply("ERR syntax error") + } + } + } + return rangeByScore0(db, key, min, max, offset, limit, withScores, true) } func ZRemRangeByScore(db *DB, args [][]byte) redis.Reply { - if len(args) != 3 { - return reply.MakeErrReply("ERR wrong number of arguments for 'zremrangebyscore' command") - } - key := string(args[0]) - - min, err := SortedSet.ParseScoreBorder(string(args[1])) - if err != nil { - return reply.MakeErrReply(err.Error()) - } - - max, err := SortedSet.ParseScoreBorder(string(args[2])) - if err != nil { - return reply.MakeErrReply(err.Error()) - } - - db.Locker.Lock(key) - defer db.Locker.UnLock(key) - - // get data - sortedSet, errReply := db.getAsSortedSet(key) - if errReply != nil { - return errReply - } - if sortedSet == nil { - return &reply.EmptyMultiBulkReply{} - } - - removed := sortedSet.RemoveByScore(min, max) - if removed > 0 { - db.AddAof(makeAofCmd("zremrangebyscore", args)) - } - return reply.MakeIntReply(removed) + if len(args) != 3 { + return reply.MakeErrReply("ERR wrong number of arguments for 'zremrangebyscore' command") + } + key := string(args[0]) + + min, err := SortedSet.ParseScoreBorder(string(args[1])) + if err != nil { + return reply.MakeErrReply(err.Error()) + } + + max, err := SortedSet.ParseScoreBorder(string(args[2])) + if err != nil { + return reply.MakeErrReply(err.Error()) + } + + db.Locker.Lock(key) + defer db.Locker.UnLock(key) + + // get data + sortedSet, errReply := db.getAsSortedSet(key) + if errReply != nil { + return errReply + } + if sortedSet == nil { + return &reply.EmptyMultiBulkReply{} + } + + removed := sortedSet.RemoveByScore(min, max) + if removed > 0 { + db.AddAof(makeAofCmd("zremrangebyscore", args)) + } + return reply.MakeIntReply(removed) } func ZRemRangeByRank(db *DB, args [][]byte) redis.Reply { - if len(args) != 3 { - return reply.MakeErrReply("ERR wrong number of arguments for 'zremrangebyrank' command") - } - key := string(args[0]) - start, err := strconv.ParseInt(string(args[1]), 10, 64) - if err != nil { - return reply.MakeErrReply("ERR value is not an integer or out of range") - } - stop, err := strconv.ParseInt(string(args[2]), 10, 64) - if err != nil { - return reply.MakeErrReply("ERR value is not an integer or out of range") - } - - db.Locker.Lock(key) - defer db.Locker.UnLock(key) - - // get data - sortedSet, errReply := db.getAsSortedSet(key) - if errReply != nil { - return errReply - } - if sortedSet == nil { - return reply.MakeIntReply(0) - } - - // compute index - size := sortedSet.Len() // assert: size > 0 - if start < -1 * size { - start = 0 - } else if start < 0 { - start = size + start - } else if start >= size { - return reply.MakeIntReply(0) - } - if stop < -1 * size { - stop = 0 - } else if stop < 0 { - stop = size + stop + 1 - } else if stop < size { - stop = stop + 1 - } else { - stop = size - } - if stop < start { - stop = start - } - - // assert: start in [0, size - 1], stop in [start, size] - removed := sortedSet.RemoveByRank(start, stop) - if removed > 0 { - db.AddAof(makeAofCmd("zremrangebyrank", args)) - } - return reply.MakeIntReply(removed) + if len(args) != 3 { + return reply.MakeErrReply("ERR wrong number of arguments for 'zremrangebyrank' command") + } + key := string(args[0]) + start, err := strconv.ParseInt(string(args[1]), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + stop, err := strconv.ParseInt(string(args[2]), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + + db.Locker.Lock(key) + defer db.Locker.UnLock(key) + + // get data + sortedSet, errReply := db.getAsSortedSet(key) + if errReply != nil { + return errReply + } + if sortedSet == nil { + return reply.MakeIntReply(0) + } + + // compute index + size := sortedSet.Len() // assert: size > 0 + if start < -1*size { + start = 0 + } else if start < 0 { + start = size + start + } else if start >= size { + return reply.MakeIntReply(0) + } + if stop < -1*size { + stop = 0 + } else if stop < 0 { + stop = size + stop + 1 + } else if stop < size { + stop = stop + 1 + } else { + stop = size + } + if stop < start { + stop = start + } + + // assert: start in [0, size - 1], stop in [start, size] + removed := sortedSet.RemoveByRank(start, stop) + if removed > 0 { + db.AddAof(makeAofCmd("zremrangebyrank", args)) + } + return reply.MakeIntReply(removed) } func ZRem(db *DB, args [][]byte) redis.Reply { - // parse args - if len(args) < 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'zrem' command") - } - key := string(args[0]) - fields := make([]string, len(args)-1) - fieldArgs := args[1:] - for i, v := range fieldArgs { - fields[i] = string(v) - } - - db.Lock(key) - defer db.UnLock(key) - - // get entity - sortedSet, errReply := db.getAsSortedSet(key) - if errReply != nil { - return errReply - } - if sortedSet == nil { - return reply.MakeIntReply(0) - } - - var deleted int64 = 0 - for _, field := range fields { - if sortedSet.Remove(field) { - deleted++ - } - } - if deleted > 0 { - db.AddAof(makeAofCmd("zrem", args)) - } - return reply.MakeIntReply(deleted) + // parse args + if len(args) < 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'zrem' command") + } + key := string(args[0]) + fields := make([]string, len(args)-1) + fieldArgs := args[1:] + for i, v := range fieldArgs { + fields[i] = string(v) + } + + db.Lock(key) + defer db.UnLock(key) + + // get entity + sortedSet, errReply := db.getAsSortedSet(key) + if errReply != nil { + return errReply + } + if sortedSet == nil { + return reply.MakeIntReply(0) + } + + var deleted int64 = 0 + for _, field := range fields { + if sortedSet.Remove(field) { + deleted++ + } + } + if deleted > 0 { + db.AddAof(makeAofCmd("zrem", args)) + } + return reply.MakeIntReply(deleted) } func ZIncrBy(db *DB, args [][]byte) redis.Reply { - if len(args) != 3 { - return reply.MakeErrReply("ERR wrong number of arguments for 'zincrby' command") - } - key := string(args[0]) - rawDelta := string(args[1]) - field := string(args[2]) - delta, err := strconv.ParseFloat(rawDelta, 64) - if err != nil { - return reply.MakeErrReply("ERR value is not a valid float") - } - - db.Locker.Lock(key) - defer db.Locker.UnLock(key) - - // get or init entity - sortedSet, _, errReply := db.getOrInitSortedSet(key) - if errReply != nil { - return errReply - } - - element, exists := sortedSet.Get(field) - if !exists { - sortedSet.Add(field, delta) - db.AddAof(makeAofCmd("zincrby", args)) - return reply.MakeBulkReply(args[1]) - } else { - score := element.Score + delta - sortedSet.Add(field, score) - bytes := []byte(strconv.FormatFloat(score, 'f', -1, 64)) - db.AddAof(makeAofCmd("zincrby", args)) - return reply.MakeBulkReply(bytes) - } -} \ No newline at end of file + if len(args) != 3 { + return reply.MakeErrReply("ERR wrong number of arguments for 'zincrby' command") + } + key := string(args[0]) + rawDelta := string(args[1]) + field := string(args[2]) + delta, err := strconv.ParseFloat(rawDelta, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not a valid float") + } + + db.Locker.Lock(key) + defer db.Locker.UnLock(key) + + // get or init entity + sortedSet, _, errReply := db.getOrInitSortedSet(key) + if errReply != nil { + return errReply + } + + element, exists := sortedSet.Get(field) + if !exists { + sortedSet.Add(field, delta) + db.AddAof(makeAofCmd("zincrby", args)) + return reply.MakeBulkReply(args[1]) + } else { + score := element.Score + delta + sortedSet.Add(field, score) + bytes := []byte(strconv.FormatFloat(score, 'f', -1, 64)) + db.AddAof(makeAofCmd("zincrby", args)) + return reply.MakeBulkReply(bytes) + } +} diff --git a/src/interface/db/db.go b/src/interface/db/db.go index ed6853c9..3334f1c5 100644 --- a/src/interface/db/db.go +++ b/src/interface/db/db.go @@ -3,7 +3,7 @@ package db import "github.com/HDT3213/godis/src/interface/redis" type DB interface { - Exec(client redis.Connection, args [][]byte) redis.Reply - AfterClientClose(c redis.Connection) - Close() + Exec(client redis.Connection, args [][]byte) redis.Reply + AfterClientClose(c redis.Connection) + Close() } diff --git a/src/interface/redis/client.go b/src/interface/redis/client.go index 44415166..d6f70d59 100644 --- a/src/interface/redis/client.go +++ b/src/interface/redis/client.go @@ -1,11 +1,11 @@ package redis type Connection interface { - Write([]byte) error + Write([]byte) error - // client should keep its subscribing channels - SubsChannel(channel string) - UnSubsChannel(channel string) - SubsCount()int - GetChannels()[]string + // client should keep its subscribing channels + SubsChannel(channel string) + UnSubsChannel(channel string) + SubsCount() int + GetChannels() []string } diff --git a/src/interface/redis/reply.go b/src/interface/redis/reply.go index e1dea806..bd7888d6 100644 --- a/src/interface/redis/reply.go +++ b/src/interface/redis/reply.go @@ -1,5 +1,5 @@ package redis type Reply interface { - ToBytes()[]byte + ToBytes() []byte } diff --git a/src/interface/tcp/handler.go b/src/interface/tcp/handler.go index ff0463b6..cd2885a4 100644 --- a/src/interface/tcp/handler.go +++ b/src/interface/tcp/handler.go @@ -1,13 +1,13 @@ package tcp import ( - "net" - "context" + "context" + "net" ) type HandleFunc func(ctx context.Context, conn net.Conn) type Handler interface { - Handle(ctx context.Context, conn net.Conn) - Close()error -} \ No newline at end of file + Handle(ctx context.Context, conn net.Conn) + Close() error +} diff --git a/src/lib/consistenthash/consistenthash.go b/src/lib/consistenthash/consistenthash.go index fc9c345b..86675b75 100644 --- a/src/lib/consistenthash/consistenthash.go +++ b/src/lib/consistenthash/consistenthash.go @@ -1,80 +1,80 @@ package consistenthash import ( - "hash/crc32" - "sort" - "strconv" - "strings" + "hash/crc32" + "sort" + "strconv" + "strings" ) type HashFunc func(data []byte) uint32 type Map struct { - hashFunc HashFunc - replicas int - keys []int // sorted - hashMap map[int]string + hashFunc HashFunc + replicas int + keys []int // sorted + hashMap map[int]string } func New(replicas int, fn HashFunc) *Map { - m := &Map{ - replicas: replicas, - hashFunc: fn, - hashMap: make(map[int]string), - } - if m.hashFunc == nil { - m.hashFunc = crc32.ChecksumIEEE - } - return m + m := &Map{ + replicas: replicas, + hashFunc: fn, + hashMap: make(map[int]string), + } + if m.hashFunc == nil { + m.hashFunc = crc32.ChecksumIEEE + } + return m } func (m *Map) IsEmpty() bool { - return len(m.keys) == 0 + return len(m.keys) == 0 } func (m *Map) Add(keys ...string) { - for _, key := range keys { - if key == "" { - continue - } - for i := 0; i < m.replicas; i++ { - hash := int(m.hashFunc([]byte(strconv.Itoa(i) + key))) - m.keys = append(m.keys, hash) - m.hashMap[hash] = key - } - } - sort.Ints(m.keys) + for _, key := range keys { + if key == "" { + continue + } + for i := 0; i < m.replicas; i++ { + hash := int(m.hashFunc([]byte(strconv.Itoa(i) + key))) + m.keys = append(m.keys, hash) + m.hashMap[hash] = key + } + } + sort.Ints(m.keys) } // support hash tag func getPartitionKey(key string) string { - beg := strings.Index(key, "{") - if beg == -1 { - return key - } - end := strings.Index(key, "}") - if end == -1 || end == beg+1 { - return key - } - return key[beg+1 : end] + beg := strings.Index(key, "{") + if beg == -1 { + return key + } + end := strings.Index(key, "}") + if end == -1 || end == beg+1 { + return key + } + return key[beg+1 : end] } // Get gets the closest item in the hash to the provided key. func (m *Map) Get(key string) string { - if m.IsEmpty() { - return "" - } + if m.IsEmpty() { + return "" + } - partitionKey := getPartitionKey(key) - hash := int(m.hashFunc([]byte(partitionKey))) + partitionKey := getPartitionKey(key) + hash := int(m.hashFunc([]byte(partitionKey))) - // Binary search for appropriate replica. - idx := sort.Search(len(m.keys), func(i int) bool { return m.keys[i] >= hash }) + // Binary search for appropriate replica. + idx := sort.Search(len(m.keys), func(i int) bool { return m.keys[i] >= hash }) - // Means we have cycled back to the first replica. - if idx == len(m.keys) { - idx = 0 - } + // Means we have cycled back to the first replica. + if idx == len(m.keys) { + idx = 0 + } - return m.hashMap[m.keys[idx]] + return m.hashMap[m.keys[idx]] } diff --git a/src/lib/files/files.go b/src/lib/files/files.go index 20c705b8..9cbb5dbe 100644 --- a/src/lib/files/files.go +++ b/src/lib/files/files.go @@ -1,78 +1,78 @@ package files import ( - "mime/multipart" - "io/ioutil" - "path" - "os" - "fmt" + "fmt" + "io/ioutil" + "mime/multipart" + "os" + "path" ) func GetSize(f multipart.File) (int, error) { - content, err := ioutil.ReadAll(f) + content, err := ioutil.ReadAll(f) - return len(content), err + return len(content), err } func GetExt(fileName string) string { - return path.Ext(fileName) + return path.Ext(fileName) } func CheckNotExist(src string) bool { - _, err := os.Stat(src) + _, err := os.Stat(src) - return os.IsNotExist(err) + return os.IsNotExist(err) } func CheckPermission(src string) bool { - _, err := os.Stat(src) + _, err := os.Stat(src) - return os.IsPermission(err) + return os.IsPermission(err) } func IsNotExistMkDir(src string) error { - if notExist := CheckNotExist(src); notExist == true { - if err := MkDir(src); err != nil { - return err - } - } + if notExist := CheckNotExist(src); notExist == true { + if err := MkDir(src); err != nil { + return err + } + } - return nil + return nil } func MkDir(src string) error { - err := os.MkdirAll(src, os.ModePerm) - if err != nil { - return err - } + err := os.MkdirAll(src, os.ModePerm) + if err != nil { + return err + } - return nil + return nil } func Open(name string, flag int, perm os.FileMode) (*os.File, error) { - f, err := os.OpenFile(name, flag, perm) - if err != nil { - return nil, err - } + f, err := os.OpenFile(name, flag, perm) + if err != nil { + return nil, err + } - return f, nil + return f, nil } func MustOpen(fileName, dir string) (*os.File, error) { - perm := CheckPermission(dir) - if perm == true { - return nil, fmt.Errorf("permission denied dir: %s", dir) - } - - err := IsNotExistMkDir(dir) - if err != nil { - return nil, fmt.Errorf("error during make dir %s, err: %s", dir, err) - } - - f, err := Open(dir + string(os.PathSeparator) + fileName, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0644) - if err != nil { - return nil, fmt.Errorf("fail to open file, err: %s", err) - } - - return f, nil + perm := CheckPermission(dir) + if perm == true { + return nil, fmt.Errorf("permission denied dir: %s", dir) + } + + err := IsNotExistMkDir(dir) + if err != nil { + return nil, fmt.Errorf("error during make dir %s, err: %s", dir, err) + } + + f, err := Open(dir+string(os.PathSeparator)+fileName, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0644) + if err != nil { + return nil, fmt.Errorf("fail to open file, err: %s", err) + } + + return f, nil } diff --git a/src/lib/geohash/geohash.go b/src/lib/geohash/geohash.go index a5833379..647756a0 100644 --- a/src/lib/geohash/geohash.go +++ b/src/lib/geohash/geohash.go @@ -1,9 +1,9 @@ package geohash import ( - "bytes" - "encoding/base32" - "encoding/binary" + "bytes" + "encoding/base32" + "encoding/binary" ) var bits = []uint8{128, 64, 32, 16, 8, 4, 2, 1} @@ -13,95 +13,95 @@ const defaultBitSize = 64 // 32 bits for latitude, another 32 bits for longitude // return: geohash, box func encode0(latitude, longitude float64, bitSize uint) ([]byte, [2][2]float64) { - box := [2][2]float64{ - {-180, 180}, // lng - {-90, 90}, // lat - } - pos := [2]float64{longitude, latitude} - hash := &bytes.Buffer{} - bit := 0 - var precision uint = 0 - code := uint8(0) - for precision < bitSize { - for direction, val := range pos { - mid := (box[direction][0] + box[direction][1]) / 2 - if val < mid { - box[direction][1] = mid - } else { - box[direction][0] = mid - code |= bits[bit] - } - bit++ - if bit == 8 { - hash.WriteByte(code) - bit = 0 - code = 0 - } - precision++ - if precision == bitSize { - break - } - } - } - // precision%8 > 0 - if code > 0 { - hash.WriteByte(code) - } - return hash.Bytes(), box + box := [2][2]float64{ + {-180, 180}, // lng + {-90, 90}, // lat + } + pos := [2]float64{longitude, latitude} + hash := &bytes.Buffer{} + bit := 0 + var precision uint = 0 + code := uint8(0) + for precision < bitSize { + for direction, val := range pos { + mid := (box[direction][0] + box[direction][1]) / 2 + if val < mid { + box[direction][1] = mid + } else { + box[direction][0] = mid + code |= bits[bit] + } + bit++ + if bit == 8 { + hash.WriteByte(code) + bit = 0 + code = 0 + } + precision++ + if precision == bitSize { + break + } + } + } + // precision%8 > 0 + if code > 0 { + hash.WriteByte(code) + } + return hash.Bytes(), box } func Encode(latitude, longitude float64) uint64 { - buf, _ := encode0(latitude, longitude, defaultBitSize) - return binary.BigEndian.Uint64(buf) + buf, _ := encode0(latitude, longitude, defaultBitSize) + return binary.BigEndian.Uint64(buf) } func decode0(hash []byte) [][]float64 { - box := [][]float64{ - {-180, 180}, - {-90, 90}, - } - direction := 0 - for i := 0; i < len(hash); i++ { - code := hash[i] - for j := 0; j < len(bits); j++ { - mid := (box[direction][0] + box[direction][1]) / 2 - mask := bits[j] - if mask&code > 0 { - box[direction][0] = mid - } else { - box[direction][1] = mid - } - direction = (direction + 1) % 2 - } - } - return box + box := [][]float64{ + {-180, 180}, + {-90, 90}, + } + direction := 0 + for i := 0; i < len(hash); i++ { + code := hash[i] + for j := 0; j < len(bits); j++ { + mid := (box[direction][0] + box[direction][1]) / 2 + mask := bits[j] + if mask&code > 0 { + box[direction][0] = mid + } else { + box[direction][1] = mid + } + direction = (direction + 1) % 2 + } + } + return box } func Decode(code uint64) (float64, float64) { - buf := make([]byte, 8) - binary.BigEndian.PutUint64(buf, code) - box := decode0(buf) - lng := float64(box[0][0]+box[0][1]) / 2 - lat := float64(box[1][0]+box[1][1]) / 2 - return lat, lng + buf := make([]byte, 8) + binary.BigEndian.PutUint64(buf, code) + box := decode0(buf) + lng := float64(box[0][0]+box[0][1]) / 2 + lat := float64(box[1][0]+box[1][1]) / 2 + return lat, lng } func ToString(buf []byte) string { - return enc.EncodeToString(buf) + return enc.EncodeToString(buf) } func ToInt(buf []byte) uint64 { - // padding - if len(buf) < 8 { - buf2 := make([]byte, 8) - copy(buf2, buf) - return binary.BigEndian.Uint64(buf2) - } - return binary.BigEndian.Uint64(buf) + // padding + if len(buf) < 8 { + buf2 := make([]byte, 8) + copy(buf2, buf) + return binary.BigEndian.Uint64(buf2) + } + return binary.BigEndian.Uint64(buf) } func FromInt(code uint64) []byte { - buf := make([]byte, 8) - binary.BigEndian.PutUint64(buf, code) - return buf -} \ No newline at end of file + buf := make([]byte, 8) + binary.BigEndian.PutUint64(buf, code) + return buf +} diff --git a/src/lib/geohash/geohash_test.go b/src/lib/geohash/geohash_test.go index 1046d69c..166d0206 100644 --- a/src/lib/geohash/geohash_test.go +++ b/src/lib/geohash/geohash_test.go @@ -1,39 +1,39 @@ package geohash import ( - "fmt" - "math" - "testing" + "fmt" + "math" + "testing" ) func TestToRange(t *testing.T) { - neighbor := []byte{0x00, 0x00, 0x00, 0x00, 0xE0, 0x00, 0x00, 0x00} - range_ := ToRange(neighbor, 36) - expectedLower := ToInt([]byte{0x00, 0x00, 0x00, 0x00, 0xE0, 0x00, 0x00, 0x00}) - expectedUpper := ToInt([]byte{0x00, 0x00, 0x00, 0x00, 0xF0, 0x00, 0x00, 0x00}) - if expectedLower != range_[0] { - t.Error("incorrect lower") - } - if expectedUpper != range_[1] { - t.Error("incorrect upper") - } + neighbor := []byte{0x00, 0x00, 0x00, 0x00, 0xE0, 0x00, 0x00, 0x00} + range_ := ToRange(neighbor, 36) + expectedLower := ToInt([]byte{0x00, 0x00, 0x00, 0x00, 0xE0, 0x00, 0x00, 0x00}) + expectedUpper := ToInt([]byte{0x00, 0x00, 0x00, 0x00, 0xF0, 0x00, 0x00, 0x00}) + if expectedLower != range_[0] { + t.Error("incorrect lower") + } + if expectedUpper != range_[1] { + t.Error("incorrect upper") + } } func TestEncode(t *testing.T) { - lat0 := 48.669 - lng0 := -4.32913 - hash := Encode(lat0, lng0) - str := ToString(FromInt(hash)) - if str != "gbsuv7zt7zntw" { - t.Error("encode error") - } - lat, lng := Decode(hash) - if math.Abs(lat-lat0) > 1e-6 || math.Abs(lng-lng0) > 1e-6 { - t.Error("decode error") - } + lat0 := 48.669 + lng0 := -4.32913 + hash := Encode(lat0, lng0) + str := ToString(FromInt(hash)) + if str != "gbsuv7zt7zntw" { + t.Error("encode error") + } + lat, lng := Decode(hash) + if math.Abs(lat-lat0) > 1e-6 || math.Abs(lng-lng0) > 1e-6 { + t.Error("decode error") + } } func TestGetNeighbours(t *testing.T) { - ranges := GetNeighbours(90, 180, 630*1000) - fmt.Printf("%#v", ranges) -} \ No newline at end of file + ranges := GetNeighbours(90, 180, 630*1000) + fmt.Printf("%#v", ranges) +} diff --git a/src/lib/geohash/neighbor.go b/src/lib/geohash/neighbor.go index f5e799f6..bbc4336d 100644 --- a/src/lib/geohash/neighbor.go +++ b/src/lib/geohash/neighbor.go @@ -3,134 +3,134 @@ package geohash import "math" const ( - DR = math.Pi / 180.0 - EarthRadius = 6372797.560856 - MercatorMax = 20037726.37 // pi * EarthRadius - MercatorMin = -20037726.37 + DR = math.Pi / 180.0 + EarthRadius = 6372797.560856 + MercatorMax = 20037726.37 // pi * EarthRadius + MercatorMin = -20037726.37 ) func degRad(ang float64) float64 { - return ang * DR + return ang * DR } func radDeg(ang float64) float64 { - return ang / DR + return ang / DR } func getBoundingBox(latitude float64, longitude float64, radiusMeters float64) ( - minLat, maxLat, minLng, maxLng float64) { - minLng = longitude - radDeg(radiusMeters/EarthRadius/math.Cos(degRad(latitude))) - if minLng < -180 { - minLng = -180 - } - maxLng = longitude + radDeg(radiusMeters/EarthRadius/math.Cos(degRad(latitude))) - if maxLng > 180 { - maxLng = 180 - } - minLat = latitude - radDeg(radiusMeters/EarthRadius) - if minLat < -90 { - minLat = -90 - } - maxLat = latitude + radDeg(radiusMeters/EarthRadius) - if maxLat > 90 { - maxLat = 90 - } - return + minLat, maxLat, minLng, maxLng float64) { + minLng = longitude - radDeg(radiusMeters/EarthRadius/math.Cos(degRad(latitude))) + if minLng < -180 { + minLng = -180 + } + maxLng = longitude + radDeg(radiusMeters/EarthRadius/math.Cos(degRad(latitude))) + if maxLng > 180 { + maxLng = 180 + } + minLat = latitude - radDeg(radiusMeters/EarthRadius) + if minLat < -90 { + minLat = -90 + } + maxLat = latitude + radDeg(radiusMeters/EarthRadius) + if maxLat > 90 { + maxLat = 90 + } + return } func estimatePrecisionByRadius(radiusMeters float64, latitude float64) uint { - if radiusMeters == 0 { - return defaultBitSize - 1 - } - var precision uint = 1 - for radiusMeters < MercatorMax { - radiusMeters *= 2 - precision++ - } - /* Make sure range is included in most of the base cases. */ - precision -= 2 - if latitude > 66 || latitude < -66 { - precision-- - if latitude > 80 || latitude < -80 { - precision-- - } - } - if precision < 1 { - precision = 1 - } - if precision > 32 { - precision = 32 - } - return precision*2 - 1 + if radiusMeters == 0 { + return defaultBitSize - 1 + } + var precision uint = 1 + for radiusMeters < MercatorMax { + radiusMeters *= 2 + precision++ + } + /* Make sure range is included in most of the base cases. */ + precision -= 2 + if latitude > 66 || latitude < -66 { + precision-- + if latitude > 80 || latitude < -80 { + precision-- + } + } + if precision < 1 { + precision = 1 + } + if precision > 32 { + precision = 32 + } + return precision*2 - 1 } func Distance(latitude1, longitude1, latitude2, longitude2 float64) float64 { - radLat1 := degRad(latitude1) - radLat2 := degRad(latitude2) - a := radLat1 - radLat2 - b := degRad(longitude1) - degRad(longitude2) - return 2 * EarthRadius * math.Asin(math.Sqrt(math.Pow(math.Sin(a/2), 2) + - math.Cos(radLat1)*math.Cos(radLat2)*math.Pow(math.Sin(b/2), 2))) + radLat1 := degRad(latitude1) + radLat2 := degRad(latitude2) + a := radLat1 - radLat2 + b := degRad(longitude1) - degRad(longitude2) + return 2 * EarthRadius * math.Asin(math.Sqrt(math.Pow(math.Sin(a/2), 2)+ + math.Cos(radLat1)*math.Cos(radLat2)*math.Pow(math.Sin(b/2), 2))) } func ToRange(scope []byte, precision uint) [2]uint64 { - lower := ToInt(scope) - radius := uint64(1 << (64 - precision)) - upper := lower + radius - return [2]uint64{lower, upper} + lower := ToInt(scope) + radius := uint64(1 << (64 - precision)) + upper := lower + radius + return [2]uint64{lower, upper} } func ensureValidLat(lat float64) float64 { - if lat > 90 { - return 90 - } - if lat < -90 { - return -90 - } - return lat + if lat > 90 { + return 90 + } + if lat < -90 { + return -90 + } + return lat } func ensureValidLng(lng float64) float64 { - if lng > 180 { - return -360 + lng - } - if lng < -180 { - return 360 + lng - } - return lng + if lng > 180 { + return -360 + lng + } + if lng < -180 { + return 360 + lng + } + return lng } func GetNeighbours(latitude, longitude, radiusMeters float64) [][2]uint64 { - precision := estimatePrecisionByRadius(radiusMeters, latitude) + precision := estimatePrecisionByRadius(radiusMeters, latitude) - center, box := encode0(latitude, longitude, precision) - height := box[0][1] - box[0][0] - width := box[1][1] - box[1][0] - centerLng := (box[0][1] + box[0][0]) / 2 - centerLat := (box[1][1] + box[1][0]) / 2 - maxLat := ensureValidLat(centerLat + height) - minLat := ensureValidLat(centerLat - height) - maxLng := ensureValidLng(centerLng + width) - minLng := ensureValidLng(centerLng - width) + center, box := encode0(latitude, longitude, precision) + height := box[0][1] - box[0][0] + width := box[1][1] - box[1][0] + centerLng := (box[0][1] + box[0][0]) / 2 + centerLat := (box[1][1] + box[1][0]) / 2 + maxLat := ensureValidLat(centerLat + height) + minLat := ensureValidLat(centerLat - height) + maxLng := ensureValidLng(centerLng + width) + minLng := ensureValidLng(centerLng - width) - var result [10][2]uint64 - leftUpper, _ := encode0(maxLat, minLng, precision) - result[1] = ToRange(leftUpper, precision) - upper, _ := encode0(maxLat, centerLng, precision) - result[2] = ToRange(upper, precision) - rightUpper, _ := encode0(maxLat, maxLng, precision) - result[3] = ToRange(rightUpper, precision) - left, _ := encode0(centerLat, minLng, precision) - result[4] = ToRange(left, precision) - result[5] = ToRange(center, precision) - right, _ := encode0(centerLat, maxLng, precision) - result[6] = ToRange(right, precision) - leftDown, _ := encode0(minLat, minLng, precision) - result[7] = ToRange(leftDown, precision) - down, _ := encode0(minLat, centerLng, precision) - result[8] = ToRange(down, precision) - rightDown, _ := encode0(minLat, maxLng, precision) - result[9] = ToRange(rightDown, precision) + var result [10][2]uint64 + leftUpper, _ := encode0(maxLat, minLng, precision) + result[1] = ToRange(leftUpper, precision) + upper, _ := encode0(maxLat, centerLng, precision) + result[2] = ToRange(upper, precision) + rightUpper, _ := encode0(maxLat, maxLng, precision) + result[3] = ToRange(rightUpper, precision) + left, _ := encode0(centerLat, minLng, precision) + result[4] = ToRange(left, precision) + result[5] = ToRange(center, precision) + right, _ := encode0(centerLat, maxLng, precision) + result[6] = ToRange(right, precision) + leftDown, _ := encode0(minLat, minLng, precision) + result[7] = ToRange(leftDown, precision) + down, _ := encode0(minLat, centerLng, precision) + result[8] = ToRange(down, precision) + rightDown, _ := encode0(minLat, maxLng, precision) + result[9] = ToRange(rightDown, precision) - return result[1:] + return result[1:] } diff --git a/src/lib/marshal/gob/gob.go b/src/lib/marshal/gob/gob.go index 124169c8..9ff67975 100644 --- a/src/lib/marshal/gob/gob.go +++ b/src/lib/marshal/gob/gob.go @@ -1,21 +1,21 @@ package gob import ( - "bytes" - "encoding/gob" + "bytes" + "encoding/gob" ) func Marshal(obj interface{}) ([]byte, error) { - buf := new(bytes.Buffer) - enc := gob.NewEncoder(buf) - err := enc.Encode(obj) - if err != nil { - return nil, err - } - return buf.Bytes(), nil + buf := new(bytes.Buffer) + enc := gob.NewEncoder(buf) + err := enc.Encode(obj) + if err != nil { + return nil, err + } + return buf.Bytes(), nil } func UnMarshal(data []byte, obj interface{}) error { - dec := gob.NewDecoder(bytes.NewBuffer(data)) - return dec.Decode(obj) + dec := gob.NewDecoder(bytes.NewBuffer(data)) + return dec.Decode(obj) } diff --git a/src/lib/sync/atomic/bool.go b/src/lib/sync/atomic/bool.go index 57f34720..4f72ace3 100644 --- a/src/lib/sync/atomic/bool.go +++ b/src/lib/sync/atomic/bool.go @@ -4,16 +4,14 @@ import "sync/atomic" type AtomicBool uint32 -func (b *AtomicBool)Get()bool { - return atomic.LoadUint32((*uint32)(b)) != 0 +func (b *AtomicBool) Get() bool { + return atomic.LoadUint32((*uint32)(b)) != 0 } -func (b *AtomicBool)Set(v bool) { - if v { - atomic.StoreUint32((*uint32)(b), 1) - } else { - atomic.StoreUint32((*uint32)(b), 0) - } +func (b *AtomicBool) Set(v bool) { + if v { + atomic.StoreUint32((*uint32)(b), 1) + } else { + atomic.StoreUint32((*uint32)(b), 0) + } } - - diff --git a/src/lib/sync/wait/wait.go b/src/lib/sync/wait/wait.go index 3abca816..99669f32 100644 --- a/src/lib/sync/wait/wait.go +++ b/src/lib/sync/wait/wait.go @@ -1,38 +1,38 @@ package wait import ( - "sync" - "time" + "sync" + "time" ) type Wait struct { - wg sync.WaitGroup + wg sync.WaitGroup } -func (w *Wait)Add(delta int) { - w.wg.Add(delta) +func (w *Wait) Add(delta int) { + w.wg.Add(delta) } -func (w *Wait)Done() { - w.wg.Done() +func (w *Wait) Done() { + w.wg.Done() } -func (w *Wait)Wait() { - w.wg.Wait() +func (w *Wait) Wait() { + w.wg.Wait() } // return isTimeout -func (w *Wait)WaitWithTimeout(timeout time.Duration)bool { - c := make(chan bool) - go func() { - defer close(c) - w.wg.Wait() - c <- true - }() - select { - case <-c: - return false // completed normally - case <-time.After(timeout): - return true // timed out - } -} \ No newline at end of file +func (w *Wait) WaitWithTimeout(timeout time.Duration) bool { + c := make(chan bool) + go func() { + defer close(c) + w.wg.Wait() + c <- true + }() + select { + case <-c: + return false // completed normally + case <-time.After(timeout): + return true // timed out + } +} diff --git a/src/lib/wildcard/wildcard.go b/src/lib/wildcard/wildcard.go index ef9193ab..64b30135 100644 --- a/src/lib/wildcard/wildcard.go +++ b/src/lib/wildcard/wildcard.go @@ -1,95 +1,95 @@ package wildcard const ( - normal = iota - all // * - any // ? - set_ // [] + normal = iota + all // * + any // ? + set_ // [] ) type item struct { - character byte - set map[byte]bool - typeCode int + character byte + set map[byte]bool + typeCode int } func (i *item) contains(c byte) bool { - _, ok := i.set[c] - return ok + _, ok := i.set[c] + return ok } type Pattern struct { - items []*item + items []*item } func CompilePattern(src string) *Pattern { - items := make([]*item, 0) - escape := false - inSet := false - var set map[byte]bool - for _, v := range src { - c := byte(v) - if escape { - items = append(items, &item{typeCode: normal, character: c}) - escape = false - } else if c == '*' { - items = append(items, &item{typeCode: all}) - } else if c == '?' { - items = append(items, &item{typeCode: any}) - } else if c == '\\' { - escape = true - } else if c == '[' { - if !inSet { - inSet = true - set = make(map[byte]bool) - } else { - set[c] = true - } - } else if c == ']' { - if inSet { - inSet = false - items = append(items, &item{typeCode: set_, set: set}) - } else { - items = append(items, &item{typeCode: normal, character: c}) - } - } else { - if inSet { - set[c] = true - } else { - items = append(items, &item{typeCode: normal, character: c}) - } - } - } - return &Pattern{ - items: items, - } + items := make([]*item, 0) + escape := false + inSet := false + var set map[byte]bool + for _, v := range src { + c := byte(v) + if escape { + items = append(items, &item{typeCode: normal, character: c}) + escape = false + } else if c == '*' { + items = append(items, &item{typeCode: all}) + } else if c == '?' { + items = append(items, &item{typeCode: any}) + } else if c == '\\' { + escape = true + } else if c == '[' { + if !inSet { + inSet = true + set = make(map[byte]bool) + } else { + set[c] = true + } + } else if c == ']' { + if inSet { + inSet = false + items = append(items, &item{typeCode: set_, set: set}) + } else { + items = append(items, &item{typeCode: normal, character: c}) + } + } else { + if inSet { + set[c] = true + } else { + items = append(items, &item{typeCode: normal, character: c}) + } + } + } + return &Pattern{ + items: items, + } } func (p *Pattern) IsMatch(s string) bool { - if len(p.items) == 0 { - return len(s) == 0 - } - m := len(s) - n := len(p.items) - table := make([][]bool, m+1) - for i := 0; i < m+1; i++ { - table[i] = make([]bool, n+1) - } - table[0][0] = true - for j := 1; j < n+1; j++ { - table[0][j] = table[0][j-1] && p.items[j-1].typeCode == all - } - for i := 1; i < m+1; i++ { - for j := 1; j < n+1; j++ { - if p.items[j-1].typeCode == all { - table[i][j] = table[i-1][j] || table[i][j-1] - } else { - table[i][j] = table[i-1][j-1] && - (p.items[j-1].typeCode == any || - (p.items[j-1].typeCode == normal && uint8(s[i-1]) == p.items[j-1].character) || - (p.items[j-1].typeCode == set_ && p.items[j-1].contains(s[i-1]))) - } - } - } - return table[m][n] + if len(p.items) == 0 { + return len(s) == 0 + } + m := len(s) + n := len(p.items) + table := make([][]bool, m+1) + for i := 0; i < m+1; i++ { + table[i] = make([]bool, n+1) + } + table[0][0] = true + for j := 1; j < n+1; j++ { + table[0][j] = table[0][j-1] && p.items[j-1].typeCode == all + } + for i := 1; i < m+1; i++ { + for j := 1; j < n+1; j++ { + if p.items[j-1].typeCode == all { + table[i][j] = table[i-1][j] || table[i][j-1] + } else { + table[i][j] = table[i-1][j-1] && + (p.items[j-1].typeCode == any || + (p.items[j-1].typeCode == normal && uint8(s[i-1]) == p.items[j-1].character) || + (p.items[j-1].typeCode == set_ && p.items[j-1].contains(s[i-1]))) + } + } + } + return table[m][n] } diff --git a/src/lib/wildcard/wildcard_test.go b/src/lib/wildcard/wildcard_test.go index f696fe6c..190840a2 100644 --- a/src/lib/wildcard/wildcard_test.go +++ b/src/lib/wildcard/wildcard_test.go @@ -3,73 +3,73 @@ package wildcard import "testing" func TestWildCard(t *testing.T) { - p := CompilePattern("a") - if !p.IsMatch("a") { - t.Error("expect true actually false") - } - if p.IsMatch("b") { - t.Error("expect false actually true") - } + p := CompilePattern("a") + if !p.IsMatch("a") { + t.Error("expect true actually false") + } + if p.IsMatch("b") { + t.Error("expect false actually true") + } - // test '?' - p = CompilePattern("a?") - if !p.IsMatch("ab") { - t.Error("expect true actually false") - } - if p.IsMatch("a") { - t.Error("expect false actually true") - } - if p.IsMatch("abb") { - t.Error("expect false actually true") - } - if p.IsMatch("bb") { - t.Error("expect false actually true") - } + // test '?' + p = CompilePattern("a?") + if !p.IsMatch("ab") { + t.Error("expect true actually false") + } + if p.IsMatch("a") { + t.Error("expect false actually true") + } + if p.IsMatch("abb") { + t.Error("expect false actually true") + } + if p.IsMatch("bb") { + t.Error("expect false actually true") + } - // test * - p = CompilePattern("a*") - if !p.IsMatch("ab") { - t.Error("expect true actually false") - } - if !p.IsMatch("a") { - t.Error("expect true actually false") - } - if !p.IsMatch("abb") { - t.Error("expect true actually false") - } - if p.IsMatch("bb") { - t.Error("expect false actually true") - } + // test * + p = CompilePattern("a*") + if !p.IsMatch("ab") { + t.Error("expect true actually false") + } + if !p.IsMatch("a") { + t.Error("expect true actually false") + } + if !p.IsMatch("abb") { + t.Error("expect true actually false") + } + if p.IsMatch("bb") { + t.Error("expect false actually true") + } - // test [] - p = CompilePattern("a[ab[]") - if !p.IsMatch("ab") { - t.Error("expect true actually false") - } - if !p.IsMatch("aa") { - t.Error("expect true actually false") - } - if !p.IsMatch("a[") { - t.Error("expect true actually false") - } - if p.IsMatch("abb") { - t.Error("expect false actually true") - } - if p.IsMatch("bb") { - t.Error("expect false actually true") - } + // test [] + p = CompilePattern("a[ab[]") + if !p.IsMatch("ab") { + t.Error("expect true actually false") + } + if !p.IsMatch("aa") { + t.Error("expect true actually false") + } + if !p.IsMatch("a[") { + t.Error("expect true actually false") + } + if p.IsMatch("abb") { + t.Error("expect false actually true") + } + if p.IsMatch("bb") { + t.Error("expect false actually true") + } - // test escape - p = CompilePattern("\\\\") // pattern: \\ - if !p.IsMatch("\\") { - t.Error("expect true actually false") - } + // test escape + p = CompilePattern("\\\\") // pattern: \\ + if !p.IsMatch("\\") { + t.Error("expect true actually false") + } - p = CompilePattern("\\*") - if !p.IsMatch("*") { - t.Error("expect true actually false") - } - if p.IsMatch("a") { - t.Error("expect false actually true") - } + p = CompilePattern("\\*") + if !p.IsMatch("*") { + t.Error("expect true actually false") + } + if p.IsMatch("a") { + t.Error("expect false actually true") + } } diff --git a/src/pubsub/hub.go b/src/pubsub/hub.go index 9265c420..f6fdc95d 100644 --- a/src/pubsub/hub.go +++ b/src/pubsub/hub.go @@ -1,20 +1,20 @@ package pubsub import ( - "github.com/HDT3213/godis/src/datastruct/dict" - "github.com/HDT3213/godis/src/datastruct/lock" + "github.com/HDT3213/godis/src/datastruct/dict" + "github.com/HDT3213/godis/src/datastruct/lock" ) type Hub struct { - // channel -> list(*Client) - subs dict.Dict - // lock channel - subsLocker *lock.Locks + // channel -> list(*Client) + subs dict.Dict + // lock channel + subsLocker *lock.Locks } func MakeHub() *Hub { - return &Hub{ - subs: dict.MakeConcurrent(4), - subsLocker: lock.Make(16), - } + return &Hub{ + subs: dict.MakeConcurrent(4), + subsLocker: lock.Make(16), + } } diff --git a/src/pubsub/pubsub.go b/src/pubsub/pubsub.go index 3c51746f..3de5014c 100644 --- a/src/pubsub/pubsub.go +++ b/src/pubsub/pubsub.go @@ -1,23 +1,23 @@ package pubsub import ( - "github.com/HDT3213/godis/src/datastruct/list" - "github.com/HDT3213/godis/src/interface/redis" - "github.com/HDT3213/godis/src/redis/reply" - "strconv" + "github.com/HDT3213/godis/src/datastruct/list" + "github.com/HDT3213/godis/src/interface/redis" + "github.com/HDT3213/godis/src/redis/reply" + "strconv" ) var ( - _subscribe = "subscribe" - _unsubscribe = "unsubscribe" - messageBytes = []byte("message") - unSubscribeNothing = []byte("*3\r\n$11\r\nunsubscribe\r\n$-1\n:0\r\n") + _subscribe = "subscribe" + _unsubscribe = "unsubscribe" + messageBytes = []byte("message") + unSubscribeNothing = []byte("*3\r\n$11\r\nunsubscribe\r\n$-1\n:0\r\n") ) func makeMsg(t string, channel string, code int64) []byte { - return []byte("*3\r\n$" + strconv.FormatInt(int64(len(t)), 10) + reply.CRLF + t + reply.CRLF + - "$" + strconv.FormatInt(int64(len(channel)), 10) + reply.CRLF + channel + reply.CRLF + - ":" + strconv.FormatInt(code, 10) + reply.CRLF) + return []byte("*3\r\n$" + strconv.FormatInt(int64(len(t)), 10) + reply.CRLF + t + reply.CRLF + + "$" + strconv.FormatInt(int64(len(channel)), 10) + reply.CRLF + channel + reply.CRLF + + ":" + strconv.FormatInt(code, 10) + reply.CRLF) } /* @@ -25,22 +25,22 @@ func makeMsg(t string, channel string, code int64) []byte { * return: is new subscribed */ func subscribe0(hub *Hub, channel string, client redis.Connection) bool { - client.SubsChannel(channel) - - // add into hub.subs - raw, ok := hub.subs.Get(channel) - var subscribers *list.LinkedList - if ok { - subscribers, _ = raw.(*list.LinkedList) - } else { - subscribers = list.Make() - hub.subs.Put(channel, subscribers) - } - if subscribers.Contains(client) { - return false - } - subscribers.Add(client) - return true + client.SubsChannel(channel) + + // add into hub.subs + raw, ok := hub.subs.Get(channel) + var subscribers *list.LinkedList + if ok { + subscribers, _ = raw.(*list.LinkedList) + } else { + subscribers = list.Make() + hub.subs.Put(channel, subscribers) + } + if subscribers.Contains(client) { + return false + } + subscribers.Add(client) + return true } /* @@ -48,102 +48,102 @@ func subscribe0(hub *Hub, channel string, client redis.Connection) bool { * return: is actually un-subscribe */ func unsubscribe0(hub *Hub, channel string, client redis.Connection) bool { - client.UnSubsChannel(channel) - - // remove from hub.subs - raw, ok := hub.subs.Get(channel) - if ok { - subscribers, _ := raw.(*list.LinkedList) - subscribers.RemoveAllByVal(client) - - if subscribers.Len() == 0 { - // clean - hub.subs.Remove(channel) - } - return true - } - return false + client.UnSubsChannel(channel) + + // remove from hub.subs + raw, ok := hub.subs.Get(channel) + if ok { + subscribers, _ := raw.(*list.LinkedList) + subscribers.RemoveAllByVal(client) + + if subscribers.Len() == 0 { + // clean + hub.subs.Remove(channel) + } + return true + } + return false } func Subscribe(hub *Hub, c redis.Connection, args [][]byte) redis.Reply { - channels := make([]string, len(args)) - for i, b := range args { - channels[i] = string(b) - } - - hub.subsLocker.Locks(channels...) - defer hub.subsLocker.UnLocks(channels...) - - for _, channel := range channels { - if subscribe0(hub, channel, c) { - _ = c.Write(makeMsg(_subscribe, channel, int64(c.SubsCount()))) - } - } - return &reply.NoReply{} + channels := make([]string, len(args)) + for i, b := range args { + channels[i] = string(b) + } + + hub.subsLocker.Locks(channels...) + defer hub.subsLocker.UnLocks(channels...) + + for _, channel := range channels { + if subscribe0(hub, channel, c) { + _ = c.Write(makeMsg(_subscribe, channel, int64(c.SubsCount()))) + } + } + return &reply.NoReply{} } func UnsubscribeAll(hub *Hub, c redis.Connection) { - channels := c.GetChannels() + channels := c.GetChannels() - hub.subsLocker.Locks(channels...) - defer hub.subsLocker.UnLocks(channels...) + hub.subsLocker.Locks(channels...) + defer hub.subsLocker.UnLocks(channels...) - for _, channel := range channels { - unsubscribe0(hub, channel, c) - } + for _, channel := range channels { + unsubscribe0(hub, channel, c) + } } func UnSubscribe(db *Hub, c redis.Connection, args [][]byte) redis.Reply { - var channels []string - if len(args) > 0 { - channels = make([]string, len(args)) - for i, b := range args { - channels[i] = string(b) - } - } else { - channels = c.GetChannels() - } - - db.subsLocker.Locks(channels...) - defer db.subsLocker.UnLocks(channels...) - - if len(channels) == 0 { - _ = c.Write(unSubscribeNothing) - return &reply.NoReply{} - } - - for _, channel := range channels { - if unsubscribe0(db, channel, c) { - _ = c.Write(makeMsg(_unsubscribe, channel, int64(c.SubsCount()))) - } - } - return &reply.NoReply{} + var channels []string + if len(args) > 0 { + channels = make([]string, len(args)) + for i, b := range args { + channels[i] = string(b) + } + } else { + channels = c.GetChannels() + } + + db.subsLocker.Locks(channels...) + defer db.subsLocker.UnLocks(channels...) + + if len(channels) == 0 { + _ = c.Write(unSubscribeNothing) + return &reply.NoReply{} + } + + for _, channel := range channels { + if unsubscribe0(db, channel, c) { + _ = c.Write(makeMsg(_unsubscribe, channel, int64(c.SubsCount()))) + } + } + return &reply.NoReply{} } func Publish(hub *Hub, args [][]byte) redis.Reply { - if len(args) != 2 { - return &reply.ArgNumErrReply{Cmd: "publish"} - } - channel := string(args[0]) - message := args[1] - - hub.subsLocker.Lock(channel) - defer hub.subsLocker.UnLock(channel) - - raw, ok := hub.subs.Get(channel) - if !ok { - return reply.MakeIntReply(0) - } - subscribers, _ := raw.(*list.LinkedList) - subscribers.ForEach(func(i int, c interface{}) bool { - client, _ := c.(redis.Connection) - replyArgs := make([][]byte, 3) - replyArgs[0] = messageBytes - replyArgs[1] = []byte(channel) - replyArgs[2] = message - _ = client.Write(reply.MakeMultiBulkReply(replyArgs).ToBytes()) - return true - }) - return reply.MakeIntReply(int64(subscribers.Len())) + if len(args) != 2 { + return &reply.ArgNumErrReply{Cmd: "publish"} + } + channel := string(args[0]) + message := args[1] + + hub.subsLocker.Lock(channel) + defer hub.subsLocker.UnLock(channel) + + raw, ok := hub.subs.Get(channel) + if !ok { + return reply.MakeIntReply(0) + } + subscribers, _ := raw.(*list.LinkedList) + subscribers.ForEach(func(i int, c interface{}) bool { + client, _ := c.(redis.Connection) + replyArgs := make([][]byte, 3) + replyArgs[0] = messageBytes + replyArgs[1] = []byte(channel) + replyArgs[2] = message + _ = client.Write(reply.MakeMultiBulkReply(replyArgs).ToBytes()) + return true + }) + return reply.MakeIntReply(int64(subscribers.Len())) } diff --git a/src/redis/client/client.go b/src/redis/client/client.go index f997787a..c141358a 100644 --- a/src/redis/client/client.go +++ b/src/redis/client/client.go @@ -1,322 +1,321 @@ package client import ( - "bufio" - "context" - "errors" - "github.com/HDT3213/godis/src/interface/redis" - "github.com/HDT3213/godis/src/lib/logger" - "github.com/HDT3213/godis/src/lib/sync/wait" - "github.com/HDT3213/godis/src/redis/reply" - "io" - "net" - "strconv" - "strings" - "sync" - "time" + "bufio" + "context" + "errors" + "github.com/HDT3213/godis/src/interface/redis" + "github.com/HDT3213/godis/src/lib/logger" + "github.com/HDT3213/godis/src/lib/sync/wait" + "github.com/HDT3213/godis/src/redis/reply" + "io" + "net" + "strconv" + "strings" + "sync" + "time" ) type Client struct { - conn net.Conn - sendingReqs chan *Request // waiting sending - waitingReqs chan *Request // waiting response - ticker *time.Ticker - addr string + conn net.Conn + sendingReqs chan *Request // waiting sending + waitingReqs chan *Request // waiting response + ticker *time.Ticker + addr string - ctx context.Context - cancelFunc context.CancelFunc - writing *sync.WaitGroup + ctx context.Context + cancelFunc context.CancelFunc + writing *sync.WaitGroup } type Request struct { - id uint64 - args [][]byte - reply redis.Reply - heartbeat bool - waiting *wait.Wait - err error + id uint64 + args [][]byte + reply redis.Reply + heartbeat bool + waiting *wait.Wait + err error } const ( - chanSize = 256 - maxWait = 3 * time.Second + chanSize = 256 + maxWait = 3 * time.Second ) func MakeClient(addr string) (*Client, error) { - conn, err := net.Dial("tcp", addr) - if err != nil { - return nil, err - } - ctx, cancel := context.WithCancel(context.Background()) - return &Client{ - addr: addr, - conn: conn, - sendingReqs: make(chan *Request, chanSize), - waitingReqs: make(chan *Request, chanSize), - ctx: ctx, - cancelFunc: cancel, - writing: &sync.WaitGroup{}, - }, nil + conn, err := net.Dial("tcp", addr) + if err != nil { + return nil, err + } + ctx, cancel := context.WithCancel(context.Background()) + return &Client{ + addr: addr, + conn: conn, + sendingReqs: make(chan *Request, chanSize), + waitingReqs: make(chan *Request, chanSize), + ctx: ctx, + cancelFunc: cancel, + writing: &sync.WaitGroup{}, + }, nil } func (client *Client) Start() { - client.ticker = time.NewTicker(10 * time.Second) - go client.handleWrite() - go func() { - err := client.handleRead() - logger.Warn(err) - }() - go client.heartbeat() + client.ticker = time.NewTicker(10 * time.Second) + go client.handleWrite() + go func() { + err := client.handleRead() + logger.Warn(err) + }() + go client.heartbeat() } func (client *Client) Close() { - // stop new request - close(client.sendingReqs) + // stop new request + close(client.sendingReqs) - // wait stop process - client.writing.Wait() + // wait stop process + client.writing.Wait() - // clean - client.cancelFunc() - _ = client.conn.Close() - close(client.waitingReqs) + // clean + client.cancelFunc() + _ = client.conn.Close() + close(client.waitingReqs) } func (client *Client) handleConnectionError(err error) error { - err1 := client.conn.Close() - if err1 != nil { - if opErr, ok := err1.(*net.OpError); ok { - if opErr.Err.Error() != "use of closed network connection" { - return err1 - } - } else { - return err1 - } - } - conn, err1 := net.Dial("tcp", client.addr) - if err1 != nil { - logger.Error(err1) - return err1 - } - client.conn = conn - go func() { - _ = client.handleRead() - }() - return nil + err1 := client.conn.Close() + if err1 != nil { + if opErr, ok := err1.(*net.OpError); ok { + if opErr.Err.Error() != "use of closed network connection" { + return err1 + } + } else { + return err1 + } + } + conn, err1 := net.Dial("tcp", client.addr) + if err1 != nil { + logger.Error(err1) + return err1 + } + client.conn = conn + go func() { + _ = client.handleRead() + }() + return nil } func (client *Client) heartbeat() { loop: - for { - select { - case <-client.ticker.C: - client.sendingReqs <- &Request{ - args: [][]byte{[]byte("PING")}, - heartbeat: true, - } - case <-client.ctx.Done(): - break loop - } - } + for { + select { + case <-client.ticker.C: + client.sendingReqs <- &Request{ + args: [][]byte{[]byte("PING")}, + heartbeat: true, + } + case <-client.ctx.Done(): + break loop + } + } } func (client *Client) handleWrite() { loop: - for { - select { - case req := <-client.sendingReqs: - client.writing.Add(1) - client.doRequest(req) - case <-client.ctx.Done(): - break loop - } - } + for { + select { + case req := <-client.sendingReqs: + client.writing.Add(1) + client.doRequest(req) + case <-client.ctx.Done(): + break loop + } + } } // todo: wait with timeout func (client *Client) Send(args [][]byte) redis.Reply { - request := &Request{ - args: args, - heartbeat: false, - waiting: &wait.Wait{}, - } - request.waiting.Add(1) - client.sendingReqs <- request - timeout := request.waiting.WaitWithTimeout(maxWait) - if timeout { - return reply.MakeErrReply("server time out") - } - if request.err != nil { - return reply.MakeErrReply("request failed") - } - return request.reply + request := &Request{ + args: args, + heartbeat: false, + waiting: &wait.Wait{}, + } + request.waiting.Add(1) + client.sendingReqs <- request + timeout := request.waiting.WaitWithTimeout(maxWait) + if timeout { + return reply.MakeErrReply("server time out") + } + if request.err != nil { + return reply.MakeErrReply("request failed") + } + return request.reply } func (client *Client) doRequest(req *Request) { - bytes := reply.MakeMultiBulkReply(req.args).ToBytes() - _, err := client.conn.Write(bytes) - i := 0 - for err != nil && i < 3 { - err = client.handleConnectionError(err) - if err == nil { - _, err = client.conn.Write(bytes) - } - i++ - } - if err == nil { - client.waitingReqs <- req - } else { - req.err = err - req.waiting.Done() - client.writing.Done() - } + bytes := reply.MakeMultiBulkReply(req.args).ToBytes() + _, err := client.conn.Write(bytes) + i := 0 + for err != nil && i < 3 { + err = client.handleConnectionError(err) + if err == nil { + _, err = client.conn.Write(bytes) + } + i++ + } + if err == nil { + client.waitingReqs <- req + } else { + req.err = err + req.waiting.Done() + client.writing.Done() + } } func (client *Client) finishRequest(reply redis.Reply) { - request := <-client.waitingReqs - request.reply = reply - if request.waiting != nil { - request.waiting.Done() - } - client.writing.Done() + request := <-client.waitingReqs + request.reply = reply + if request.waiting != nil { + request.waiting.Done() + } + client.writing.Done() } func (client *Client) handleRead() error { - reader := bufio.NewReader(client.conn) - downloading := false - expectedArgsCount := 0 - receivedCount := 0 - msgType := byte(0) // first char of msg - var args [][]byte - var fixedLen int64 = 0 - var err error - var msg []byte - for { - // read line - if fixedLen == 0 { // read normal line - msg, err = reader.ReadBytes('\n') - if err != nil { - if err == io.EOF || err == io.ErrUnexpectedEOF { - logger.Info("connection close") - } else { - logger.Warn(err) - } + reader := bufio.NewReader(client.conn) + downloading := false + expectedArgsCount := 0 + receivedCount := 0 + msgType := byte(0) // first char of msg + var args [][]byte + var fixedLen int64 = 0 + var err error + var msg []byte + for { + // read line + if fixedLen == 0 { // read normal line + msg, err = reader.ReadBytes('\n') + if err != nil { + if err == io.EOF || err == io.ErrUnexpectedEOF { + logger.Info("connection close") + } else { + logger.Warn(err) + } - return errors.New("connection closed") - } - if len(msg) == 0 || msg[len(msg)-2] != '\r' { - return errors.New("protocol error") - } - } else { // read bulk line (binary safe) - msg = make([]byte, fixedLen+2) - _, err = io.ReadFull(reader, msg) - if err != nil { - if err == io.EOF || err == io.ErrUnexpectedEOF { - return errors.New("connection closed") - } else { - return err - } - } - if len(msg) == 0 || - msg[len(msg)-2] != '\r' || - msg[len(msg)-1] != '\n' { - return errors.New("protocol error") - } - fixedLen = 0 - } + return errors.New("connection closed") + } + if len(msg) == 0 || msg[len(msg)-2] != '\r' { + return errors.New("protocol error") + } + } else { // read bulk line (binary safe) + msg = make([]byte, fixedLen+2) + _, err = io.ReadFull(reader, msg) + if err != nil { + if err == io.EOF || err == io.ErrUnexpectedEOF { + return errors.New("connection closed") + } else { + return err + } + } + if len(msg) == 0 || + msg[len(msg)-2] != '\r' || + msg[len(msg)-1] != '\n' { + return errors.New("protocol error") + } + fixedLen = 0 + } - // parse line - if !downloading { - // receive new response - if msg[0] == '*' { // multi bulk response - // bulk multi msg - expectedLine, err := strconv.ParseUint(string(msg[1:len(msg)-2]), 10, 32) - if err != nil { - return errors.New("protocol error: " + err.Error()) - } - if expectedLine == 0 { - client.finishRequest(&reply.EmptyMultiBulkReply{}) - } else if expectedLine > 0 { - msgType = msg[0] - downloading = true - expectedArgsCount = int(expectedLine) - receivedCount = 0 - args = make([][]byte, expectedLine) - } else { - return errors.New("protocol error") - } - } else if msg[0] == '$' { // bulk response - fixedLen, err = strconv.ParseInt(string(msg[1:len(msg)-2]), 10, 64) - if err != nil { - return err - } - if fixedLen == -1 { // null bulk - client.finishRequest(&reply.NullBulkReply{}) - fixedLen = 0 - } else if fixedLen > 0 { - msgType = msg[0] - downloading = true - expectedArgsCount = 1 - receivedCount = 0 - args = make([][]byte, 1) - } else { - return errors.New("protocol error") - } - } else { // single line response - str := strings.TrimSuffix(string(msg), "\n") - str = strings.TrimSuffix(str, "\r") - var result redis.Reply - switch msg[0] { - case '+': - result = reply.MakeStatusReply(str[1:]) - case '-': - result = reply.MakeErrReply(str[1:]) - case ':': - val, err := strconv.ParseInt(str[1:], 10, 64) - if err != nil { - return errors.New("protocol error") - } - result = reply.MakeIntReply(val) - } - client.finishRequest(result) - } - } else { - // receive following part of a request - line := msg[0 : len(msg)-2] - if line[0] == '$' { - fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64) - if err != nil { - return err - } - if fixedLen <= 0 { // null bulk in multi bulks - args[receivedCount] = []byte{} - receivedCount++ - fixedLen = 0 - } - } else { - args[receivedCount] = line - receivedCount++ - } + // parse line + if !downloading { + // receive new response + if msg[0] == '*' { // multi bulk response + // bulk multi msg + expectedLine, err := strconv.ParseUint(string(msg[1:len(msg)-2]), 10, 32) + if err != nil { + return errors.New("protocol error: " + err.Error()) + } + if expectedLine == 0 { + client.finishRequest(&reply.EmptyMultiBulkReply{}) + } else if expectedLine > 0 { + msgType = msg[0] + downloading = true + expectedArgsCount = int(expectedLine) + receivedCount = 0 + args = make([][]byte, expectedLine) + } else { + return errors.New("protocol error") + } + } else if msg[0] == '$' { // bulk response + fixedLen, err = strconv.ParseInt(string(msg[1:len(msg)-2]), 10, 64) + if err != nil { + return err + } + if fixedLen == -1 { // null bulk + client.finishRequest(&reply.NullBulkReply{}) + fixedLen = 0 + } else if fixedLen > 0 { + msgType = msg[0] + downloading = true + expectedArgsCount = 1 + receivedCount = 0 + args = make([][]byte, 1) + } else { + return errors.New("protocol error") + } + } else { // single line response + str := strings.TrimSuffix(string(msg), "\n") + str = strings.TrimSuffix(str, "\r") + var result redis.Reply + switch msg[0] { + case '+': + result = reply.MakeStatusReply(str[1:]) + case '-': + result = reply.MakeErrReply(str[1:]) + case ':': + val, err := strconv.ParseInt(str[1:], 10, 64) + if err != nil { + return errors.New("protocol error") + } + result = reply.MakeIntReply(val) + } + client.finishRequest(result) + } + } else { + // receive following part of a request + line := msg[0 : len(msg)-2] + if line[0] == '$' { + fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64) + if err != nil { + return err + } + if fixedLen <= 0 { // null bulk in multi bulks + args[receivedCount] = []byte{} + receivedCount++ + fixedLen = 0 + } + } else { + args[receivedCount] = line + receivedCount++ + } - // if sending finished - if receivedCount == expectedArgsCount { - downloading = false // finish downloading progress + // if sending finished + if receivedCount == expectedArgsCount { + downloading = false // finish downloading progress - if msgType == '*' { - reply := reply.MakeMultiBulkReply(args) - client.finishRequest(reply) - } else if msgType == '$' { - reply := reply.MakeBulkReply(args[0]) - client.finishRequest(reply) - } + if msgType == '*' { + reply := reply.MakeMultiBulkReply(args) + client.finishRequest(reply) + } else if msgType == '$' { + reply := reply.MakeBulkReply(args[0]) + client.finishRequest(reply) + } - - // finish reply - expectedArgsCount = 0 - receivedCount = 0 - args = nil - msgType = byte(0) - } - } - } + // finish reply + expectedArgsCount = 0 + receivedCount = 0 + args = nil + msgType = byte(0) + } + } + } } diff --git a/src/redis/client/client_test.go b/src/redis/client/client_test.go index b1eeee44..087ef248 100644 --- a/src/redis/client/client_test.go +++ b/src/redis/client/client_test.go @@ -1,104 +1,104 @@ package client import ( - "github.com/HDT3213/godis/src/lib/logger" - "github.com/HDT3213/godis/src/redis/reply" - "testing" + "github.com/HDT3213/godis/src/lib/logger" + "github.com/HDT3213/godis/src/redis/reply" + "testing" ) func TestClient(t *testing.T) { - logger.Setup(&logger.Settings{ - Path: "logs", - Name: "godis", - Ext: ".log", - TimeFormat: "2006-01-02", - }) - client, err := MakeClient("localhost:6379") - if err != nil { - t.Error(err) - } - client.Start() + logger.Setup(&logger.Settings{ + Path: "logs", + Name: "godis", + Ext: ".log", + TimeFormat: "2006-01-02", + }) + client, err := MakeClient("localhost:6379") + if err != nil { + t.Error(err) + } + client.Start() - result := client.Send([][]byte{ - []byte("PING"), - }) - if statusRet, ok := result.(*reply.StatusReply); ok { - if statusRet.Status != "PONG" { - t.Error("`ping` failed, result: " + statusRet.Status) - } - } + result := client.Send([][]byte{ + []byte("PING"), + }) + if statusRet, ok := result.(*reply.StatusReply); ok { + if statusRet.Status != "PONG" { + t.Error("`ping` failed, result: " + statusRet.Status) + } + } - result = client.Send([][]byte{ - []byte("SET"), - []byte("a"), - []byte("a"), - }) - if statusRet, ok := result.(*reply.StatusReply); ok { - if statusRet.Status != "OK" { - t.Error("`set` failed, result: " + statusRet.Status) - } - } + result = client.Send([][]byte{ + []byte("SET"), + []byte("a"), + []byte("a"), + }) + if statusRet, ok := result.(*reply.StatusReply); ok { + if statusRet.Status != "OK" { + t.Error("`set` failed, result: " + statusRet.Status) + } + } - result = client.Send([][]byte{ - []byte("GET"), - []byte("a"), - }) - if bulkRet, ok := result.(*reply.BulkReply); ok { - if string(bulkRet.Arg) != "a" { - t.Error("`get` failed, result: " + string(bulkRet.Arg)) - } - } + result = client.Send([][]byte{ + []byte("GET"), + []byte("a"), + }) + if bulkRet, ok := result.(*reply.BulkReply); ok { + if string(bulkRet.Arg) != "a" { + t.Error("`get` failed, result: " + string(bulkRet.Arg)) + } + } - result = client.Send([][]byte{ - []byte("DEL"), - []byte("a"), - }) - if intRet, ok := result.(*reply.IntReply); ok { - if intRet.Code != 1 { - t.Error("`del` failed, result: " + string(intRet.Code)) - } - } + result = client.Send([][]byte{ + []byte("DEL"), + []byte("a"), + }) + if intRet, ok := result.(*reply.IntReply); ok { + if intRet.Code != 1 { + t.Error("`del` failed, result: " + string(intRet.Code)) + } + } - result = client.Send([][]byte{ - []byte("GET"), - []byte("a"), - }) - if _, ok := result.(*reply.NullBulkReply); !ok { - t.Error("`get` failed, result: " + string(result.ToBytes())) - } + result = client.Send([][]byte{ + []byte("GET"), + []byte("a"), + }) + if _, ok := result.(*reply.NullBulkReply); !ok { + t.Error("`get` failed, result: " + string(result.ToBytes())) + } - result = client.Send([][]byte{ - []byte("DEL"), - []byte("arr"), - }) + result = client.Send([][]byte{ + []byte("DEL"), + []byte("arr"), + }) - result = client.Send([][]byte{ - []byte("RPUSH"), - []byte("arr"), - []byte("1"), - []byte("2"), - []byte("c"), - }) - if intRet, ok := result.(*reply.IntReply); ok { - if intRet.Code != 3 { - t.Error("`rpush` failed, result: " + string(intRet.Code)) - } - } + result = client.Send([][]byte{ + []byte("RPUSH"), + []byte("arr"), + []byte("1"), + []byte("2"), + []byte("c"), + }) + if intRet, ok := result.(*reply.IntReply); ok { + if intRet.Code != 3 { + t.Error("`rpush` failed, result: " + string(intRet.Code)) + } + } - result = client.Send([][]byte{ - []byte("LRANGE"), - []byte("arr"), - []byte("0"), - []byte("-1"), - }) - if multiBulkRet, ok := result.(*reply.MultiBulkReply); ok { - if len(multiBulkRet.Args) != 3 || - string(multiBulkRet.Args[0]) != "1" || - string(multiBulkRet.Args[1]) != "2" || - string(multiBulkRet.Args[2]) != "c" { - t.Error("`lrange` failed, result: " + string(multiBulkRet.ToBytes())) - } - } + result = client.Send([][]byte{ + []byte("LRANGE"), + []byte("arr"), + []byte("0"), + []byte("-1"), + }) + if multiBulkRet, ok := result.(*reply.MultiBulkReply); ok { + if len(multiBulkRet.Args) != 3 || + string(multiBulkRet.Args[0]) != "1" || + string(multiBulkRet.Args[1]) != "2" || + string(multiBulkRet.Args[2]) != "c" { + t.Error("`lrange` failed, result: " + string(multiBulkRet.ToBytes())) + } + } - client.Close() + client.Close() } diff --git a/src/redis/reply/consts.go b/src/redis/reply/consts.go index 28f0f311..8668aff2 100644 --- a/src/redis/reply/consts.go +++ b/src/redis/reply/consts.go @@ -1,42 +1,42 @@ package reply -type PongReply struct {} +type PongReply struct{} var PongBytes = []byte("+PONG\r\n") -func (r *PongReply)ToBytes()[]byte { - return PongBytes +func (r *PongReply) ToBytes() []byte { + return PongBytes } -type OkReply struct {} +type OkReply struct{} var okBytes = []byte("+OK\r\n") -func (r *OkReply)ToBytes()[]byte { - return okBytes +func (r *OkReply) ToBytes() []byte { + return okBytes } var nullBulkBytes = []byte("$-1\r\n") -type NullBulkReply struct {} +type NullBulkReply struct{} -func (r *NullBulkReply)ToBytes()[]byte { - return nullBulkBytes +func (r *NullBulkReply) ToBytes() []byte { + return nullBulkBytes } var emptyMultiBulkBytes = []byte("*0\r\n") -type EmptyMultiBulkReply struct {} +type EmptyMultiBulkReply struct{} -func (r *EmptyMultiBulkReply)ToBytes()[]byte { - return emptyMultiBulkBytes +func (r *EmptyMultiBulkReply) ToBytes() []byte { + return emptyMultiBulkBytes } // reply nothing, for commands like subscribe -type NoReply struct {} +type NoReply struct{} var NoBytes = []byte("") -func (r *NoReply)ToBytes()[]byte { - return NoBytes -} \ No newline at end of file +func (r *NoReply) ToBytes() []byte { + return NoBytes +} diff --git a/src/redis/reply/errors.go b/src/redis/reply/errors.go index 35f2bb7a..bcfefe79 100644 --- a/src/redis/reply/errors.go +++ b/src/redis/reply/errors.go @@ -1,67 +1,67 @@ package reply // UnknownErr -type UnknownErrReply struct {} +type UnknownErrReply struct{} var unknownErrBytes = []byte("-Err unknown\r\n") -func (r *UnknownErrReply)ToBytes()[]byte { - return unknownErrBytes +func (r *UnknownErrReply) ToBytes() []byte { + return unknownErrBytes } -func (r *UnknownErrReply) Error()string { - return "Err unknown" +func (r *UnknownErrReply) Error() string { + return "Err unknown" } // ArgNumErr type ArgNumErrReply struct { - Cmd string + Cmd string } -func (r *ArgNumErrReply)ToBytes()[]byte { - return []byte("-ERR wrong number of arguments for '" + r.Cmd + "' command\r\n") +func (r *ArgNumErrReply) ToBytes() []byte { + return []byte("-ERR wrong number of arguments for '" + r.Cmd + "' command\r\n") } -func (r *ArgNumErrReply) Error()string { - return "ERR wrong number of arguments for '" + r.Cmd + "' command" +func (r *ArgNumErrReply) Error() string { + return "ERR wrong number of arguments for '" + r.Cmd + "' command" } // SyntaxErr -type SyntaxErrReply struct {} +type SyntaxErrReply struct{} var syntaxErrBytes = []byte("-Err syntax error\r\n") -func (r *SyntaxErrReply)ToBytes()[]byte { - return syntaxErrBytes +func (r *SyntaxErrReply) ToBytes() []byte { + return syntaxErrBytes } -func (r *SyntaxErrReply)Error()string { - return "Err syntax error" +func (r *SyntaxErrReply) Error() string { + return "Err syntax error" } // WrongTypeErr -type WrongTypeErrReply struct {} +type WrongTypeErrReply struct{} var wrongTypeErrBytes = []byte("-WRONGTYPE Operation against a key holding the wrong kind of value\r\n") -func (r *WrongTypeErrReply)ToBytes()[]byte { - return wrongTypeErrBytes +func (r *WrongTypeErrReply) ToBytes() []byte { + return wrongTypeErrBytes } -func (r *WrongTypeErrReply)Error()string { - return "WRONGTYPE Operation against a key holding the wrong kind of value" +func (r *WrongTypeErrReply) Error() string { + return "WRONGTYPE Operation against a key holding the wrong kind of value" } // ProtocolErr type ProtocolErrReply struct { - Msg string + Msg string } -func (r *ProtocolErrReply)ToBytes()[]byte { - return []byte("-ERR Protocol error: '" + r.Msg + "'\r\n") +func (r *ProtocolErrReply) ToBytes() []byte { + return []byte("-ERR Protocol error: '" + r.Msg + "'\r\n") } -func (r *ProtocolErrReply) Error()string { - return "ERR Protocol error: '" + r.Msg -} \ No newline at end of file +func (r *ProtocolErrReply) Error() string { + return "ERR Protocol error: '" + r.Msg +} diff --git a/src/redis/reply/reply.go b/src/redis/reply/reply.go index 57934809..922f89fe 100644 --- a/src/redis/reply/reply.go +++ b/src/redis/reply/reply.go @@ -1,141 +1,140 @@ package reply import ( - "bytes" - "github.com/HDT3213/godis/src/interface/redis" - "strconv" + "bytes" + "github.com/HDT3213/godis/src/interface/redis" + "strconv" ) var ( - nullBulkReplyBytes = []byte("$-1") - CRLF = "\r\n" + nullBulkReplyBytes = []byte("$-1") + CRLF = "\r\n" ) /* ---- Bulk Reply ---- */ type BulkReply struct { - Arg []byte + Arg []byte } func MakeBulkReply(arg []byte) *BulkReply { - return &BulkReply{ - Arg: arg, - } + return &BulkReply{ + Arg: arg, + } } func (r *BulkReply) ToBytes() []byte { - if len(r.Arg) == 0 { - return nullBulkReplyBytes - } - return []byte("$" + strconv.Itoa(len(r.Arg)) + CRLF + string(r.Arg) + CRLF) + if len(r.Arg) == 0 { + return nullBulkReplyBytes + } + return []byte("$" + strconv.Itoa(len(r.Arg)) + CRLF + string(r.Arg) + CRLF) } /* ---- Multi Bulk Reply ---- */ type MultiBulkReply struct { - Args [][]byte + Args [][]byte } func MakeMultiBulkReply(args [][]byte) *MultiBulkReply { - return &MultiBulkReply{ - Args: args, - } + return &MultiBulkReply{ + Args: args, + } } func (r *MultiBulkReply) ToBytes() []byte { - argLen := len(r.Args) - var buf bytes.Buffer - buf.WriteString("*" + strconv.Itoa(argLen) + CRLF) - for _, arg := range r.Args { - if arg == nil { - buf.WriteString("$-1" + CRLF) - } else { - buf.WriteString("$" + strconv.Itoa(len(arg)) + CRLF + string(arg) + CRLF) - } - } - return buf.Bytes() + argLen := len(r.Args) + var buf bytes.Buffer + buf.WriteString("*" + strconv.Itoa(argLen) + CRLF) + for _, arg := range r.Args { + if arg == nil { + buf.WriteString("$-1" + CRLF) + } else { + buf.WriteString("$" + strconv.Itoa(len(arg)) + CRLF + string(arg) + CRLF) + } + } + return buf.Bytes() } /* ---- Multi Raw Reply ---- */ type MultiRawReply struct { - Args [][]byte + Args [][]byte } func MakeMultiRawReply(args [][]byte) *MultiRawReply { - return &MultiRawReply{ - Args: args, - } + return &MultiRawReply{ + Args: args, + } } func (r *MultiRawReply) ToBytes() []byte { - argLen := len(r.Args) - var buf bytes.Buffer - buf.WriteString("*" + strconv.Itoa(argLen) + CRLF) - for _, arg := range r.Args { - buf.Write(arg) - } - return buf.Bytes() + argLen := len(r.Args) + var buf bytes.Buffer + buf.WriteString("*" + strconv.Itoa(argLen) + CRLF) + for _, arg := range r.Args { + buf.Write(arg) + } + return buf.Bytes() } /* ---- Status Reply ---- */ type StatusReply struct { - Status string + Status string } func MakeStatusReply(status string) *StatusReply { - return &StatusReply{ - Status: status, - } + return &StatusReply{ + Status: status, + } } func (r *StatusReply) ToBytes() []byte { - return []byte("+" + r.Status + "\r\n") + return []byte("+" + r.Status + "\r\n") } /* ---- Int Reply ---- */ type IntReply struct { - Code int64 + Code int64 } func MakeIntReply(code int64) *IntReply { - return &IntReply{ - Code: code, - } + return &IntReply{ + Code: code, + } } func (r *IntReply) ToBytes() []byte { - return []byte(":" + strconv.FormatInt(r.Code, 10) + CRLF) + return []byte(":" + strconv.FormatInt(r.Code, 10) + CRLF) } - /* ---- Error Reply ---- */ type ErrorReply interface { - Error() string - ToBytes() []byte + Error() string + ToBytes() []byte } type StandardErrReply struct { - Status string + Status string } func MakeErrReply(status string) *StandardErrReply { - return &StandardErrReply{ - Status: status, - } + return &StandardErrReply{ + Status: status, + } } func IsErrorReply(reply redis.Reply) bool { - return reply.ToBytes()[0] == '-' + return reply.ToBytes()[0] == '-' } func (r *StandardErrReply) ToBytes() []byte { - return []byte("-" + r.Status + "\r\n") + return []byte("-" + r.Status + "\r\n") } func (r *StandardErrReply) Error() string { - return r.Status -} \ No newline at end of file + return r.Status +} diff --git a/src/redis/server/client.go b/src/redis/server/client.go index 8be64a51..d16efff3 100644 --- a/src/redis/server/client.go +++ b/src/redis/server/client.go @@ -1,95 +1,95 @@ package server import ( - "github.com/HDT3213/godis/src/lib/sync/atomic" - "github.com/HDT3213/godis/src/lib/sync/wait" - "net" - "sync" - "time" + "github.com/HDT3213/godis/src/lib/sync/atomic" + "github.com/HDT3213/godis/src/lib/sync/wait" + "net" + "sync" + "time" ) // abstract of active client type Client struct { - conn net.Conn + conn net.Conn - // waiting util reply finished - waitingReply wait.Wait + // waiting util reply finished + waitingReply wait.Wait - // is sending request in progress - uploading atomic.AtomicBool - // multi bulk msg lineCount - 1(first line) - expectedArgsCount uint32 - // sent line count, exclude first line - receivedCount uint32 - // sent lines, exclude first line - args [][]byte + // is sending request in progress + uploading atomic.AtomicBool + // multi bulk msg lineCount - 1(first line) + expectedArgsCount uint32 + // sent line count, exclude first line + receivedCount uint32 + // sent lines, exclude first line + args [][]byte - // lock while server sending response - mu sync.Mutex + // lock while server sending response + mu sync.Mutex - // subscribing channels - subs map[string]bool + // subscribing channels + subs map[string]bool } -func (c *Client)Close()error { - c.waitingReply.WaitWithTimeout(10 * time.Second) - _ = c.conn.Close() - return nil +func (c *Client) Close() error { + c.waitingReply.WaitWithTimeout(10 * time.Second) + _ = c.conn.Close() + return nil } func MakeClient(conn net.Conn) *Client { - return &Client{ - conn: conn, - } + return &Client{ + conn: conn, + } } -func (c *Client)Write(b []byte)error { - if b == nil || len(b) == 0 { - return nil - } - c.mu.Lock() - defer c.mu.Unlock() +func (c *Client) Write(b []byte) error { + if b == nil || len(b) == 0 { + return nil + } + c.mu.Lock() + defer c.mu.Unlock() - _, err := c.conn.Write(b) - return err + _, err := c.conn.Write(b) + return err } -func (c *Client)SubsChannel(channel string) { - c.mu.Lock() - defer c.mu.Unlock() +func (c *Client) SubsChannel(channel string) { + c.mu.Lock() + defer c.mu.Unlock() - if c.subs == nil { - c.subs = make(map[string]bool) - } - c.subs[channel] = true + if c.subs == nil { + c.subs = make(map[string]bool) + } + c.subs[channel] = true } -func (c *Client)UnSubsChannel(channel string) { - c.mu.Lock() - defer c.mu.Unlock() +func (c *Client) UnSubsChannel(channel string) { + c.mu.Lock() + defer c.mu.Unlock() - if c.subs == nil { - return - } - delete(c.subs, channel) + if c.subs == nil { + return + } + delete(c.subs, channel) } -func (c *Client)SubsCount()int { - if c.subs == nil { - return 0 - } - return len(c.subs) +func (c *Client) SubsCount() int { + if c.subs == nil { + return 0 + } + return len(c.subs) } -func (c *Client)GetChannels()[]string { - if c.subs == nil { - return make([]string, 0) - } - channels := make([]string, len(c.subs)) - i := 0 - for channel := range c.subs { - channels[i] = channel - i++ - } - return channels -} \ No newline at end of file +func (c *Client) GetChannels() []string { + if c.subs == nil { + return make([]string, 0) + } + channels := make([]string, len(c.subs)) + i := 0 + for channel := range c.subs { + channels[i] = channel + i++ + } + return channels +} diff --git a/src/redis/server/handler.go b/src/redis/server/handler.go index 25fc644e..382e9fab 100644 --- a/src/redis/server/handler.go +++ b/src/redis/server/handler.go @@ -5,192 +5,192 @@ package server */ import ( - "bufio" - "context" - "github.com/HDT3213/godis/src/cluster" - "github.com/HDT3213/godis/src/config" - DBImpl "github.com/HDT3213/godis/src/db" - "github.com/HDT3213/godis/src/interface/db" - "github.com/HDT3213/godis/src/lib/logger" - "github.com/HDT3213/godis/src/lib/sync/atomic" - "github.com/HDT3213/godis/src/redis/reply" - "io" - "net" - "strconv" - "strings" - "sync" + "bufio" + "context" + "github.com/HDT3213/godis/src/cluster" + "github.com/HDT3213/godis/src/config" + DBImpl "github.com/HDT3213/godis/src/db" + "github.com/HDT3213/godis/src/interface/db" + "github.com/HDT3213/godis/src/lib/logger" + "github.com/HDT3213/godis/src/lib/sync/atomic" + "github.com/HDT3213/godis/src/redis/reply" + "io" + "net" + "strconv" + "strings" + "sync" ) var ( - UnknownErrReplyBytes = []byte("-ERR unknown\r\n") + UnknownErrReplyBytes = []byte("-ERR unknown\r\n") ) type Handler struct { - activeConn sync.Map // *client -> placeholder - db db.DB - closing atomic.AtomicBool // refusing new client and new request + activeConn sync.Map // *client -> placeholder + db db.DB + closing atomic.AtomicBool // refusing new client and new request } func MakeHandler() *Handler { - var db db.DB - if config.Properties.Peers != nil && - len(config.Properties.Peers) > 0 { - db = cluster.MakeCluster() - } else { - db = DBImpl.MakeDB() - } - return &Handler{ - db: db, - } + var db db.DB + if config.Properties.Peers != nil && + len(config.Properties.Peers) > 0 { + db = cluster.MakeCluster() + } else { + db = DBImpl.MakeDB() + } + return &Handler{ + db: db, + } } func (h *Handler) closeClient(client *Client) { - _ = client.Close() - h.db.AfterClientClose(client) - h.activeConn.Delete(client) + _ = client.Close() + h.db.AfterClientClose(client) + h.activeConn.Delete(client) } func (h *Handler) Handle(ctx context.Context, conn net.Conn) { - if h.closing.Get() { - // closing handler refuse new connection - _ = conn.Close() - } - - client := MakeClient(conn) - h.activeConn.Store(client, 1) - - reader := bufio.NewReader(conn) - var fixedLen int64 = 0 - var err error - var msg []byte - for { - if fixedLen == 0 { - msg, err = reader.ReadBytes('\n') - if err != nil { - if err == io.EOF || - err == io.ErrUnexpectedEOF || - strings.Contains(err.Error(), "use of closed network connection") { - logger.Info("connection close") - } else { - logger.Warn(err) - } - - // after client close - h.closeClient(client) - return // io error, disconnect with client - } - if len(msg) == 0 || msg[len(msg)-2] != '\r' { - errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"} - _, _ = client.conn.Write(errReply.ToBytes()) - } - } else { - msg = make([]byte, fixedLen+2) - _, err = io.ReadFull(reader, msg) - if err != nil { - if err == io.EOF || - err == io.ErrUnexpectedEOF || - strings.Contains(err.Error(), "use of closed network connection") { - logger.Info("connection close") - } else { - logger.Warn(err) - } - - // after client close - h.closeClient(client) - return // io error, disconnect with client - } - if len(msg) == 0 || - msg[len(msg)-2] != '\r' || - msg[len(msg)-1] != '\n' { - errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"} - _, _ = client.conn.Write(errReply.ToBytes()) - } - fixedLen = 0 - } - - if !client.uploading.Get() { - // new request - if msg[0] == '*' { - // bulk multi msg - expectedLine, err := strconv.ParseUint(string(msg[1:len(msg)-2]), 10, 32) - if err != nil { - _, _ = client.conn.Write(UnknownErrReplyBytes) - continue - } - client.waitingReply.Add(1) - client.uploading.Set(true) - client.expectedArgsCount = uint32(expectedLine) - client.receivedCount = 0 - client.args = make([][]byte, expectedLine) - } else { - // text protocol - // remove \r or \n or \r\n in the end of line - str := strings.TrimSuffix(string(msg), "\n") - str = strings.TrimSuffix(str, "\r") - strs := strings.Split(str, " ") - args := make([][]byte, len(strs)) - for i, s := range strs { - args[i] = []byte(s) - } - - // send reply - result := h.db.Exec(client, args) - if result != nil { - _ = client.Write(result.ToBytes()) - } else { - _ = client.Write(UnknownErrReplyBytes) - } - } - } else { - // receive following part of a request - line := msg[0 : len(msg)-2] - if line[0] == '$' { - fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64) - if err != nil { - errReply := &reply.ProtocolErrReply{Msg: err.Error()} - _, _ = client.conn.Write(errReply.ToBytes()) - } - if fixedLen <= 0 { - errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"} - _, _ = client.conn.Write(errReply.ToBytes()) - } - } else { - client.args[client.receivedCount] = line - client.receivedCount++ - } - - // if sending finished - if client.receivedCount == client.expectedArgsCount { - client.uploading.Set(false) // finish sending progress - - // send reply - result := h.db.Exec(client, client.args) - if result != nil { - _ = client.Write(result.ToBytes()) - } else { - _ = client.Write(UnknownErrReplyBytes) - } - - // finish reply - client.expectedArgsCount = 0 - client.receivedCount = 0 - client.args = nil - client.waitingReply.Done() - } - } - - } + if h.closing.Get() { + // closing handler refuse new connection + _ = conn.Close() + } + + client := MakeClient(conn) + h.activeConn.Store(client, 1) + + reader := bufio.NewReader(conn) + var fixedLen int64 = 0 + var err error + var msg []byte + for { + if fixedLen == 0 { + msg, err = reader.ReadBytes('\n') + if err != nil { + if err == io.EOF || + err == io.ErrUnexpectedEOF || + strings.Contains(err.Error(), "use of closed network connection") { + logger.Info("connection close") + } else { + logger.Warn(err) + } + + // after client close + h.closeClient(client) + return // io error, disconnect with client + } + if len(msg) == 0 || msg[len(msg)-2] != '\r' { + errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"} + _, _ = client.conn.Write(errReply.ToBytes()) + } + } else { + msg = make([]byte, fixedLen+2) + _, err = io.ReadFull(reader, msg) + if err != nil { + if err == io.EOF || + err == io.ErrUnexpectedEOF || + strings.Contains(err.Error(), "use of closed network connection") { + logger.Info("connection close") + } else { + logger.Warn(err) + } + + // after client close + h.closeClient(client) + return // io error, disconnect with client + } + if len(msg) == 0 || + msg[len(msg)-2] != '\r' || + msg[len(msg)-1] != '\n' { + errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"} + _, _ = client.conn.Write(errReply.ToBytes()) + } + fixedLen = 0 + } + + if !client.uploading.Get() { + // new request + if msg[0] == '*' { + // bulk multi msg + expectedLine, err := strconv.ParseUint(string(msg[1:len(msg)-2]), 10, 32) + if err != nil { + _, _ = client.conn.Write(UnknownErrReplyBytes) + continue + } + client.waitingReply.Add(1) + client.uploading.Set(true) + client.expectedArgsCount = uint32(expectedLine) + client.receivedCount = 0 + client.args = make([][]byte, expectedLine) + } else { + // text protocol + // remove \r or \n or \r\n in the end of line + str := strings.TrimSuffix(string(msg), "\n") + str = strings.TrimSuffix(str, "\r") + strs := strings.Split(str, " ") + args := make([][]byte, len(strs)) + for i, s := range strs { + args[i] = []byte(s) + } + + // send reply + result := h.db.Exec(client, args) + if result != nil { + _ = client.Write(result.ToBytes()) + } else { + _ = client.Write(UnknownErrReplyBytes) + } + } + } else { + // receive following part of a request + line := msg[0 : len(msg)-2] + if line[0] == '$' { + fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64) + if err != nil { + errReply := &reply.ProtocolErrReply{Msg: err.Error()} + _, _ = client.conn.Write(errReply.ToBytes()) + } + if fixedLen <= 0 { + errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"} + _, _ = client.conn.Write(errReply.ToBytes()) + } + } else { + client.args[client.receivedCount] = line + client.receivedCount++ + } + + // if sending finished + if client.receivedCount == client.expectedArgsCount { + client.uploading.Set(false) // finish sending progress + + // send reply + result := h.db.Exec(client, client.args) + if result != nil { + _ = client.Write(result.ToBytes()) + } else { + _ = client.Write(UnknownErrReplyBytes) + } + + // finish reply + client.expectedArgsCount = 0 + client.receivedCount = 0 + client.args = nil + client.waitingReply.Done() + } + } + + } } func (h *Handler) Close() error { - logger.Info("handler shuting down...") - h.closing.Set(true) - // TODO: concurrent wait - h.activeConn.Range(func(key interface{}, val interface{}) bool { - client := key.(*Client) - _ = client.Close() - return true - }) - h.db.Close() - return nil + logger.Info("handler shuting down...") + h.closing.Set(true) + // TODO: concurrent wait + h.activeConn.Range(func(key interface{}, val interface{}) bool { + client := key.(*Client) + _ = client.Close() + return true + }) + h.db.Close() + return nil } diff --git a/src/tcp/echo.go b/src/tcp/echo.go index 4b581e90..773ca6f5 100644 --- a/src/tcp/echo.go +++ b/src/tcp/echo.go @@ -5,79 +5,78 @@ package tcp */ import ( - "net" - "context" - "bufio" - "github.com/HDT3213/godis/src/lib/logger" - "sync" - "io" - "github.com/HDT3213/godis/src/lib/sync/atomic" - "time" - "github.com/HDT3213/godis/src/lib/sync/wait" + "bufio" + "context" + "github.com/HDT3213/godis/src/lib/logger" + "github.com/HDT3213/godis/src/lib/sync/atomic" + "github.com/HDT3213/godis/src/lib/sync/wait" + "io" + "net" + "sync" + "time" ) type EchoHandler struct { - activeConn sync.Map - closing atomic.AtomicBool + activeConn sync.Map + closing atomic.AtomicBool } -func MakeEchoHandler()(*EchoHandler) { - return &EchoHandler{ - } +func MakeEchoHandler() *EchoHandler { + return &EchoHandler{} } type Client struct { - Conn net.Conn - Waiting wait.Wait + Conn net.Conn + Waiting wait.Wait } -func (c *Client)Close()error { - c.Waiting.WaitWithTimeout(10 * time.Second) - c.Conn.Close() - return nil +func (c *Client) Close() error { + c.Waiting.WaitWithTimeout(10 * time.Second) + c.Conn.Close() + return nil } -func (h *EchoHandler)Handle(ctx context.Context, conn net.Conn) { - if h.closing.Get() { - // closing handler refuse new connection - conn.Close() - } +func (h *EchoHandler) Handle(ctx context.Context, conn net.Conn) { + if h.closing.Get() { + // closing handler refuse new connection + conn.Close() + } - client := &Client { - Conn: conn, - } - h.activeConn.Store(client, 1) + client := &Client{ + Conn: conn, + } + h.activeConn.Store(client, 1) - reader := bufio.NewReader(conn) - for { - // may occurs: client EOF, client timeout, server early close - msg, err := reader.ReadString('\n') - if err != nil { - if err == io.EOF { - logger.Info("connection close") - h.activeConn.Delete(client) - } else { - logger.Warn(err) - } - return - } - client.Waiting.Add(1) - //logger.Info("sleeping") - //time.Sleep(10 * time.Second) - b := []byte(msg) - conn.Write(b) - client.Waiting.Done() - } + reader := bufio.NewReader(conn) + for { + // may occurs: client EOF, client timeout, server early close + msg, err := reader.ReadString('\n') + if err != nil { + if err == io.EOF { + logger.Info("connection close") + h.activeConn.Delete(client) + } else { + logger.Warn(err) + } + return + } + client.Waiting.Add(1) + //logger.Info("sleeping") + //time.Sleep(10 * time.Second) + b := []byte(msg) + conn.Write(b) + client.Waiting.Done() + } } -func (h *EchoHandler)Close()error { - logger.Info("handler shuting down...") - h.closing.Set(true) - // TODO: concurrent wait - h.activeConn.Range(func(key interface{}, val interface{})bool { - client := key.(*Client) - client.Close() - return true - }) - return nil -} \ No newline at end of file +func (h *EchoHandler) Close() error { + logger.Info("handler shuting down...") + h.closing.Set(true) + // TODO: concurrent wait + h.activeConn.Range(func(key interface{}, val interface{}) bool { + client := key.(*Client) + client.Close() + return true + }) + return nil +} diff --git a/src/tcp/server.go b/src/tcp/server.go index 0f3dab69..0b4d1370 100644 --- a/src/tcp/server.go +++ b/src/tcp/server.go @@ -5,75 +5,74 @@ package tcp */ import ( - "context" - "fmt" - "github.com/HDT3213/godis/src/interface/tcp" - "github.com/HDT3213/godis/src/lib/logger" - "github.com/HDT3213/godis/src/lib/sync/atomic" - "net" - "os" - "os/signal" - "sync" - "syscall" - "time" + "context" + "fmt" + "github.com/HDT3213/godis/src/interface/tcp" + "github.com/HDT3213/godis/src/lib/logger" + "github.com/HDT3213/godis/src/lib/sync/atomic" + "net" + "os" + "os/signal" + "sync" + "syscall" + "time" ) type Config struct { - Address string `yaml:"address"` - MaxConnect uint32 `yaml:"max-connect"` - Timeout time.Duration `yaml:"timeout"` + Address string `yaml:"address"` + MaxConnect uint32 `yaml:"max-connect"` + Timeout time.Duration `yaml:"timeout"` } func ListenAndServe(cfg *Config, handler tcp.Handler) { - listener, err := net.Listen("tcp", cfg.Address) - if err != nil { - logger.Fatal(fmt.Sprintf("listen err: %v", err)) - } + listener, err := net.Listen("tcp", cfg.Address) + if err != nil { + logger.Fatal(fmt.Sprintf("listen err: %v", err)) + } - // listen signal - var closing atomic.AtomicBool - sigCh := make(chan os.Signal, 1) - signal.Notify(sigCh, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT) - go func() { - sig := <-sigCh - switch sig { - case syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT: - logger.Info("shuting down...") - closing.Set(true) - _ = listener.Close() // listener.Accept() will return err immediately - _ = handler.Close() // close connections - } - }() + // listen signal + var closing atomic.AtomicBool + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT) + go func() { + sig := <-sigCh + switch sig { + case syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT: + logger.Info("shuting down...") + closing.Set(true) + _ = listener.Close() // listener.Accept() will return err immediately + _ = handler.Close() // close connections + } + }() - - // listen port - logger.Info(fmt.Sprintf("bind: %s, start listening...", cfg.Address)) - defer func() { - // close during unexpected error - _ = listener.Close() - _ = handler.Close() - }() - ctx, _ := context.WithCancel(context.Background()) - var waitDone sync.WaitGroup - for { - conn, err := listener.Accept() - if err != nil { - if closing.Get() { - logger.Info("waiting disconnect...") - waitDone.Wait() - return // handler will be closed by defer - } - logger.Error(fmt.Sprintf("accept err: %v", err)) - continue - } - // handle - logger.Info("accept link") - waitDone.Add(1) - go func() { - defer func() { - waitDone.Done() - }() - handler.Handle(ctx, conn) - }() - } + // listen port + logger.Info(fmt.Sprintf("bind: %s, start listening...", cfg.Address)) + defer func() { + // close during unexpected error + _ = listener.Close() + _ = handler.Close() + }() + ctx, _ := context.WithCancel(context.Background()) + var waitDone sync.WaitGroup + for { + conn, err := listener.Accept() + if err != nil { + if closing.Get() { + logger.Info("waiting disconnect...") + waitDone.Wait() + return // handler will be closed by defer + } + logger.Error(fmt.Sprintf("accept err: %v", err)) + continue + } + // handle + logger.Info("accept link") + waitDone.Add(1) + go func() { + defer func() { + waitDone.Done() + }() + handler.Handle(ctx, conn) + }() + } }