Skip to content

Commit

Permalink
Support API Pagination On Model Version API (#40)
Browse files Browse the repository at this point in the history
* Support pagination for model versions API

* Update golang client and python client

* Fix code that use merlin api client

* Fixing type check

* Address PR Review:
1. Generalized pagination implementation
2. Remove unnecessary code

* Support pagination on version list page

* Fix warning on ui

* Reformat generated enum named from swagger

* Fix yarn.lock

* Remove unncessary code

* Add comment to exported method OkWithHeaders

* Support search based on environment name

* Load current state after user deploy or serve model
  • Loading branch information
tiopramayudi authored Jan 8, 2021
1 parent 98f71fe commit fb58351
Show file tree
Hide file tree
Showing 54 changed files with 2,450 additions and 1,621 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ BIN_NAME=merlin
UI_PATH := ui
UI_BUILD_PATH := ${UI_PATH}/build
API_PATH=api
API_ALL_PACKAGES := $(shell cd ${API_PATH} && go list ./... | grep -v github.com/gojek/mlp/api/client | grep -v mocks)
API_ALL_PACKAGES := $(shell cd ${API_PATH} && go list ./... | grep -v github.com/gojek/mlp/api/client | grep -v -e mocks -e client)

all: setup init-dep lint test clean build run

Expand Down
25 changes: 23 additions & 2 deletions api/api/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ package api
import (
"encoding/json"
"net/http"
"strings"
)

// Response handles responses of APIs.
type Response struct {
code int
data interface{}
code int
data interface{}
headers map[string]string
}

// Error represents the structure of an error response.
Expand All @@ -33,6 +35,16 @@ type Error struct {
// WriteTo writes the response header and body.
func (r *Response) WriteTo(w http.ResponseWriter) {
w.Header().Set("Content-Type", "application/json; charset=UTF-8")

exposeHeaders := make([]string, 0, len(r.headers))
for key, value := range r.headers {
exposeHeaders = append(exposeHeaders, key)
w.Header().Set(key, value)
}

allowHeaders := strings.Join(exposeHeaders, ",")
w.Header().Set("Access-Control-Expose-Headers", allowHeaders)

w.WriteHeader(r.code)

if r.data != nil {
Expand All @@ -49,6 +61,15 @@ func Ok(data interface{}) *Response {
}
}

// OkWithHeaders represents the response of status code 200 with custom headers
func OkWithHeaders(data interface{}, headers map[string]string) *Response {
return &Response{
code: http.StatusOK,
data: data,
headers: headers,
}
}

// Created represents the response of status code 201.
func Created(data interface{}) *Response {
return &Response{
Expand Down
16 changes: 14 additions & 2 deletions api/api/versions_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/gojek/merlin/log"
"github.com/gojek/merlin/mlflow"
"github.com/gojek/merlin/models"
"github.com/gojek/merlin/service"
)

type VersionsController struct {
Expand Down Expand Up @@ -79,13 +80,24 @@ func (c *VersionsController) PatchVersion(r *http.Request, vars map[string]strin
func (c *VersionsController) ListVersions(r *http.Request, vars map[string]string, _ interface{}) *Response {
ctx := r.Context()

var query service.VersionQuery
if err := decoder.Decode(&query, r.URL.Query()); err != nil {
log.Errorf("Error while parsing query string %v", err)
return BadRequest(fmt.Sprintf("Unable to parse query string: %s", err))
}

modelID, _ := models.ParseID(vars["model_id"])
versions, err := c.VersionsService.ListVersions(ctx, modelID, c.MonitoringConfig)
versions, nextCursor, err := c.VersionsService.ListVersions(ctx, modelID, c.MonitoringConfig, query)
if err != nil {
return InternalServerError(err.Error())
}

return Ok(versions)
responseHeaders := make(map[string]string)
if nextCursor != "" {
responseHeaders["Next-Cursor"] = nextCursor
}

return OkWithHeaders(versions, responseHeaders)
}

func (c *VersionsController) CreateVersion(r *http.Request, vars map[string]string, _ interface{}) *Response {
Expand Down
60 changes: 56 additions & 4 deletions api/api/versions_api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package api
import (
"fmt"
"net/http"
"net/url"
"testing"

"github.com/gojek/merlin/config"
Expand Down Expand Up @@ -135,6 +136,7 @@ func TestListVersion(t *testing.T) {
desc string
vars map[string]string
versionService func() *mocks.VersionsService
queryParameter string
expected *Response
}{
{
Expand All @@ -144,7 +146,7 @@ func TestListVersion(t *testing.T) {
},
versionService: func() *mocks.VersionsService {
svc := &mocks.VersionsService{}
svc.On("ListVersions", mock.Anything, models.ID(1), mock.Anything).Return([]*models.Version{
svc.On("ListVersions", mock.Anything, models.ID(1), mock.Anything, mock.Anything).Return([]*models.Version{
{
ID: models.ID(1),
ModelID: models.ID(1),
Expand All @@ -159,9 +161,56 @@ func TestListVersion(t *testing.T) {
},
MlflowURL: "http://mlflow.com",
},
}, nil)
}, "", nil)
return svc
},
expected: &Response{
code: http.StatusOK,
data: []*models.Version{
{
ID: models.ID(1),
ModelID: models.ID(1),
Model: &models.Model{
ID: models.ID(1),
Name: "model-1",
ProjectID: models.ID(1),
Project: mlp.Project{},
ExperimentID: 1,
Type: "pyfunc",
MlflowURL: "http://mlflow.com",
},
MlflowURL: "http://mlflow.com",
},
},
headers: map[string]string{},
},
},
{
desc: "Should success get version with pagination",
vars: map[string]string{
"model_id": "1",
},
versionService: func() *mocks.VersionsService {
svc := &mocks.VersionsService{}
svc.On("ListVersions", mock.Anything, models.ID(1), mock.Anything, mock.Anything).Return([]*models.Version{
{
ID: models.ID(1),
ModelID: models.ID(1),
Model: &models.Model{
ID: models.ID(1),
Name: "model-1",
ProjectID: models.ID(1),
Project: mlp.Project{},
ExperimentID: 1,
Type: "pyfunc",
MlflowURL: "http://mlflow.com",
},
MlflowURL: "http://mlflow.com",
},
}, "NDdfMzQ=", nil)
return svc
},
queryParameter: "limit=30",
expected: &Response{
code: http.StatusOK,
data: []*models.Version{
Expand All @@ -180,6 +229,9 @@ func TestListVersion(t *testing.T) {
MlflowURL: "http://mlflow.com",
},
},
headers: map[string]string{
"Next-Cursor": "NDdfMzQ=",
},
},
},
{
Expand All @@ -189,7 +241,7 @@ func TestListVersion(t *testing.T) {
},
versionService: func() *mocks.VersionsService {
svc := &mocks.VersionsService{}
svc.On("ListVersions", mock.Anything, models.ID(1), mock.Anything).Return(nil, fmt.Errorf("DB is down"))
svc.On("ListVersions", mock.Anything, models.ID(1), mock.Anything, mock.Anything).Return(nil, "", fmt.Errorf("DB is down"))
return svc
},
expected: &Response{
Expand All @@ -212,7 +264,7 @@ func TestListVersion(t *testing.T) {
AlertEnabled: true,
},
}
resp := ctl.ListVersions(&http.Request{}, tC.vars, nil)
resp := ctl.ListVersions(&http.Request{URL: &url.URL{RawQuery: tC.queryParameter}}, tC.vars, nil)
assert.Equal(t, tC.expected, resp)
})
}
Expand Down
Loading

0 comments on commit fb58351

Please sign in to comment.