diff --git a/api/service/version_endpoint_service.go b/api/service/version_endpoint_service.go index ca5d34a47..ea34264a0 100644 --- a/api/service/version_endpoint_service.go +++ b/api/service/version_endpoint_service.go @@ -231,8 +231,8 @@ func (k *endpointService) override(left *models.VersionEndpoint, right *models.V // override env vars // Configure environment variables for Pyfunc model - if len(right.EnvVars) > 0 { - left.EnvVars = models.MergeEnvVars(left.EnvVars, right.EnvVars) + if right.EnvVars != nil { + left.EnvVars = right.EnvVars } // override protocol diff --git a/api/service/version_endpoint_service_test.go b/api/service/version_endpoint_service_test.go index 2ee382640..3853eef5e 100644 --- a/api/service/version_endpoint_service_test.go +++ b/api/service/version_endpoint_service_test.go @@ -669,6 +669,12 @@ func TestDeployEndpoint(t *testing.T) { CPURequest: resource.MustParse("1"), MemoryRequest: resource.MustParse("1Gi"), }, + EnvVars: models.EnvVars{ + { + Name: "TF_MODEL_NAME", + Value: "saved_model.pb", + }, + }, Logger: &models.Logger{ Model: &models.LoggerConfig{ Enabled: true, @@ -697,6 +703,12 @@ func TestDeployEndpoint(t *testing.T) { CPURequest: resource.MustParse("1"), MemoryRequest: resource.MustParse("1Gi"), }, + EnvVars: models.EnvVars{ + { + Name: "NUM_OF_ITERATION", + Value: "1", + }, + }, Logger: &models.Logger{ Model: &models.LoggerConfig{ Enabled: true, @@ -723,6 +735,12 @@ func TestDeployEndpoint(t *testing.T) { CPURequest: resource.MustParse("1"), MemoryRequest: resource.MustParse("1Gi"), }, + EnvVars: models.EnvVars{ + { + Name: "NUM_OF_ITERATION", + Value: "1", + }, + }, Logger: &models.Logger{ DestinationURL: loggerDestinationURL, Model: &models.LoggerConfig{