Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: preliminary tests and merge fix for authv2 #3584

Merged
merged 22 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .devcontainer-scripts/utils.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# Param 2: email
#
config_user() {
echo "Configuring git for $1 <$2>"
local gcn=$(git config --global user.name)
if [ -z "${gcn}" ]; then
echo "Setting up git user / remote"
Expand All @@ -24,6 +25,7 @@ config_user() {
# Param 2: remote url
#
config_remote() {
echo "Adding git remote and fetching $2 as $1"
local gr=$(git remote -v | grep $1)
if [ -z "${gr}" ]; then
git remote add $1 $2
Expand Down
5 changes: 2 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,8 @@ RUN if [ "${FFMPEG}" = "true" ]; then \

RUN apt-get update && \
apt-get install -y --no-install-recommends \
ssh less && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
ssh less wget
# For the devcontainer, leave apt functional in case additional devtools are needed at runtime.

RUN go install github.com/go-delve/delve/cmd/dlv@latest

Expand Down
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,9 @@ clean-tests:
rm -rf test-dir
rm -rf core/http/backend-assets

clean-dc: clean
cp -r /build/backend-assets /workspace/backend-assets

## Build:
build: prepare backend-assets grpcs ## Build the project
$(info ${GREEN}I local-ai build info:${RESET})
Expand Down
4 changes: 2 additions & 2 deletions core/gallery/gallery.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func AvailableGalleryModels(galleries []config.Gallery, basePath string) ([]*Gal
func findGalleryURLFromReferenceURL(url string, basePath string) (string, error) {
var refFile string
uri := downloader.URI(url)
err := uri.DownloadAndUnmarshal(basePath, func(url string, d []byte) error {
err := uri.DownloadWithCallback(basePath, func(url string, d []byte) error {
refFile = string(d)
if len(refFile) == 0 {
return fmt.Errorf("invalid reference file at url %s: %s", url, d)
Expand All @@ -156,7 +156,7 @@ func getGalleryModels(gallery config.Gallery, basePath string) ([]*GalleryModel,
}
uri := downloader.URI(gallery.URL)

err := uri.DownloadAndUnmarshal(basePath, func(url string, d []byte) error {
err := uri.DownloadWithCallback(basePath, func(url string, d []byte) error {
return yaml.Unmarshal(d, &models)
})
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion core/gallery/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ type PromptTemplate struct {
func GetGalleryConfigFromURL(url string, basePath string) (Config, error) {
var config Config
uri := downloader.URI(url)
err := uri.DownloadAndUnmarshal(basePath, func(url string, d []byte) error {
err := uri.DownloadWithCallback(basePath, func(url string, d []byte) error {
return yaml.Unmarshal(d, &config)
})
if err != nil {
Expand Down
18 changes: 0 additions & 18 deletions core/http/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,6 @@ import (
"github.com/rs/zerolog/log"
)

func readAuthHeader(c *fiber.Ctx) string {
authHeader := c.Get("Authorization")

// elevenlabs
xApiKey := c.Get("xi-api-key")
if xApiKey != "" {
authHeader = "Bearer " + xApiKey
}

// anthropic
xApiKey = c.Get("x-api-key")
if xApiKey != "" {
authHeader = "Bearer " + xApiKey
}

return authHeader
}

// Embed a directory
//
//go:embed static/*
Expand Down
69 changes: 62 additions & 7 deletions core/http/app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ import (
"github.com/sashabaranov/go-openai/jsonschema"
)

const apiKey = "joshua"
const bearerKey = "Bearer " + apiKey

const testPrompt = `### System:
You are an AI assistant that follows instruction extremely well. Help as much as you can.

Expand All @@ -50,11 +53,19 @@ type modelApplyRequest struct {

func getModelStatus(url string) (response map[string]interface{}) {
// Create the HTTP request
resp, err := http.Get(url)
req, err := http.NewRequest("GET", url, nil)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)
if err != nil {
fmt.Println("Error creating request:", err)
return
}
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
fmt.Println("Error sending request:", err)
return
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
Expand All @@ -72,14 +83,15 @@ func getModelStatus(url string) (response map[string]interface{}) {
return
}

func getModels(url string) (response []gallery.GalleryModel) {
func getModels(url string) ([]gallery.GalleryModel, error) {
response := []gallery.GalleryModel{}
uri := downloader.URI(url)
// TODO: No tests currently seem to exercise file:// urls. Fix?
uri.DownloadAndUnmarshal("", func(url string, i []byte) error {
err := uri.DownloadWithAuthorizationAndCallback("", bearerKey, func(url string, i []byte) error {
// Unmarshal YAML data into a struct
return json.Unmarshal(i, &response)
})
return
return response, err
}

func postModelApplyRequest(url string, request modelApplyRequest) (response map[string]interface{}) {
Expand All @@ -101,6 +113,7 @@ func postModelApplyRequest(url string, request modelApplyRequest) (response map[
return
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)

// Make the request
client := &http.Client{}
Expand Down Expand Up @@ -140,6 +153,7 @@ func postRequestJSON[B any](url string, bodyJson *B) error {
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)

client := &http.Client{}
resp, err := client.Do(req)
Expand Down Expand Up @@ -175,6 +189,7 @@ func postRequestResponseJSON[B1 any, B2 any](url string, reqJson *B1, respJson *
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)

client := &http.Client{}
resp, err := client.Do(req)
Expand All @@ -195,6 +210,35 @@ func postRequestResponseJSON[B1 any, B2 any](url string, reqJson *B1, respJson *
return json.Unmarshal(body, respJson)
}

func postInvalidRequest(url string) (error, int) {

req, err := http.NewRequest("POST", url, bytes.NewBufferString("invalid request"))
if err != nil {
return err, -1
}

req.Header.Set("Content-Type", "application/json")

client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err, -1
}

defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
return err, -1
}

if resp.StatusCode < 200 || resp.StatusCode >= 400 {
return fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body)), resp.StatusCode
}

return nil, resp.StatusCode
}

//go:embed backend-assets/*
var backendAssets embed.FS

Expand Down Expand Up @@ -260,6 +304,7 @@ var _ = Describe("API test", func() {
config.WithContext(c),
config.WithGalleries(galleries),
config.WithModelPath(modelDir),
config.WithApiKeys([]string{apiKey}),
config.WithBackendAssets(backendAssets),
config.WithBackendAssetsOutput(backendAssetsDir))...)
Expect(err).ToNot(HaveOccurred())
Expand All @@ -269,7 +314,7 @@ var _ = Describe("API test", func() {

go app.Listen("127.0.0.1:9090")

defaultConfig := openai.DefaultConfig("")
defaultConfig := openai.DefaultConfig(apiKey)
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"

client2 = openaigo.NewClient("")
Expand All @@ -295,10 +340,19 @@ var _ = Describe("API test", func() {
Expect(err).To(HaveOccurred())
})

Context("Auth Tests", func() {
It("Should fail if the api key is missing", func() {
err, sc := postInvalidRequest("http://127.0.0.1:9090/models/available")
Expect(err).ToNot(BeNil())
Expect(sc).To(Equal(403))
})
})

Context("Applying models", func() {

It("applies models from a gallery", func() {
models := getModels("http://127.0.0.1:9090/models/available")
models, err := getModels("http://127.0.0.1:9090/models/available")
Expect(err).To(BeNil())
Expect(len(models)).To(Equal(2), fmt.Sprint(models))
Expect(models[0].Installed).To(BeFalse(), fmt.Sprint(models))
Expect(models[1].Installed).To(BeFalse(), fmt.Sprint(models))
Expand Down Expand Up @@ -331,7 +385,8 @@ var _ = Describe("API test", func() {
Expect(content["backend"]).To(Equal("bert-embeddings"))
Expect(content["foo"]).To(Equal("bar"))

models = getModels("http://127.0.0.1:9090/models/available")
models, err = getModels("http://127.0.0.1:9090/models/available")
Expect(err).To(BeNil())
Expect(len(models)).To(Equal(2), fmt.Sprint(models))
Expect(models[0].Name).To(Or(Equal("bert"), Equal("bert2")))
Expect(models[1].Name).To(Or(Equal("bert"), Equal("bert2")))
Expand Down
3 changes: 2 additions & 1 deletion core/http/middleware/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) fiber.Er
if applicationConfig.OpaqueErrors {
return ctx.SendStatus(403)
}
return ctx.Status(403).SendString(err.Error())
}
if applicationConfig.OpaqueErrors {
return ctx.SendStatus(500)
Expand Down Expand Up @@ -90,4 +91,4 @@ func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig
}
}
return func(c *fiber.Ctx) bool { return false }
}
}
2 changes: 1 addition & 1 deletion embedded/embedded.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func init() {
func GetRemoteLibraryShorteners(url string, basePath string) (map[string]string, error) {
remoteLibrary := map[string]string{}
uri := downloader.URI(url)
err := uri.DownloadAndUnmarshal(basePath, func(_ string, i []byte) error {
err := uri.DownloadWithCallback(basePath, func(_ string, i []byte) error {
return yaml.Unmarshal(i, &remoteLibrary)
})
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
module github.com/mudler/LocalAI

go 1.22.0
go 1.23

toolchain go1.22.4
toolchain go1.23.1

require (
dario.cat/mergo v1.0.0
Expand Down
18 changes: 15 additions & 3 deletions pkg/downloader/uri.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ const (

type URI string

func (uri URI) DownloadAndUnmarshal(basePath string, f func(url string, i []byte) error) error {
func (uri URI) DownloadWithCallback(basePath string, f func(url string, i []byte) error) error {
return uri.DownloadWithAuthorizationAndCallback(basePath, "", f)
}

func (uri URI) DownloadWithAuthorizationAndCallback(basePath string, authorization string, f func(url string, i []byte) error) error {
url := uri.ResolveURL()

if strings.HasPrefix(url, LocalPrefix) {
Expand All @@ -41,7 +45,6 @@ func (uri URI) DownloadAndUnmarshal(basePath string, f func(url string, i []byte
if err != nil {
return err
}
// ???
resolvedBasePath, err := filepath.EvalSymlinks(basePath)
if err != nil {
return err
Expand All @@ -63,7 +66,16 @@ func (uri URI) DownloadAndUnmarshal(basePath string, f func(url string, i []byte
}

// Send a GET request to the URL
response, err := http.Get(url)

req, err := http.NewRequest("GET", url, nil)
if err != nil {
return err
}
if authorization != "" {
req.Header.Add("Authorization", authorization)
}

response, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/downloader/uri_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ var _ = Describe("Gallery API tests", func() {
It("parses github with a branch", func() {
uri := URI("github:go-skynet/model-gallery/gpt4all-j.yaml")
Expect(
uri.DownloadAndUnmarshal("", func(url string, i []byte) error {
uri.DownloadWithCallback("", func(url string, i []byte) error {
Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
return nil
}),
Expand All @@ -21,7 +21,7 @@ var _ = Describe("Gallery API tests", func() {
uri := URI("github:go-skynet/model-gallery/gpt4all-j.yaml@main")

Expect(
uri.DownloadAndUnmarshal("", func(url string, i []byte) error {
uri.DownloadWithCallback("", func(url string, i []byte) error {
Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
return nil
}),
Expand All @@ -30,7 +30,7 @@ var _ = Describe("Gallery API tests", func() {
It("parses github with urls", func() {
uri := URI("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")
Expect(
uri.DownloadAndUnmarshal("", func(url string, i []byte) error {
uri.DownloadWithCallback("", func(url string, i []byte) error {
Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
return nil
}),
Expand Down