From 69e4b72bba89e5f032bacf24fdf2480eb49e9c93 Mon Sep 17 00:00:00 2001 From: Maxim Ivanov Date: Sun, 14 Aug 2016 20:15:09 +0100 Subject: [PATCH] Upsert method - inserts new or updates existing element Upsert method uses callback which return value is inserted into a map. Callback is given 3 arguments: - exists - whether given key already exists in map - valueInMap - existing value in map - valueBeingAdded - value which was passed to Upsert method Return value of a callback is used to save new value in map. This approach allows map users to do non-trivial transformations of value stored in map, while still protected by a lock It is somewhat ovelaps with #29, but more generic and allows more use cases --- concurrent_map.go | 17 ++++++++++++++ concurrent_map_test.go | 53 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+) diff --git a/concurrent_map.go b/concurrent_map.go index 9f6e31d..2993038 100644 --- a/concurrent_map.go +++ b/concurrent_map.go @@ -50,6 +50,23 @@ func (m *ConcurrentMap) Set(key string, value interface{}) { shard.Unlock() } +// Callback to return new element to be inserted into the map +// It is called while lock is held, therefore it MUST NOT +// try to access other keys in same map, as it can lead to deadlock since +// Go sync.RWLock is not reentrant +type UpsertCb func(exist bool, valueInMap interface{}, newValue interface{}) interface{} + +// Insert or Update - updates existing element or inserts a new one using UpsertCb +func (m *ConcurrentMap) Upsert(key string, value interface{}, cb UpsertCb) (res interface{}) { + shard := m.GetShard(key) + shard.Lock() + v, ok := shard.items[key] + res = cb(ok, v, value) + shard.items[key] = res + shard.Unlock() + return res +} + // Sets the given value under the specified key if no value was associated with it. func (m *ConcurrentMap) SetIfAbsent(key string, value interface{}) bool { // Get map shard. diff --git a/concurrent_map_test.go b/concurrent_map_test.go index 603a7e8..3f0411e 100644 --- a/concurrent_map_test.go +++ b/concurrent_map_test.go @@ -327,3 +327,56 @@ func TestFnv32(t *testing.T) { t.Errorf("Bundled fnv32 produced %d, expected result from hash/fnv32 is %d", fnv32(key), hasher.Sum32()) } } + +func TestUpsert(t *testing.T) { + dolphin := Animal{"dolphin"} + whale := Animal{"whale"} + tiger := Animal{"tiger"} + lion := Animal{"lion"} + + cb := func(exists bool, valueInMap interface{}, newValue interface{}) interface{} { + nv := newValue.(Animal) + if !exists { + return []Animal{nv} + } + res := valueInMap.([]Animal) + return append(res, nv) + } + + m := New() + m.Set("marine", []Animal{dolphin}) + m.Upsert("marine", whale, cb) + m.Upsert("predator", tiger, cb) + m.Upsert("predator", lion, cb) + + if m.Count() != 2 { + t.Error("map should contain exactly two elements.") + } + + compare := func(a, b []Animal) bool { + if a == nil || b == nil { + return false + } + + if len(a) != len(b) { + return false + } + + for i, v := range a { + if v != b[i] { + return false + } + } + return true + } + + marineAnimals, ok := m.Get("marine") + if !ok || !compare(marineAnimals.([]Animal), []Animal{dolphin, whale}) { + t.Error("Set, then Upsert failed") + } + + predators, ok := m.Get("predator") + if !ok || !compare(predators.([]Animal), []Animal{tiger, lion}) { + t.Error("Upsert, then Upsert failed") + } +}