From 6c89102d687e36d3dea905399f5775c4cc19a647 Mon Sep 17 00:00:00 2001 From: schristou88 Date: Mon, 26 Feb 2024 05:06:45 -0800 Subject: [PATCH 1/6] Initial implementation of assistants api --- api/openai/assistant.go | 470 ++++++++++++++++++++++++++++++++++++++++ core/http/api.go | 20 ++ 2 files changed, 490 insertions(+) create mode 100644 api/openai/assistant.go diff --git a/api/openai/assistant.go b/api/openai/assistant.go new file mode 100644 index 00000000000..58bb52f17df --- /dev/null +++ b/api/openai/assistant.go @@ -0,0 +1,470 @@ +package openai + +import ( + "fmt" + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/options" + "github.com/gofiber/fiber/v2" + "github.com/google/uuid" + "github.com/rs/zerolog/log" + "net/http" + "sort" + "strconv" + "strings" + "time" +) + +// ToolType defines a type for tool options +type ToolType string + +const ( + CodeInterpreter ToolType = "code_interpreter" + Retrieval ToolType = "retrieval" + Function ToolType = "function" + + MaxCharacterInstructions = 32768 + MaxCharacterDescription = 512 + MaxCharacterName = 256 + MaxToolsSize = 128 + MaxFileIdSize = 20 + MaxCharacterMetadataKey = 64 + MaxCharacterMetadataValue = 512 + + MaxLengthRandomID = 0 +) + +type Tool struct { + Type ToolType `json:"type"` +} + +// Assistant represents the structure of an assistant object from the OpenAI API. +type Assistant struct { + ID string `json:"id"` // The unique identifier of the assistant. + Object string `json:"object"` // Object type, which is "assistant". + Created int64 `json:"created"` // The time at which the assistant was created. + Model string `json:"model"` // The model ID used by the assistant. + Name string `json:"name,omitempty"` // The name of the assistant. + Description string `json:"description,omitempty"` // The description of the assistant. + Instructions string `json:"instructions,omitempty"` // The system instructions that the assistant uses. + Tools []Tool `json:"tools,omitempty"` // A list of tools enabled on the assistant. + FileIDs []string `json:"file_ids,omitempty"` // A list of file IDs attached to this assistant. + Metadata map[string]string `json:"metadata,omitempty"` // Set of key-value pairs attached to the assistant. +} + +var ( + assistants = []Assistant{} // better to return empty array instead of "null" +) + +type AssistantRequest struct { + Model string `json:"model"` + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + Instructions string `json:"instructions,omitempty"` + Tools []Tool `json:"tools,omitempty"` + FileIDs []string `json:"file_ids,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +func CreateAssistantEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + request := new(AssistantRequest) + if err := c.BodyParser(request); err != nil { + log.Warn().AnErr("Unable to parse AssistantRequest", err) + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Cannot parse JSON"}) + } + + if !modelExists(o, request.Model) { + log.Warn().Msgf("Model: %s was not found in list of models.", request.Model) + return c.Status(fiber.StatusBadRequest).SendString("Model " + request.Model + " not found") + } + + assistant := Assistant{ + ID: "asst_" + generateRandomID(MaxLengthRandomID), + Object: "assistant", + Created: time.Now().Unix(), + Model: request.Model, + Name: request.Name, + Description: request.Description, + Instructions: request.Instructions, + Tools: request.Tools, + FileIDs: request.FileIDs, + Metadata: request.Metadata, + } + + assistants = append(assistants, assistant) + + return c.Status(fiber.StatusOK).JSON(assistant) + } +} + +func generateRandomID(maxLength int) string { + newUUID, err := uuid.NewUUID() + if err != nil { + log.Error().Msgf("Failed to generate UUID: %v", err) + return "" + } + + uuidStr := newUUID.String() + if maxLength > 0 && len(uuidStr) > maxLength { + return uuidStr[:maxLength] + } + return uuidStr +} + +func ListAssistantsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + // Parse query parameters + limitQuery := c.Query("limit", "20") + orderQuery := c.Query("order", "desc") + afterQuery := c.Query("after") + beforeQuery := c.Query("before") + + // Convert string limit to integer + limit, err := strconv.Atoi(limitQuery) + if err != nil { + return c.Status(http.StatusBadRequest).SendString(err.Error()) + } + + // Sort assistants + sort.SliceStable(assistants, func(i, j int) bool { + if orderQuery == "asc" { + return assistants[i].Created < assistants[j].Created + } + return assistants[i].Created > assistants[j].Created + }) + + // After and before cursors + if afterQuery != "" { + assistants = filterAssistantsAfterID(assistants, afterQuery) + } + if beforeQuery != "" { + assistants = filterAssistantsBeforeID(assistants, beforeQuery) + } + + // Apply limit + if limit < len(assistants) { + assistants = assistants[:limit] + } + + return c.JSON(assistants) + } +} + +// FilterAssistantsBeforeID filters out those assistants whose ID comes before the given ID +// We assume that the assistants are already sorted +func filterAssistantsBeforeID(assistants []Assistant, id string) []Assistant { + for i, assistant := range assistants { + if strings.Compare(assistant.ID, id) == 0 { + if i != 0 { + return assistants[:i] + } + return []Assistant{} + } + } + return assistants +} + +// FilterAssistantsAfterID filters out those assistants whose ID comes after the given ID +// We assume that the assistants are already sorted +func filterAssistantsAfterID(assistants []Assistant, id string) []Assistant { + for i, assistant := range assistants { + if strings.Compare(assistant.ID, id) == 0 { + if i != len(assistants)-1 { + return assistants[i+1:] + } + return []Assistant{} + } + } + return assistants +} + +func modelExists(o *options.Option, modelName string) (found bool) { + found = false + models, err := o.Loader.ListModels() + if err != nil { + return + } + + for _, model := range models { + if model == modelName { + found = true + return + } + } + return +} + +func DeleteAssistantEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + type DeleteAssistantResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` + } + + return func(c *fiber.Ctx) error { + assistantID := c.Params("assistant_id") + if assistantID == "" { + return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required") + } + + for i, assistant := range assistants { + if assistant.ID == assistantID { + assistants = append(assistants[:i], assistants[i+1:]...) + return c.Status(fiber.StatusOK).JSON(DeleteAssistantResponse{ + ID: assistantID, + Object: "assistant.deleted", + Deleted: true, + }) + } + } + + log.Warn().Msgf("Unable to find assistant %s for deletion", assistantID) + return c.Status(fiber.StatusNotFound).JSON(DeleteAssistantResponse{ + ID: assistantID, + Object: "assistant.deleted", + Deleted: false, + }) + } +} + +func GetAssistantEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + assistantID := c.Params("assistant_id") + if assistantID == "" { + return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required") + } + + for _, assistant := range assistants { + if assistant.ID == assistantID { + return c.Status(fiber.StatusOK).JSON(assistant) + } + } + + return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find assistant with id: %s", assistantID)) + } +} + +type AssistantFile struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + AssistantID string `json:"assistant_id"` +} + +var assistantFiles []AssistantFile + +func CreateAssistantFileEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + type AssistantFileRequest struct { + FileID string `json:"file_id"` + } + + return func(c *fiber.Ctx) error { + request := new(AssistantFileRequest) + if err := c.BodyParser(request); err != nil { + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Cannot parse JSON"}) + } + + assistantID := c.Query("assistant_id") + if assistantID == "" { + return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required") + } + + for _, assistant := range assistants { + if assistant.ID == assistantID { + if len(assistant.FileIDs) > MaxFileIdSize { + return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("Max files %d for assistant %s reached.", MaxFileIdSize, assistant.Name)) + } + + for _, file := range uploadedFiles { + if file.ID == request.FileID { + assistant.FileIDs = append(assistant.FileIDs, request.FileID) + assistantFile := AssistantFile{ + ID: file.ID, + Object: "assistant.file", + CreatedAt: time.Now().Unix(), + AssistantID: assistant.ID, + } + assistantFiles = append(assistantFiles, assistantFile) + return c.Status(fiber.StatusOK).JSON(assistantFile) + } + } + + return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find file_id: %s", request.FileID)) + } + } + + return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find ")) + } +} + +func ListAssistantFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + type ListAssistantFiles struct { + Data []File + Object string + } + + return func(c *fiber.Ctx) error { + assistantID := c.Params("assistant_id") + if assistantID == "" { + return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required") + } + + limitQuery := c.Query("limit", "20") + order := c.Query("order", "desc") + limit, err := strconv.Atoi(limitQuery) + if err != nil || limit < 1 || limit > 100 { + limit = 20 // Default to 20 if there's an error or the limit is out of bounds + } + + // Sort files by CreatedAt depending on the order query parameter + if order == "asc" { + sort.Slice(assistantFiles, func(i, j int) bool { + return assistantFiles[i].CreatedAt < assistantFiles[j].CreatedAt + }) + } else { // default to "desc" + sort.Slice(assistantFiles, func(i, j int) bool { + return assistantFiles[i].CreatedAt > assistantFiles[j].CreatedAt + }) + } + + // Limit the number of files returned + var limitedFiles []AssistantFile + hasMore := false + if len(assistantFiles) > limit { + hasMore = true + limitedFiles = assistantFiles[:limit] + } else { + limitedFiles = assistantFiles + } + + response := map[string]interface{}{ + "object": "list", + "data": limitedFiles, + "first_id": func() string { + if len(limitedFiles) > 0 { + return limitedFiles[0].ID + } + return "" + }(), + "last_id": func() string { + if len(limitedFiles) > 0 { + return limitedFiles[len(limitedFiles)-1].ID + } + return "" + }(), + "has_more": hasMore, + } + + return c.Status(fiber.StatusOK).JSON(response) + } +} + +func ModifyAssistantEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + request := new(AssistantRequest) + if err := c.BodyParser(request); err != nil { + log.Warn().AnErr("Unable to parse AssistantRequest", err) + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Cannot parse JSON"}) + } + + assistantID := c.Params("assistant_id") + if assistantID == "" { + return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required") + } + + for i, assistant := range assistants { + if assistant.ID == assistantID { + newAssistant := Assistant{ + ID: assistantID, + Object: assistant.Object, + Created: assistant.Created, + Model: request.Model, + Name: request.Name, + Description: request.Description, + Instructions: request.Instructions, + Tools: request.Tools, + FileIDs: request.FileIDs, // todo: should probably verify fileids exist + Metadata: request.Metadata, + } + + // Remove old one and replace with new one + assistants = append(assistants[:i], assistants[i+1:]...) + assistants = append(assistants, newAssistant) + return c.Status(fiber.StatusOK).JSON(newAssistant) + } + } + return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find assistant with id: %s", assistantID)) + } +} + +func DeleteAssistantFileEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + type DeleteAssistantFileResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` + } + return func(c *fiber.Ctx) error { + assistantID := c.Params("assistant_id") + fileId := c.Params("file_id") + if assistantID == "" { + return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id and file_id are required") + } + // First remove file from assistant + for i, assistant := range assistants { + if assistant.ID == assistantID { + for j, fileId := range assistant.FileIDs { + if fileId == fileId { + assistants[i].FileIDs = append(assistants[i].FileIDs[:j], assistants[i].FileIDs[j+1:]...) + + // Check if the file exists in the assistantFiles slice + for i, assistantFile := range assistantFiles { + if assistantFile.ID == fileId { + // Remove the file from the assistantFiles slice + assistantFiles = append(assistantFiles[:i], assistantFiles[i+1:]...) + return c.Status(fiber.StatusOK).JSON(DeleteAssistantFileResponse{ + ID: fileId, + Object: "assistant.file.deleted", + Deleted: true, + }) + } + } + } + } + + log.Warn().Msgf("Unable to locate file_id: %s in assistants: %s", fileId, assistantID) + return c.Status(fiber.StatusNotFound).JSON(DeleteAssistantFileResponse{ + ID: fileId, + Object: "assistant.file.deleted", + Deleted: false, + }) + } + } + log.Warn().Msgf("Unable to find assistant: %s", assistantID) + + return c.Status(fiber.StatusNotFound).JSON(DeleteAssistantFileResponse{ + ID: fileId, + Object: "assistant.file.deleted", + Deleted: false, + }) + } +} + +func GetAssistantFileEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + assistantID := c.Params("assistant_id") + fileId := c.Params("file_id") + if assistantID == "" { + return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id and file_id are required") + } + + for _, assistantFile := range assistantFiles { + if assistantFile.AssistantID == assistantID { + if assistantFile.ID == fileId { + return c.Status(fiber.StatusOK).JSON(assistantFile) + } + return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find assistant file with file_id: %s", fileId)) + } + } + return c.Status(fiber.StatusNotFound).SendString(fmt.Sprintf("Unable to find assistant file with assistant_id: %s", assistantID)) + } +} diff --git a/core/http/api.go b/core/http/api.go index 7d228152409..40248ae982d 100644 --- a/core/http/api.go +++ b/core/http/api.go @@ -248,6 +248,26 @@ func App(opts ...options.AppOption) (*fiber.App, error) { app.Post("/v1/edits", auth, openai.EditEndpoint(cl, options)) app.Post("/edits", auth, openai.EditEndpoint(cl, options)) + // assistant + app.Get("/v1/assistants", openai.ListAssistantsEndpoint(cl, options)) + app.Get("/assistants", openai.ListAssistantsEndpoint(cl, options)) + app.Post("/v1/assistants", openai.CreateAssistantEndpoint(cl, options)) + app.Post("/assistants", openai.CreateAssistantEndpoint(cl, options)) + app.Delete("/v1/assistants/:assistant_id", openai.DeleteAssistantEndpoint(cl, options)) + app.Delete("/assistants/:assistant_id", openai.DeleteAssistantEndpoint(cl, options)) + app.Get("/v1/assistants/:assistant_id", openai.GetAssistantEndpoint(cl, options)) + app.Get("/assistants/:assistant_id", openai.GetAssistantEndpoint(cl, options)) + app.Post("/v1/assistants/:assistant_id", openai.ModifyAssistantEndpoint(cl, options)) + app.Post("/assistants/:assistant_id", openai.ModifyAssistantEndpoint(cl, options)) + app.Get("/v1/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(cl, options)) + app.Get("/assistants/:assistant_id/files", openai.ListAssistantFilesEndpoint(cl, options)) + app.Post("/v1/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(cl, options)) + app.Post("/assistants/:assistant_id/files", openai.CreateAssistantFileEndpoint(cl, options)) + app.Delete("/v1/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(cl, options)) + app.Delete("/assistants/:assistant_id/files/:file_id", openai.DeleteAssistantFileEndpoint(cl, options)) + app.Get("/v1/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(cl, options)) + app.Get("/assistants/:assistant_id/files/:file_id", openai.GetAssistantFileEndpoint(cl, options)) + // files app.Post("/v1/files", auth, openai.UploadFilesEndpoint(cl, options)) app.Post("/files", auth, openai.UploadFilesEndpoint(cl, options)) From 00437ae115a839389434ace821c41bb095108337 Mon Sep 17 00:00:00 2001 From: schristou88 Date: Sat, 2 Mar 2024 02:42:35 -0800 Subject: [PATCH 2/6] Move load/save configs to utils --- api/openai/assistant.go | 2 +- api/openai/files.go | 54 ++++++++--------------------------------- core/http/api.go | 3 ++- pkg/utils/config.go | 40 ++++++++++++++++++++++++++++++ 4 files changed, 53 insertions(+), 46 deletions(-) create mode 100644 pkg/utils/config.go diff --git a/api/openai/assistant.go b/api/openai/assistant.go index 58bb52f17df..3382d9f5344 100644 --- a/api/openai/assistant.go +++ b/api/openai/assistant.go @@ -275,7 +275,7 @@ func CreateAssistantFileEndpoint(cm *config.ConfigLoader, o *options.Option) fun return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("Max files %d for assistant %s reached.", MaxFileIdSize, assistant.Name)) } - for _, file := range uploadedFiles { + for _, file := range UploadedFiles { if file.ID == request.FileID { assistant.FileIDs = append(assistant.FileIDs, request.FileID) assistantFile := AssistantFile{ diff --git a/api/openai/files.go b/api/openai/files.go index 140b4151940..0dfd1bf204a 100644 --- a/api/openai/files.go +++ b/api/openai/files.go @@ -1,7 +1,6 @@ package openai import ( - "encoding/json" "errors" "fmt" "os" @@ -12,12 +11,11 @@ import ( "github.com/go-skynet/LocalAI/core/options" "github.com/go-skynet/LocalAI/pkg/utils" "github.com/gofiber/fiber/v2" - "github.com/rs/zerolog/log" ) -var uploadedFiles []File +var UploadedFiles []File -const uploadedFilesFile = "uploadedFiles.json" +const UploadedFilesFile = "uploadedFiles.json" // File represents the structure of a file object from the OpenAI API. type File struct { @@ -29,38 +27,6 @@ type File struct { Purpose string `json:"purpose"` // The purpose of the file (e.g., "fine-tune", "classifications", etc.) } -func saveUploadConfig(uploadDir string) { - file, err := json.MarshalIndent(uploadedFiles, "", " ") - if err != nil { - log.Error().Msgf("Failed to JSON marshal the uploadedFiles: %s", err) - } - - err = os.WriteFile(filepath.Join(uploadDir, uploadedFilesFile), file, 0644) - if err != nil { - log.Error().Msgf("Failed to save uploadedFiles to file: %s", err) - } -} - -func LoadUploadConfig(uploadPath string) { - uploadFilePath := filepath.Join(uploadPath, uploadedFilesFile) - - _, err := os.Stat(uploadFilePath) - if os.IsNotExist(err) { - log.Debug().Msgf("No uploadedFiles file found at %s", uploadFilePath) - return - } - - file, err := os.ReadFile(uploadFilePath) - if err != nil { - log.Error().Msgf("Failed to read file: %s", err) - } else { - err = json.Unmarshal(file, &uploadedFiles) - if err != nil { - log.Error().Msgf("Failed to JSON unmarshal the file into uploadedFiles: %s", err) - } - } -} - // UploadFilesEndpoint https://platform.openai.com/docs/api-reference/files/create func UploadFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { @@ -103,8 +69,8 @@ func UploadFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fib Purpose: purpose, } - uploadedFiles = append(uploadedFiles, f) - saveUploadConfig(o.UploadDir) + UploadedFiles = append(UploadedFiles, f) + utils.SaveConfig(o.UploadDir, UploadedFilesFile, UploadedFiles) return c.Status(fiber.StatusOK).JSON(f) } } @@ -121,9 +87,9 @@ func ListFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber purpose := c.Query("purpose") if purpose == "" { - listFiles.Data = uploadedFiles + listFiles.Data = UploadedFiles } else { - for _, f := range uploadedFiles { + for _, f := range UploadedFiles { if purpose == f.Purpose { listFiles.Data = append(listFiles.Data, f) } @@ -140,7 +106,7 @@ func getFileFromRequest(c *fiber.Ctx) (*File, error) { return nil, fmt.Errorf("file_id parameter is required") } - for _, f := range uploadedFiles { + for _, f := range UploadedFiles { if id == f.ID { return &f, nil } @@ -184,14 +150,14 @@ func DeleteFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fib } // Remove upload from list - for i, f := range uploadedFiles { + for i, f := range UploadedFiles { if f.ID == file.ID { - uploadedFiles = append(uploadedFiles[:i], uploadedFiles[i+1:]...) + UploadedFiles = append(UploadedFiles[:i], UploadedFiles[i+1:]...) break } } - saveUploadConfig(o.UploadDir) + utils.SaveConfig(o.UploadDir, UploadedFilesFile, UploadedFiles) return c.JSON(DeleteStatus{ Id: file.ID, Object: "file", diff --git a/core/http/api.go b/core/http/api.go index 40248ae982d..880910b3a1b 100644 --- a/core/http/api.go +++ b/core/http/api.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/go-skynet/LocalAI/pkg/utils" "os" "strings" @@ -227,7 +228,7 @@ func App(opts ...options.AppOption) (*fiber.App, error) { os.MkdirAll(options.Loader.ModelPath, 0755) // Load upload json - openai.LoadUploadConfig(options.UploadDir) + utils.LoadConfig(options.UploadDir, openai.UploadedFilesFile, &openai.UploadedFiles) modelGalleryService := localai.CreateModelGalleryService(options.Galleries, options.Loader.ModelPath, galleryService) app.Post("/models/apply", auth, modelGalleryService.ApplyModelGalleryEndpoint()) diff --git a/pkg/utils/config.go b/pkg/utils/config.go new file mode 100644 index 00000000000..b8f83a23fb3 --- /dev/null +++ b/pkg/utils/config.go @@ -0,0 +1,40 @@ +package utils + +import ( + "encoding/json" + "github.com/rs/zerolog/log" + "os" + "path/filepath" +) + +func SaveConfig(uploadDir, fileName string, obj any) { + file, err := json.MarshalIndent(obj, "", " ") + if err != nil { + log.Error().Msgf("Failed to JSON marshal the uploadedFiles: %s", err) + } + + err = os.WriteFile(filepath.Join(uploadDir, fileName), file, 0644) + if err != nil { + log.Error().Msgf("Failed to save uploadedFiles to file: %s", err) + } +} + +func LoadConfig(filePath, fileName string, obj any) { + uploadFilePath := filepath.Join(filePath, fileName) + + _, err := os.Stat(uploadFilePath) + if os.IsNotExist(err) { + log.Debug().Msgf("No uploadedFiles file found at %s", uploadFilePath) + return + } + + file, err := os.ReadFile(uploadFilePath) + if err != nil { + log.Error().Msgf("Failed to read file: %s", err) + } else { + err = json.Unmarshal(file, &obj) + if err != nil { + log.Error().Msgf("Failed to JSON unmarshal the file into uploadedFiles: %s", err) + } + } +} From 0b3c84a3226e6be697489a47b9965db585eae4f1 Mon Sep 17 00:00:00 2001 From: schristou88 Date: Sat, 2 Mar 2024 03:23:02 -0800 Subject: [PATCH 3/6] Save assistant and assistantfiles config to disk. --- api/openai/assistant.go | 73 +++++++++++++++++++++++------------------ core/http/api.go | 5 ++- core/options/options.go | 7 ++++ main.go | 7 ++++ 4 files changed, 59 insertions(+), 33 deletions(-) diff --git a/api/openai/assistant.go b/api/openai/assistant.go index 3382d9f5344..741abaa6e2e 100644 --- a/api/openai/assistant.go +++ b/api/openai/assistant.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/core/options" + "github.com/go-skynet/LocalAI/pkg/utils" "github.com/gofiber/fiber/v2" "github.com/google/uuid" "github.com/rs/zerolog/log" @@ -52,7 +53,8 @@ type Assistant struct { } var ( - assistants = []Assistant{} // better to return empty array instead of "null" + Assistants = []Assistant{} // better to return empty array instead of "null" + AssistantsConfigFile = "assistants.json" ) type AssistantRequest struct { @@ -91,8 +93,8 @@ func CreateAssistantEndpoint(cm *config.ConfigLoader, o *options.Option) func(c Metadata: request.Metadata, } - assistants = append(assistants, assistant) - + Assistants = append(Assistants, assistant) + utils.SaveConfig(o.ConfigsDir, AssistantsConfigFile, Assistants) return c.Status(fiber.StatusOK).JSON(assistant) } } @@ -126,27 +128,27 @@ func ListAssistantsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c * } // Sort assistants - sort.SliceStable(assistants, func(i, j int) bool { + sort.SliceStable(Assistants, func(i, j int) bool { if orderQuery == "asc" { - return assistants[i].Created < assistants[j].Created + return Assistants[i].Created < Assistants[j].Created } - return assistants[i].Created > assistants[j].Created + return Assistants[i].Created > Assistants[j].Created }) // After and before cursors if afterQuery != "" { - assistants = filterAssistantsAfterID(assistants, afterQuery) + Assistants = filterAssistantsAfterID(Assistants, afterQuery) } if beforeQuery != "" { - assistants = filterAssistantsBeforeID(assistants, beforeQuery) + Assistants = filterAssistantsBeforeID(Assistants, beforeQuery) } // Apply limit - if limit < len(assistants) { - assistants = assistants[:limit] + if limit < len(Assistants) { + Assistants = Assistants[:limit] } - return c.JSON(assistants) + return c.JSON(Assistants) } } @@ -207,9 +209,10 @@ func DeleteAssistantEndpoint(cm *config.ConfigLoader, o *options.Option) func(c return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required") } - for i, assistant := range assistants { + for i, assistant := range Assistants { if assistant.ID == assistantID { - assistants = append(assistants[:i], assistants[i+1:]...) + Assistants = append(Assistants[:i], Assistants[i+1:]...) + utils.SaveConfig(o.ConfigsDir, AssistantsConfigFile, Assistants) return c.Status(fiber.StatusOK).JSON(DeleteAssistantResponse{ ID: assistantID, Object: "assistant.deleted", @@ -234,7 +237,7 @@ func GetAssistantEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fi return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required") } - for _, assistant := range assistants { + for _, assistant := range Assistants { if assistant.ID == assistantID { return c.Status(fiber.StatusOK).JSON(assistant) } @@ -251,7 +254,10 @@ type AssistantFile struct { AssistantID string `json:"assistant_id"` } -var assistantFiles []AssistantFile +var ( + AssistantFiles []AssistantFile + AssistantsFileConfigFile = "assistantsFile.json" +) func CreateAssistantFileEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { type AssistantFileRequest struct { @@ -269,7 +275,7 @@ func CreateAssistantFileEndpoint(cm *config.ConfigLoader, o *options.Option) fun return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required") } - for _, assistant := range assistants { + for _, assistant := range Assistants { if assistant.ID == assistantID { if len(assistant.FileIDs) > MaxFileIdSize { return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("Max files %d for assistant %s reached.", MaxFileIdSize, assistant.Name)) @@ -284,7 +290,8 @@ func CreateAssistantFileEndpoint(cm *config.ConfigLoader, o *options.Option) fun CreatedAt: time.Now().Unix(), AssistantID: assistant.ID, } - assistantFiles = append(assistantFiles, assistantFile) + AssistantFiles = append(AssistantFiles, assistantFile) + utils.SaveConfig(o.ConfigsDir, AssistantsFileConfigFile, AssistantFiles) return c.Status(fiber.StatusOK).JSON(assistantFile) } } @@ -318,23 +325,23 @@ func ListAssistantFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func // Sort files by CreatedAt depending on the order query parameter if order == "asc" { - sort.Slice(assistantFiles, func(i, j int) bool { - return assistantFiles[i].CreatedAt < assistantFiles[j].CreatedAt + sort.Slice(AssistantFiles, func(i, j int) bool { + return AssistantFiles[i].CreatedAt < AssistantFiles[j].CreatedAt }) } else { // default to "desc" - sort.Slice(assistantFiles, func(i, j int) bool { - return assistantFiles[i].CreatedAt > assistantFiles[j].CreatedAt + sort.Slice(AssistantFiles, func(i, j int) bool { + return AssistantFiles[i].CreatedAt > AssistantFiles[j].CreatedAt }) } // Limit the number of files returned var limitedFiles []AssistantFile hasMore := false - if len(assistantFiles) > limit { + if len(AssistantFiles) > limit { hasMore = true - limitedFiles = assistantFiles[:limit] + limitedFiles = AssistantFiles[:limit] } else { - limitedFiles = assistantFiles + limitedFiles = AssistantFiles } response := map[string]interface{}{ @@ -372,7 +379,7 @@ func ModifyAssistantEndpoint(cm *config.ConfigLoader, o *options.Option) func(c return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required") } - for i, assistant := range assistants { + for i, assistant := range Assistants { if assistant.ID == assistantID { newAssistant := Assistant{ ID: assistantID, @@ -388,8 +395,9 @@ func ModifyAssistantEndpoint(cm *config.ConfigLoader, o *options.Option) func(c } // Remove old one and replace with new one - assistants = append(assistants[:i], assistants[i+1:]...) - assistants = append(assistants, newAssistant) + Assistants = append(Assistants[:i], Assistants[i+1:]...) + Assistants = append(Assistants, newAssistant) + utils.SaveConfig(o.ConfigsDir, AssistantsConfigFile, Assistants) return c.Status(fiber.StatusOK).JSON(newAssistant) } } @@ -410,17 +418,18 @@ func DeleteAssistantFileEndpoint(cm *config.ConfigLoader, o *options.Option) fun return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id and file_id are required") } // First remove file from assistant - for i, assistant := range assistants { + for i, assistant := range Assistants { if assistant.ID == assistantID { for j, fileId := range assistant.FileIDs { if fileId == fileId { - assistants[i].FileIDs = append(assistants[i].FileIDs[:j], assistants[i].FileIDs[j+1:]...) + Assistants[i].FileIDs = append(Assistants[i].FileIDs[:j], Assistants[i].FileIDs[j+1:]...) // Check if the file exists in the assistantFiles slice - for i, assistantFile := range assistantFiles { + for i, assistantFile := range AssistantFiles { if assistantFile.ID == fileId { // Remove the file from the assistantFiles slice - assistantFiles = append(assistantFiles[:i], assistantFiles[i+1:]...) + AssistantFiles = append(AssistantFiles[:i], AssistantFiles[i+1:]...) + utils.SaveConfig(o.ConfigsDir, AssistantsFileConfigFile, AssistantFiles) return c.Status(fiber.StatusOK).JSON(DeleteAssistantFileResponse{ ID: fileId, Object: "assistant.file.deleted", @@ -457,7 +466,7 @@ func GetAssistantFileEndpoint(cm *config.ConfigLoader, o *options.Option) func(c return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id and file_id are required") } - for _, assistantFile := range assistantFiles { + for _, assistantFile := range AssistantFiles { if assistantFile.AssistantID == assistantID { if assistantFile.ID == fileId { return c.Status(fiber.StatusOK).JSON(assistantFile) diff --git a/core/http/api.go b/core/http/api.go index 880910b3a1b..36bd3d09cb0 100644 --- a/core/http/api.go +++ b/core/http/api.go @@ -225,10 +225,13 @@ func App(opts ...options.AppOption) (*fiber.App, error) { os.MkdirAll(options.ImageDir, 0755) os.MkdirAll(options.AudioDir, 0755) os.MkdirAll(options.UploadDir, 0755) + os.MkdirAll(options.ConfigsDir, 0755) os.MkdirAll(options.Loader.ModelPath, 0755) - // Load upload json + // Load config jsons utils.LoadConfig(options.UploadDir, openai.UploadedFilesFile, &openai.UploadedFiles) + utils.LoadConfig(options.ConfigsDir, openai.AssistantsConfigFile, &openai.Assistants) + utils.LoadConfig(options.ConfigsDir, openai.AssistantsFileConfigFile, &openai.AssistantFiles) modelGalleryService := localai.CreateModelGalleryService(options.Galleries, options.Loader.ModelPath, galleryService) app.Post("/models/apply", auth, modelGalleryService.ApplyModelGalleryEndpoint()) diff --git a/core/options/options.go b/core/options/options.go index 72aea1a3293..c92cdcb2c44 100644 --- a/core/options/options.go +++ b/core/options/options.go @@ -22,6 +22,7 @@ type Option struct { ImageDir string AudioDir string UploadDir string + ConfigsDir string CORS bool PreloadJSONModels string PreloadModelsFromPath string @@ -256,6 +257,12 @@ func WithUploadDir(uploadDir string) AppOption { } } +func WithConfigsDir(configsDir string) AppOption { + return func(o *Option) { + o.ConfigsDir = configsDir + } +} + func WithApiKeys(apiKeys []string) AppOption { return func(o *Option) { o.ApiKeys = apiKeys diff --git a/main.go b/main.go index 7e4262ee57b..a7832bafa31 100644 --- a/main.go +++ b/main.go @@ -148,6 +148,12 @@ func main() { EnvVars: []string{"UPLOAD_PATH"}, Value: "/tmp/localai/upload", }, + &cli.StringFlag{ + Name: "config-path", + Usage: "Path to store uploads from files api", + EnvVars: []string{"CONFIG_PATH"}, + Value: "/tmp/localai/config", + }, &cli.StringFlag{ Name: "backend-assets-path", Usage: "Path used to extract libraries that are required by some of the backends in runtime.", @@ -234,6 +240,7 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit options.WithImageDir(ctx.String("image-path")), options.WithAudioDir(ctx.String("audio-path")), options.WithUploadDir(ctx.String("upload-path")), + options.WithConfigsDir(ctx.String("config-path")), options.WithF16(ctx.Bool("f16")), options.WithStringGalleries(ctx.String("galleries")), options.WithModelLibraryURL(ctx.String("remote-library")), From d3f9c0988b46751d4898af9d2d2b84b1b1fffb41 Mon Sep 17 00:00:00 2001 From: schristou88 Date: Tue, 5 Mar 2024 20:48:40 -0800 Subject: [PATCH 4/6] Add tsets for assistant api --- api/openai/assistant.go | 144 +++++++----- api/openai/assistant_test.go | 436 +++++++++++++++++++++++++++++++++++ api/openai/files_test.go | 4 +- pkg/utils/config.go | 5 +- 4 files changed, 531 insertions(+), 58 deletions(-) create mode 100644 api/openai/assistant_test.go diff --git a/api/openai/assistant.go b/api/openai/assistant.go index 741abaa6e2e..fcc59c79b66 100644 --- a/api/openai/assistant.go +++ b/api/openai/assistant.go @@ -6,12 +6,12 @@ import ( "github.com/go-skynet/LocalAI/core/options" "github.com/go-skynet/LocalAI/pkg/utils" "github.com/gofiber/fiber/v2" - "github.com/google/uuid" "github.com/rs/zerolog/log" "net/http" "sort" "strconv" "strings" + "sync/atomic" "time" ) @@ -30,8 +30,6 @@ const ( MaxFileIdSize = 20 MaxCharacterMetadataKey = 64 MaxCharacterMetadataValue = 512 - - MaxLengthRandomID = 0 ) type Tool struct { @@ -80,8 +78,22 @@ func CreateAssistantEndpoint(cm *config.ConfigLoader, o *options.Option) func(c return c.Status(fiber.StatusBadRequest).SendString("Model " + request.Model + " not found") } + if request.Tools == nil { + request.Tools = []Tool{} + } + + if request.FileIDs == nil { + request.FileIDs = []string{} + } + + if request.Metadata == nil { + request.Metadata = make(map[string]string) + } + + id := "asst_" + strconv.FormatInt(generateRandomID(), 10) + assistant := Assistant{ - ID: "asst_" + generateRandomID(MaxLengthRandomID), + ID: id, Object: "assistant", Created: time.Now().Unix(), Model: request.Model, @@ -99,22 +111,17 @@ func CreateAssistantEndpoint(cm *config.ConfigLoader, o *options.Option) func(c } } -func generateRandomID(maxLength int) string { - newUUID, err := uuid.NewUUID() - if err != nil { - log.Error().Msgf("Failed to generate UUID: %v", err) - return "" - } +var currentId int64 = 0 - uuidStr := newUUID.String() - if maxLength > 0 && len(uuidStr) > maxLength { - return uuidStr[:maxLength] - } - return uuidStr +func generateRandomID() int64 { + atomic.AddInt64(¤tId, 1) + return currentId } func ListAssistantsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { + // Because we're altering the existing assistants list we should just duplicate it for now. + returnAssistants := Assistants // Parse query parameters limitQuery := c.Query("limit", "20") orderQuery := c.Query("order", "desc") @@ -124,60 +131,80 @@ func ListAssistantsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c * // Convert string limit to integer limit, err := strconv.Atoi(limitQuery) if err != nil { - return c.Status(http.StatusBadRequest).SendString(err.Error()) + return c.Status(http.StatusBadRequest).SendString(fmt.Sprintf("Invalid limit query value: %s", limitQuery)) } // Sort assistants - sort.SliceStable(Assistants, func(i, j int) bool { + sort.SliceStable(returnAssistants, func(i, j int) bool { if orderQuery == "asc" { - return Assistants[i].Created < Assistants[j].Created + return returnAssistants[i].Created < returnAssistants[j].Created } - return Assistants[i].Created > Assistants[j].Created + return returnAssistants[i].Created > returnAssistants[j].Created }) // After and before cursors if afterQuery != "" { - Assistants = filterAssistantsAfterID(Assistants, afterQuery) + returnAssistants = filterAssistantsAfterID(returnAssistants, afterQuery) } if beforeQuery != "" { - Assistants = filterAssistantsBeforeID(Assistants, beforeQuery) + returnAssistants = filterAssistantsBeforeID(returnAssistants, beforeQuery) } // Apply limit - if limit < len(Assistants) { - Assistants = Assistants[:limit] + if limit < len(returnAssistants) { + returnAssistants = returnAssistants[:limit] } - return c.JSON(Assistants) + return c.JSON(returnAssistants) } } // FilterAssistantsBeforeID filters out those assistants whose ID comes before the given ID // We assume that the assistants are already sorted func filterAssistantsBeforeID(assistants []Assistant, id string) []Assistant { - for i, assistant := range assistants { - if strings.Compare(assistant.ID, id) == 0 { - if i != 0 { - return assistants[:i] - } - return []Assistant{} + idInt, err := strconv.Atoi(id) + if err != nil { + return assistants // Return original slice if invalid id format is provided + } + + var filteredAssistants []Assistant + + for _, assistant := range assistants { + aid, err := strconv.Atoi(strings.TrimPrefix(assistant.ID, "asst_")) + if err != nil { + continue // Skip if invalid id in assistant + } + + if aid < idInt { + filteredAssistants = append(filteredAssistants, assistant) } } - return assistants + + return filteredAssistants } // FilterAssistantsAfterID filters out those assistants whose ID comes after the given ID // We assume that the assistants are already sorted func filterAssistantsAfterID(assistants []Assistant, id string) []Assistant { - for i, assistant := range assistants { - if strings.Compare(assistant.ID, id) == 0 { - if i != len(assistants)-1 { - return assistants[i+1:] - } - return []Assistant{} + idInt, err := strconv.Atoi(id) + if err != nil { + return assistants // Return original slice if invalid id format is provided + } + + var filteredAssistants []Assistant + + for _, assistant := range assistants { + aid, err := strconv.Atoi(strings.TrimPrefix(assistant.ID, "asst_")) + if err != nil { + continue // Skip if invalid id in assistant + } + + if aid > idInt { + filteredAssistants = append(filteredAssistants, assistant) } } - return assistants + + return filteredAssistants } func modelExists(o *options.Option, modelName string) (found bool) { @@ -259,18 +286,24 @@ var ( AssistantsFileConfigFile = "assistantsFile.json" ) -func CreateAssistantFileEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { - type AssistantFileRequest struct { - FileID string `json:"file_id"` - } +type AssistantFileRequest struct { + FileID string `json:"file_id"` +} +type DeleteAssistantFileResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` +} + +func CreateAssistantFileEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { request := new(AssistantFileRequest) if err := c.BodyParser(request); err != nil { return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Cannot parse JSON"}) } - assistantID := c.Query("assistant_id") + assistantID := c.Params("assistant_id") if assistantID == "" { return c.Status(fiber.StatusBadRequest).SendString("parameter assistant_id is required") } @@ -406,11 +439,6 @@ func ModifyAssistantEndpoint(cm *config.ConfigLoader, o *options.Option) func(c } func DeleteAssistantFileEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { - type DeleteAssistantFileResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Deleted bool `json:"deleted"` - } return func(c *fiber.Ctx) error { assistantID := c.Params("assistant_id") fileId := c.Params("file_id") @@ -440,12 +468,20 @@ func DeleteAssistantFileEndpoint(cm *config.ConfigLoader, o *options.Option) fun } } - log.Warn().Msgf("Unable to locate file_id: %s in assistants: %s", fileId, assistantID) - return c.Status(fiber.StatusNotFound).JSON(DeleteAssistantFileResponse{ - ID: fileId, - Object: "assistant.file.deleted", - Deleted: false, - }) + log.Warn().Msgf("Unable to locate file_id: %s in assistants: %s. Continuing to delete assistant file.", fileId, assistantID) + for i, assistantFile := range AssistantFiles { + if assistantFile.AssistantID == assistantID { + + AssistantFiles = append(AssistantFiles[:i], AssistantFiles[i+1:]...) + utils.SaveConfig(o.ConfigsDir, AssistantsFileConfigFile, AssistantFiles) + + return c.Status(fiber.StatusNotFound).JSON(DeleteAssistantFileResponse{ + ID: fileId, + Object: "assistant.file.deleted", + Deleted: true, + }) + } + } } } log.Warn().Msgf("Unable to find assistant: %s", assistantID) diff --git a/api/openai/assistant_test.go b/api/openai/assistant_test.go new file mode 100644 index 00000000000..ab433b34c0e --- /dev/null +++ b/api/openai/assistant_test.go @@ -0,0 +1,436 @@ +package openai + +import ( + "encoding/json" + "fmt" + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/options" + "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" + "github.com/stretchr/testify/assert" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" +) + +type MockLoader struct { + models []string +} + +func TestAssistantEndpoints(t *testing.T) { + // Preparing the mocked objects + loader := &config.ConfigLoader{} + //configsDir := "/tmp/localai/configs" + configsDir := "" + option := &options.Option{ + ConfigsDir: configsDir, + UploadLimitMB: 10, + UploadDir: "test_dir", + Loader: &model.ModelLoader{ + ModelPath: "/tmp/localai/models", + }, + } + + _ = os.RemoveAll(option.ConfigsDir) + + app := fiber.New(fiber.Config{ + BodyLimit: 20 * 1024 * 1024, // sets the limit to 20MB. + }) + + // Create a Test Server + app.Get("/assistants", ListAssistantsEndpoint(loader, option)) + app.Post("/assistants", CreateAssistantEndpoint(loader, option)) + app.Delete("/assistants/:assistant_id", DeleteAssistantEndpoint(loader, option)) + app.Get("/assistants/:assistant_id", GetAssistantEndpoint(loader, option)) + app.Post("/assistants/:assistant_id", ModifyAssistantEndpoint(loader, option)) + + app.Post("/files", UploadFilesEndpoint(loader, option)) + app.Get("/assistants/:assistant_id/files", ListAssistantFilesEndpoint(loader, option)) + app.Post("/assistants/:assistant_id/files", CreateAssistantFileEndpoint(loader, option)) + app.Delete("/assistants/:assistant_id/files/:file_id", DeleteAssistantFileEndpoint(loader, option)) + app.Get("/assistants/:assistant_id/files/:file_id", GetAssistantFileEndpoint(loader, option)) + + t.Run("CreateAssistantEndpoint", func(t *testing.T) { + ar := &AssistantRequest{ + Model: "ggml-gpt4all-j", + Name: "3.5-turbo", + Description: "Test Assistant", + Instructions: "You are computer science teacher answering student questions", + Tools: []Tool{{Type: Function}}, + FileIDs: nil, + Metadata: nil, + } + + resultAssistant, resp, err := createAssistant(app, *ar) + assert.NoError(t, err) + assert.Equal(t, fiber.StatusOK, resp.StatusCode) + + assert.Equal(t, 1, len(Assistants)) + t.Cleanup(cleanupAllAssistants(t, app, []string{resultAssistant.ID})) + + assert.Equal(t, ar.Name, resultAssistant.Name) + assert.Equal(t, ar.Model, resultAssistant.Model) + assert.Equal(t, ar.Tools, resultAssistant.Tools) + assert.Equal(t, ar.Description, resultAssistant.Description) + assert.Equal(t, ar.Instructions, resultAssistant.Instructions) + assert.Equal(t, ar.FileIDs, resultAssistant.FileIDs) + assert.Equal(t, ar.Metadata, resultAssistant.Metadata) + }) + + t.Run("ListAssistantsEndpoint", func(t *testing.T) { + var ids []string + var resultAssistant []Assistant + for i := 0; i < 4; i++ { + ar := &AssistantRequest{ + Model: "ggml-gpt4all-j", + Name: fmt.Sprintf("3.5-turbo-%d", i), + Description: fmt.Sprintf("Test Assistant - %d", i), + Instructions: fmt.Sprintf("You are computer science teacher answering student questions - %d", i), + Tools: []Tool{{Type: Function}}, + FileIDs: []string{"fid-1234"}, + Metadata: map[string]string{"meta": "data"}, + } + + //var err error + ra, _, err := createAssistant(app, *ar) + // Because we create the assistants so fast all end up with the same created time. + time.Sleep(time.Second) + resultAssistant = append(resultAssistant, ra) + assert.NoError(t, err) + ids = append(ids, resultAssistant[i].ID) + } + + t.Cleanup(cleanupAllAssistants(t, app, ids)) + + tests := []struct { + name string + reqURL string + expectedStatus int + expectedResult []Assistant + expectedStringResult string + }{ + { + name: "Valid Usage - limit only", + reqURL: "/assistants?limit=2", + expectedStatus: http.StatusOK, + expectedResult: Assistants[:2], // Expecting the first two assistants + }, + { + name: "Valid Usage - order asc", + reqURL: "/assistants?order=asc", + expectedStatus: http.StatusOK, + expectedResult: Assistants, // Expecting all assistants in ascending order + }, + { + name: "Valid Usage - order desc", + reqURL: "/assistants?order=desc", + expectedStatus: http.StatusOK, + expectedResult: []Assistant{Assistants[3], Assistants[2], Assistants[1], Assistants[0]}, // Expecting all assistants in descending order + }, + { + name: "Valid Usage - after specific ID", + reqURL: "/assistants?after=2", + expectedStatus: http.StatusOK, + // Note this is correct because it's put in descending order already + expectedResult: Assistants[:3], // Expecting assistants after (excluding) ID 2 + }, + { + name: "Valid Usage - before specific ID", + reqURL: "/assistants?before=4", + expectedStatus: http.StatusOK, + expectedResult: Assistants[2:], // Expecting assistants before (excluding) ID 3. + }, + { + name: "Invalid Usage - non-integer limit", + reqURL: "/assistants?limit=two", + expectedStatus: http.StatusBadRequest, + expectedStringResult: "Invalid limit query value: two", + }, + { + name: "Invalid Usage - non-existing id in after", + reqURL: "/assistants?after=100", + expectedStatus: http.StatusOK, + expectedResult: []Assistant(nil), // Expecting empty list as there are no IDs above 100 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + request := httptest.NewRequest(http.MethodGet, tt.reqURL, nil) + response, err := app.Test(request) + assert.NoError(t, err) + assert.Equal(t, tt.expectedStatus, response.StatusCode) + if tt.expectedStatus != fiber.StatusOK { + all, _ := ioutil.ReadAll(response.Body) + assert.Equal(t, tt.expectedStringResult, string(all)) + } else { + var result []Assistant + err = json.NewDecoder(response.Body).Decode(&result) + assert.NoError(t, err) + + assert.Equal(t, tt.expectedResult, result) + } + }) + } + }) + + t.Run("DeleteAssistantEndpoint", func(t *testing.T) { + ar := &AssistantRequest{ + Model: "ggml-gpt4all-j", + Name: "3.5-turbo", + Description: "Test Assistant", + Instructions: "You are computer science teacher answering student questions", + Tools: []Tool{{Type: Function}}, + FileIDs: nil, + Metadata: nil, + } + + resultAssistant, _, err := createAssistant(app, *ar) + assert.NoError(t, err) + + target := fmt.Sprintf("/assistants/%s", resultAssistant.ID) + deleteReq := httptest.NewRequest(http.MethodDelete, target, nil) + _, err = app.Test(deleteReq) + assert.NoError(t, err) + assert.Equal(t, 0, len(Assistants)) + }) + + t.Run("GetAssistantEndpoint", func(t *testing.T) { + ar := &AssistantRequest{ + Model: "ggml-gpt4all-j", + Name: "3.5-turbo", + Description: "Test Assistant", + Instructions: "You are computer science teacher answering student questions", + Tools: []Tool{{Type: Function}}, + FileIDs: nil, + Metadata: nil, + } + + resultAssistant, _, err := createAssistant(app, *ar) + assert.NoError(t, err) + t.Cleanup(cleanupAllAssistants(t, app, []string{resultAssistant.ID})) + + target := fmt.Sprintf("/assistants/%s", resultAssistant.ID) + request := httptest.NewRequest(http.MethodGet, target, nil) + response, err := app.Test(request) + assert.NoError(t, err) + + var getAssistant Assistant + err = json.NewDecoder(response.Body).Decode(&getAssistant) + assert.NoError(t, err) + + assert.Equal(t, resultAssistant.ID, getAssistant.ID) + }) + + t.Run("ModifyAssistantEndpoint", func(t *testing.T) { + ar := &AssistantRequest{ + Model: "ggml-gpt4all-j", + Name: "3.5-turbo", + Description: "Test Assistant", + Instructions: "You are computer science teacher answering student questions", + Tools: []Tool{{Type: Function}}, + FileIDs: nil, + Metadata: nil, + } + + resultAssistant, _, err := createAssistant(app, *ar) + assert.NoError(t, err) + + modifiedAr := &AssistantRequest{ + Model: "ggml-gpt4all-j", + Name: "4.0-turbo", + Description: "Modified Test Assistant", + Instructions: "You are math teacher answering student questions", + Tools: []Tool{{Type: CodeInterpreter}}, + FileIDs: nil, + Metadata: nil, + } + + modifiedArJson, err := json.Marshal(modifiedAr) + assert.NoError(t, err) + + target := fmt.Sprintf("/assistants/%s", resultAssistant.ID) + request := httptest.NewRequest(http.MethodPost, target, strings.NewReader(string(modifiedArJson))) + request.Header.Set(fiber.HeaderContentType, "application/json") + + modifyResponse, err := app.Test(request) + assert.NoError(t, err) + var getAssistant Assistant + err = json.NewDecoder(modifyResponse.Body).Decode(&getAssistant) + + t.Cleanup(cleanupAllAssistants(t, app, []string{getAssistant.ID})) + + assert.Equal(t, resultAssistant.ID, getAssistant.ID) // IDs should match even if contents change + assert.Equal(t, modifiedAr.Tools, getAssistant.Tools) + assert.Equal(t, modifiedAr.Name, getAssistant.Name) + assert.Equal(t, modifiedAr.Instructions, getAssistant.Instructions) + assert.Equal(t, modifiedAr.Description, getAssistant.Description) + }) + + t.Run("CreateAssistantFileEndpoint", func(t *testing.T) { + file, assistant, err := createFileAndAssistant(t, app, option) + assert.NoError(t, err) + + afr := AssistantFileRequest{FileID: file.ID} + af, _, err := createAssistantFile(app, afr, assistant.ID) + + assert.NoError(t, err) + t.Cleanup(cleanupAssistantFile(t, app, file.ID, af.AssistantID)) + assert.Equal(t, assistant.ID, af.AssistantID) + }) + t.Run("ListAssistantFilesEndpoint", func(t *testing.T) { + file, assistant, err := createFileAndAssistant(t, app, option) + assert.NoError(t, err) + + afr := AssistantFileRequest{FileID: file.ID} + af, _, err := createAssistantFile(app, afr, assistant.ID) + + assert.NoError(t, err) + t.Cleanup(cleanupAssistantFile(t, app, file.ID, af.AssistantID)) + + assert.Equal(t, assistant.ID, af.AssistantID) + }) + t.Run("GetAssistantFileEndpoint", func(t *testing.T) { + + file, assistant, err := createFileAndAssistant(t, app, option) + assert.NoError(t, err) + + afr := AssistantFileRequest{FileID: file.ID} + af, _, err := createAssistantFile(app, afr, assistant.ID) + assert.NoError(t, err) + t.Cleanup(cleanupAssistantFile(t, app, af.ID, af.AssistantID)) + + target := fmt.Sprintf("/assistants/%s/files/%s", assistant.ID, file.ID) + request := httptest.NewRequest(http.MethodGet, target, nil) + response, err := app.Test(request) + assert.NoError(t, err) + + var assistantFile AssistantFile + err = json.NewDecoder(response.Body).Decode(&assistantFile) + assert.NoError(t, err) + + assert.Equal(t, af.ID, assistantFile.ID) + assert.Equal(t, af.AssistantID, assistantFile.AssistantID) + }) + t.Run("DeleteAssistantFileEndpoint", func(t *testing.T) { + file, assistant, err := createFileAndAssistant(t, app, option) + assert.NoError(t, err) + + afr := AssistantFileRequest{FileID: file.ID} + af, _, err := createAssistantFile(app, afr, assistant.ID) + assert.NoError(t, err) + + cleanupAssistantFile(t, app, af.ID, af.AssistantID)() + + assert.Empty(t, AssistantFiles) + }) + +} + +func createFileAndAssistant(t *testing.T, app *fiber.App, o *options.Option) (File, Assistant, error) { + ar := &AssistantRequest{ + Model: "ggml-gpt4all-j", + Name: "3.5-turbo", + Description: "Test Assistant", + Instructions: "You are computer science teacher answering student questions", + Tools: []Tool{{Type: Function}}, + FileIDs: nil, + Metadata: nil, + } + + assistant, _, err := createAssistant(app, *ar) + if err != nil { + return File{}, Assistant{}, err + } + t.Cleanup(cleanupAllAssistants(t, app, []string{assistant.ID})) + + file := CallFilesUploadEndpointWithCleanup(t, app, "test.txt", "file", "fine-tune", 5, o) + return file, assistant, nil +} + +func createAssistantFile(app *fiber.App, afr AssistantFileRequest, assistantId string) (AssistantFile, *http.Response, error) { + afrJson, err := json.Marshal(afr) + if err != nil { + return AssistantFile{}, nil, err + } + + target := fmt.Sprintf("/assistants/%s/files", assistantId) + request := httptest.NewRequest(http.MethodPost, target, strings.NewReader(string(afrJson))) + request.Header.Set(fiber.HeaderContentType, "application/json") + request.Header.Set("OpenAi-Beta", "assistants=v1") + + resp, err := app.Test(request) + if err != nil { + return AssistantFile{}, resp, err + } + + var assistantFile AssistantFile + all, err := ioutil.ReadAll(resp.Body) + err = json.NewDecoder(strings.NewReader(string(all))).Decode(&assistantFile) + if err != nil { + return AssistantFile{}, resp, err + } + + return assistantFile, resp, nil +} + +func createAssistant(app *fiber.App, ar AssistantRequest) (Assistant, *http.Response, error) { + assistant, err := json.Marshal(ar) + if err != nil { + return Assistant{}, nil, err + } + + request := httptest.NewRequest(http.MethodPost, "/assistants", strings.NewReader(string(assistant))) + request.Header.Set(fiber.HeaderContentType, "application/json") + request.Header.Set("OpenAi-Beta", "assistants=v1") + + resp, err := app.Test(request) + if err != nil { + return Assistant{}, resp, err + } + + bodyString, err := io.ReadAll(resp.Body) + if err != nil { + return Assistant{}, resp, err + } + + var resultAssistant Assistant + err = json.NewDecoder(strings.NewReader(string(bodyString))).Decode(&resultAssistant) + + return resultAssistant, resp, nil +} + +func cleanupAllAssistants(t *testing.T, app *fiber.App, ids []string) func() { + return func() { + for _, assistant := range ids { + target := fmt.Sprintf("/assistants/%s", assistant) + deleteReq := httptest.NewRequest(http.MethodDelete, target, nil) + _, err := app.Test(deleteReq) + if err != nil { + t.Fatalf("Failed to delete assistant %s: %v", assistant, err) + } + } + } +} + +func cleanupAssistantFile(t *testing.T, app *fiber.App, fileId, assistantId string) func() { + return func() { + target := fmt.Sprintf("/assistants/%s/files/%s", assistantId, fileId) + request := httptest.NewRequest(http.MethodDelete, target, nil) + request.Header.Set(fiber.HeaderContentType, "application/json") + request.Header.Set("OpenAi-Beta", "assistants=v1") + + resp, err := app.Test(request) + assert.NoError(t, err) + + var dafr DeleteAssistantFileResponse + err = json.NewDecoder(resp.Body).Decode(&dafr) + assert.NoError(t, err) + assert.True(t, dafr.Deleted) + } +} diff --git a/api/openai/files_test.go b/api/openai/files_test.go index 535cde8ba56..d07ec4a801e 100644 --- a/api/openai/files_test.go +++ b/api/openai/files_test.go @@ -115,8 +115,8 @@ func TestUploadFileExceedSizeLimit(t *testing.T) { assert.Equal(t, 200, resp.StatusCode) listFiles := responseToListFile(t, resp) - if len(listFiles.Data) != len(uploadedFiles) { - t.Errorf("Expected %v files, got %v files", len(uploadedFiles), len(listFiles.Data)) + if len(listFiles.Data) != len(UploadedFiles) { + t.Errorf("Expected %v files, got %v files", len(UploadedFiles), len(listFiles.Data)) } }) t.Run("ListFilesEndpoint with valid purpose parameter", func(t *testing.T) { diff --git a/pkg/utils/config.go b/pkg/utils/config.go index b8f83a23fb3..f3497b55b24 100644 --- a/pkg/utils/config.go +++ b/pkg/utils/config.go @@ -7,13 +7,14 @@ import ( "path/filepath" ) -func SaveConfig(uploadDir, fileName string, obj any) { +func SaveConfig(filePath, fileName string, obj any) { file, err := json.MarshalIndent(obj, "", " ") if err != nil { log.Error().Msgf("Failed to JSON marshal the uploadedFiles: %s", err) } - err = os.WriteFile(filepath.Join(uploadDir, fileName), file, 0644) + absolutePath := filepath.Join(filePath, fileName) + err = os.WriteFile(absolutePath, file, 0644) if err != nil { log.Error().Msgf("Failed to save uploadedFiles to file: %s", err) } From 0f1959f978433c61349b78df80ebe8005b5c49bc Mon Sep 17 00:00:00 2001 From: schristou88 Date: Thu, 14 Mar 2024 03:11:06 -0700 Subject: [PATCH 5/6] Fix models path spelling mistake. --- core/http/endpoints/openai/assistant_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/http/endpoints/openai/assistant_test.go b/core/http/endpoints/openai/assistant_test.go index 72cc29846d0..bdc41ddaf98 100644 --- a/core/http/endpoints/openai/assistant_test.go +++ b/core/http/endpoints/openai/assistant_test.go @@ -38,9 +38,9 @@ func TestAssistantEndpoints(t *testing.T) { // Preparing the mocked objects cl := &config.BackendConfigLoader{} //configsDir := "/tmp/localai/configs" - var ml = model.NewModelLoader("/tmp/localai/models") - modelPath := "/tmp/localai/model" + var ml = model.NewModelLoader(modelPath) + appConfig := &config.ApplicationConfig{ ConfigsDir: configsDir, UploadLimitMB: 10, From 8811c89ec030ef5e76df25be576a6c41974fd9c2 Mon Sep 17 00:00:00 2001 From: schristou88 Date: Thu, 14 Mar 2024 09:15:58 -0700 Subject: [PATCH 6/6] Remove personal go.mod information --- go.mod | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/go.mod b/go.mod index ed616e8183d..bbb9083858c 100644 --- a/go.mod +++ b/go.mod @@ -106,19 +106,3 @@ require ( golang.org/x/text v0.13.0 // indirect golang.org/x/tools v0.12.0 // indirect ) - -replace github.com/nomic-ai/gpt4all/gpt4all-bindings/golang => /Users/schristou/Personal/LocalAI/sources/gpt4all/gpt4all-bindings/golang - -replace github.com/donomii/go-rwkv.cpp => /Users/schristou/Personal/LocalAI/sources/go-rwkv - -replace github.com/ggerganov/whisper.cpp => /Users/schristou/Personal/LocalAI/sources/whisper.cpp - -replace github.com/ggerganov/whisper.cpp/bindings/go => /Users/schristou/Personal/LocalAI/sources/whisper.cpp/bindings/go - -replace github.com/go-skynet/go-bert.cpp => /Users/schristou/Personal/LocalAI/sources/go-bert - -replace github.com/mudler/go-stable-diffusion => /Users/schristou/Personal/LocalAI/sources/go-stable-diffusion - -replace github.com/M0Rf30/go-tiny-dream => /Users/schristou/Personal/LocalAI/sources/go-tiny-dream - -replace github.com/mudler/go-piper => /Users/schristou/Personal/LocalAI/sources/go-piper