Skip to content

Commit 383357b

Browse files
authored
Use atomic pointer for map access synchronization (#12)
* use atomic pointer for map access synchronization Signed-off-by: Vladislav Yarmak <[email protected]> * restore coverage Signed-off-by: Vladislav Yarmak <[email protected]> --------- Signed-off-by: Vladislav Yarmak <[email protected]>
1 parent 27f6cbb commit 383357b

File tree

3 files changed

+26
-22
lines changed

3 files changed

+26
-22
lines changed

htgroup.go

+5-13
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import (
1616
"io"
1717
"os"
1818
"strings"
19-
"sync"
19+
"sync/atomic"
2020
)
2121

2222
// Data structure for users and theirs groups (map).
@@ -26,8 +26,7 @@ type userGroupMap map[string][]string
2626
// A HTGroup encompasses an Apache-style group file.
2727
type HTGroup struct {
2828
filePath string
29-
mutex sync.RWMutex
30-
userGroups userGroupMap
29+
userGroups atomic.Pointer[userGroupMap]
3130
}
3231

3332
// NewGroups creates a HTGroup from an Apache-style group file.
@@ -56,10 +55,7 @@ func NewGroupsFromReader(r io.Reader, bad BadLineHandler) (*HTGroup, error) {
5655

5756
// ReloadGroups rereads the group file.
5857
func (htGroup *HTGroup) ReloadGroups(bad BadLineHandler) error {
59-
htGroup.mutex.Lock()
60-
filename := htGroup.filePath
61-
htGroup.mutex.Unlock()
62-
file, err := os.Open(filename)
58+
file, err := os.Open(htGroup.filePath)
6359
if err != nil {
6460
return err
6561
}
@@ -83,9 +79,7 @@ func (htGroup *HTGroup) ReloadGroupsFromReader(r io.Reader, bad BadLineHandler)
8379
return fmt.Errorf("Error scanning group file: %s", scannerErr.Error())
8480
}
8581

86-
htGroup.mutex.Lock()
87-
htGroup.userGroups = userGroups
88-
htGroup.mutex.Unlock()
82+
htGroup.userGroups.Store(&userGroups)
8983

9084
return nil
9185
}
@@ -123,9 +117,7 @@ func (htGroup *HTGroup) IsUserInGroup(user string, group string) bool {
123117
// GetUserGroups reads all groups of a user.
124118
// Returns all groups as a string array or an empty array.
125119
func (htGroup *HTGroup) GetUserGroups(user string) []string {
126-
htGroup.mutex.RLock()
127-
groups := htGroup.userGroups[user]
128-
htGroup.mutex.RUnlock()
120+
groups := (*htGroup.userGroups.Load())[user]
129121

130122
if groups == nil {
131123
return []string{}

htgroup_test.go

+17
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package htpasswd
22

33
import (
44
"os"
5+
"strings"
56
"testing"
67

78
"github.com/stretchr/testify/assert"
@@ -66,4 +67,20 @@ func TestGroups(t *testing.T) {
6667
assert.Len(t, htGroup.GetUserGroups("user2"), 2)
6768
assert.Len(t, htGroup.GetUserGroups("user3"), 1)
6869
assert.Len(t, htGroup.GetUserGroups("unknownuser"), 0)
70+
71+
// Test load from reader as well
72+
r := strings.NewReader(contents2)
73+
htGroup, err = NewGroupsFromReader(r, nil)
74+
assert.NoError(t, err)
75+
assert.True(t, htGroup.IsUserInGroup("user1", "users"))
76+
assert.True(t, htGroup.IsUserInGroup("user1", "admins"))
77+
assert.True(t, htGroup.IsUserInGroup("user2", "users"))
78+
assert.True(t, htGroup.IsUserInGroup("user2", "admins"))
79+
assert.False(t, htGroup.IsUserInGroup("unknownuser", "users"))
80+
assert.False(t, htGroup.IsUserInGroup("user1", "unknowngroup"))
81+
assert.False(t, htGroup.IsUserInGroup("unknownuser", "unknowngroup"))
82+
assert.Len(t, htGroup.GetUserGroups("user1"), 2)
83+
assert.Len(t, htGroup.GetUserGroups("user2"), 2)
84+
assert.Len(t, htGroup.GetUserGroups("user3"), 1)
85+
assert.Len(t, htGroup.GetUserGroups("unknownuser"), 0)
6986
}

htpasswd.go

+4-9
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import (
1919
"io"
2020
"os"
2121
"strings"
22-
"sync"
22+
"sync/atomic"
2323
)
2424

2525
// An EncodedPasswd is created from the encoded password in a password file by a PasswdParser.
@@ -53,8 +53,7 @@ type BadLineHandler func(err error)
5353
// An File encompasses an Apache-style htpasswd file for HTTP Basic authentication
5454
type File struct {
5555
filePath string
56-
mutex sync.RWMutex
57-
passwds passwdTable
56+
passwds atomic.Pointer[passwdTable]
5857
parsers []PasswdParser
5958
}
6059

@@ -104,9 +103,7 @@ func NewFromReader(r io.Reader, parsers []PasswdParser, bad BadLineHandler) (*Fi
104103
// Match checks the username and password combination to see if it represents
105104
// a valid account from the htpassword file.
106105
func (bf *File) Match(username, password string) bool {
107-
bf.mutex.RLock()
108-
matcher, ok := bf.passwds[username]
109-
bf.mutex.RUnlock()
106+
matcher, ok := (*bf.passwds.Load())[username]
110107

111108
if ok && matcher.MatchesPassword(password) {
112109
// we are good
@@ -154,9 +151,7 @@ func (bf *File) ReloadFromReader(r io.Reader, bad BadLineHandler) error {
154151
}
155152

156153
// .. finally, safely swap in the new map
157-
bf.mutex.Lock()
158-
bf.passwds = newPasswdMap
159-
bf.mutex.Unlock()
154+
bf.passwds.Store(&newPasswdMap)
160155

161156
return nil
162157
}

0 commit comments

Comments
 (0)