diff --git a/pkg/users/users.go b/pkg/users/users.go index 9c69c78..878b58f 100644 --- a/pkg/users/users.go +++ b/pkg/users/users.go @@ -2,6 +2,8 @@ package users import ( + "errors" + "fmt" "strconv" ) @@ -68,6 +70,7 @@ type UserList interface { // GetAll returns all users in the list GetAll() ([]User, error) GenerateUID() int + GenerateUIDInRange(int, int) (int, error) LastUID() int SetPath(path string) Load() error @@ -99,3 +102,31 @@ func (list CommonUserList) GenerateUID() int { } return list.lastUID + 1 } + +// Finds the lowest available uid in the specified range. +// Returns an error if there is no available uid in that range. +func (list CommonUserList) GenerateUIDInRange(minimum, maximum int) (int, error) { + userSet := make(map[int]struct{}) + for _, user := range list.users { + uid, err := user.UID() + if err != nil { + return -1, fmt.Errorf("getting user's uid: %w", err) + } + userSet[uid] = struct{}{} + } + + result := -1 + for i := minimum; i <= maximum; i++ { + if _, found := userSet[i]; found { + continue // uid in use, skip it + } + result = i // found a free one, stop here + break + } + + if result == -1 { + return result, errors.New("no available uid in range") + } + + return result, nil +} diff --git a/pkg/users/users_linux_test.go b/pkg/users/users_linux_test.go index 703f900..fefb463 100644 --- a/pkg/users/users_linux_test.go +++ b/pkg/users/users_linux_test.go @@ -9,6 +9,60 @@ import ( ) var _ = Describe("LinuxUserList", func() { + Describe("GenerateUIDInRange", func() { + var file *os.File + var err error + var list LinuxUserList + + BeforeEach(func() { + file, err = os.CreateTemp("", "passwd") + Expect(err).ToNot(HaveOccurred()) + DeferCleanup(func() { + defer os.Remove(file.Name()) + }) + + _, err = file.WriteString("root:x:0:0:root:/root:/bin/bash\n") + Expect(err).ToNot(HaveOccurred()) + _, err = file.WriteString("foo:x:1000:1000:foo:/home/foo:/bin/bash\n") + Expect(err).ToNot(HaveOccurred()) + _, err = file.WriteString("foo:x:1001:1000:foo:/home/foo:/bin/bash\n") + _, err = file.WriteString("foo:x:1001:1000:foo:/home/foo:/bin/bash\n") + Expect(err).ToNot(HaveOccurred()) + + list = LinuxUserList{} + list.SetPath(file.Name()) + Expect(list.Load()).ToNot(HaveOccurred()) + }) + + When("a uid is available in the range", func() { + var minimum, maximum int + BeforeEach(func() { + minimum = 1000 + maximum = 2000 + }) + + It("returns the minimum available uid", func() { + r, err := list.GenerateUIDInRange(minimum, maximum) + Expect(err).ToNot(HaveOccurred()) + Expect(r).To(Equal(1002)) + }) + }) + + When("there is no available uid", func() { + var minimum, maximum int + BeforeEach(func() { + minimum = 1000 + maximum = 1001 + }) + + It("returns an error", func() { + _, err := list.GenerateUIDInRange(minimum, maximum) + Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError("no available uid in range")) + }) + }) + }) + It("parses a record", func() { rootRecord := `root:x:0:0:root:/root:/bin/bash`