From a4f6746ce2210b2ec7a861429a3785d6bbb86e26 Mon Sep 17 00:00:00 2001 From: kerthcet Date: Thu, 5 Sep 2024 20:43:03 +0800 Subject: [PATCH] Change ModelClaims API Signed-off-by: kerthcet --- README.md | 2 +- api/core/v1alpha1/model_types.go | 41 ++++++----- api/core/v1alpha1/zz_generated.deepcopy.go | 72 ++++++++++++------- api/inference/v1alpha1/playground_types.go | 15 ++-- api/inference/v1alpha1/service_types.go | 5 +- .../v1alpha1/zz_generated.deepcopy.go | 8 +-- .../core/v1alpha1/modelclaims.go | 58 +++++++++++++++ .../core/v1alpha1/modelrepresentative.go | 51 +++++++++++++ .../core/v1alpha1/multimodelsclaim.go | 64 ----------------- .../inference/v1alpha1/playgroundspec.go | 16 ++--- .../inference/v1alpha1/servicespec.go | 14 ++-- client-go/applyconfiguration/utils.go | 6 +- .../bases/inference.llmaz.io_playgrounds.yaml | 50 +++++++------ .../bases/inference.llmaz.io_services.yaml | 42 ++++++----- .../llamacpp/playground.yaml | 13 ++-- .../speculative-decoding/vllm/playground.yaml | 17 +++-- .../inference/playground_controller.go | 59 +++++++++------ .../inference/service_controller.go | 6 +- pkg/controller_helper/backend/backend.go | 10 ++- pkg/controller_helper/backend/llamacpp.go | 36 ++++------ .../backend/llamacpp_test.go | 32 ++++----- pkg/controller_helper/backend/sglang.go | 20 +++--- pkg/controller_helper/backend/sglang_test.go | 30 ++++---- pkg/controller_helper/backend/vllm.go | 36 ++++------ pkg/controller_helper/backend/vllm_test.go | 28 ++++---- .../model_source/modelsource.go | 2 - pkg/webhook/playground_webhook.go | 43 +++++++---- pkg/webhook/service_webhook.go | 33 ++++++++- .../controller/inference/playground_test.go | 6 +- .../controller/inference/service_test.go | 4 +- test/integration/webhook/playground_test.go | 27 ++++--- test/integration/webhook/service_test.go | 33 ++++++++- test/util/mock.go | 2 +- test/util/validation/validate_playground.go | 21 +++--- test/util/validation/validate_service.go | 5 +- test/util/wrapper/playground.go | 17 +++-- test/util/wrapper/service.go | 25 ++++--- 37 files changed, 549 insertions(+), 400 deletions(-) create mode 100644 client-go/applyconfiguration/core/v1alpha1/modelclaims.go create mode 100644 client-go/applyconfiguration/core/v1alpha1/modelrepresentative.go delete mode 100644 client-go/applyconfiguration/core/v1alpha1/multimodelsclaim.go diff --git a/README.md b/README.md index 6365cc0..b3950bf 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@

- llmaz + llmaz

diff --git a/api/core/v1alpha1/model_types.go b/api/core/v1alpha1/model_types.go index e0ca3d6..ec16682 100644 --- a/api/core/v1alpha1/model_types.go +++ b/api/core/v1alpha1/model_types.go @@ -120,28 +120,35 @@ type ModelClaim struct { InferenceFlavors []FlavorName `json:"inferenceFlavors,omitempty"` } -type InferenceMode string +type ModelRole string const ( - Standard InferenceMode = "Standard" - SpeculativeDecoding InferenceMode = "SpeculativeDecoding" + // Main represents the main model, if only one model is required, + // it must be the main model. Only one main model is allowed. + MainRole ModelRole = "main" + // Draft represents the draft model in speculative decoding, + // the main model is the target model then. + DraftRole ModelRole = "draft" ) -// MultiModelsClaim represents claiming for multiple models with different claimModes, -// like standard or speculative-decoding to support different inference scenarios. -type MultiModelsClaim struct { - // ModelNames represents a list of models, there maybe multiple models here - // to support state-of-the-art technologies like speculative decoding. - // If the composedMode is SpeculativeDecoding, the first model is the target model, - // and the second model is the draft model. - // +kubebuilder:validation:MinItems=1 - ModelNames []ModelName `json:"modelNames,omitempty"` - // Mode represents the paradigm to serve the model, whether via a standard way - // or via an advanced technique like SpeculativeDecoding. - // +kubebuilder:default=Standard - // +kubebuilder:validation:Enum={Standard,SpeculativeDecoding} +type ModelRepresentative struct { + // Name represents the model name. + Name ModelName `json:"name"` + // Role represents the model role once more than one model is required. + // +kubebuilder:validation:Enum={main,draft} + // +kubebuilder:default=main // +optional - InferenceMode InferenceMode `json:"inferenceMode,omitempty"` + Role *ModelRole `json:"role,omitempty"` +} + +// ModelClaims represents multiple claims for different models. +type ModelClaims struct { + // Models represents a list of models with roles specified, there maybe + // multiple models here to support state-of-the-art technologies like + // speculative decoding, then one model is main(target) model, another one + // is draft model. + // +kubebuilder:validation:MinItems=1 + Models []ModelRepresentative `json:"models,omitempty"` // InferenceFlavors represents a list of flavors with fungibility supported // to serve the model. // - If not set, always apply with the 0-index model by default. diff --git a/api/core/v1alpha1/zz_generated.deepcopy.go b/api/core/v1alpha1/zz_generated.deepcopy.go index 8ad44d3..241c4c5 100644 --- a/api/core/v1alpha1/zz_generated.deepcopy.go +++ b/api/core/v1alpha1/zz_generated.deepcopy.go @@ -82,6 +82,33 @@ func (in *ModelClaim) DeepCopy() *ModelClaim { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *ModelClaims) DeepCopyInto(out *ModelClaims) { + *out = *in + if in.Models != nil { + in, out := &in.Models, &out.Models + *out = make([]ModelRepresentative, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + if in.InferenceFlavors != nil { + in, out := &in.InferenceFlavors, &out.InferenceFlavors + *out = make([]FlavorName, len(*in)) + copy(*out, *in) + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ModelClaims. +func (in *ModelClaims) DeepCopy() *ModelClaims { + if in == nil { + return nil + } + out := new(ModelClaims) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *ModelHub) DeepCopyInto(out *ModelHub) { *out = *in @@ -112,6 +139,26 @@ func (in *ModelHub) DeepCopy() *ModelHub { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *ModelRepresentative) DeepCopyInto(out *ModelRepresentative) { + *out = *in + if in.Role != nil { + in, out := &in.Role, &out.Role + *out = new(ModelRole) + **out = **in + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ModelRepresentative. +func (in *ModelRepresentative) DeepCopy() *ModelRepresentative { + if in == nil { + return nil + } + out := new(ModelRepresentative) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *ModelSource) DeepCopyInto(out *ModelSource) { *out = *in @@ -182,31 +229,6 @@ func (in *ModelStatus) DeepCopy() *ModelStatus { return out } -// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. -func (in *MultiModelsClaim) DeepCopyInto(out *MultiModelsClaim) { - *out = *in - if in.ModelNames != nil { - in, out := &in.ModelNames, &out.ModelNames - *out = make([]ModelName, len(*in)) - copy(*out, *in) - } - if in.InferenceFlavors != nil { - in, out := &in.InferenceFlavors, &out.InferenceFlavors - *out = make([]FlavorName, len(*in)) - copy(*out, *in) - } -} - -// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MultiModelsClaim. -func (in *MultiModelsClaim) DeepCopy() *MultiModelsClaim { - if in == nil { - return nil - } - out := new(MultiModelsClaim) - in.DeepCopyInto(out) - return out -} - // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *OpenModel) DeepCopyInto(out *OpenModel) { *out = *in diff --git a/api/inference/v1alpha1/playground_types.go b/api/inference/v1alpha1/playground_types.go index 792fe42..2d0dd43 100644 --- a/api/inference/v1alpha1/playground_types.go +++ b/api/inference/v1alpha1/playground_types.go @@ -28,17 +28,16 @@ type PlaygroundSpec struct { // +kubebuilder:default=1 // +optional Replicas *int32 `json:"replicas,omitempty"` - // ModelClaim represents claiming for one model, it's the standard claimMode - // of multiModelsClaim compared to other modes like SpeculativeDecoding. - // Most of the time, modelClaim is enough. - // ModelClaim and multiModelsClaim are exclusive configured. + // ModelClaim represents claiming for one model, it's a simplified use case + // of modelClaims. Most of the time, modelClaim is enough. + // ModelClaim and modelClaims are exclusive configured. // +optional ModelClaim *coreapi.ModelClaim `json:"modelClaim,omitempty"` - // MultiModelsClaim represents claiming for multiple models with different claimModes, - // like standard or speculative-decoding to support different inference scenarios. - // ModelClaim and multiModelsClaim are exclusive configured. + // ModelClaims represents claiming for multiple models for more complicated + // use cases like speculative-decoding. + // ModelClaims and modelClaim are exclusive configured. // +optional - MultiModelsClaim *coreapi.MultiModelsClaim `json:"multiModelsClaim,omitempty"` + ModelClaims *coreapi.ModelClaims `json:"modelClaims,omitempty"` // BackendConfig represents the inference backend configuration // under the hood, e.g. vLLM, which is the default backend. // +optional diff --git a/api/inference/v1alpha1/service_types.go b/api/inference/v1alpha1/service_types.go index 9ab675b..7de6087 100644 --- a/api/inference/v1alpha1/service_types.go +++ b/api/inference/v1alpha1/service_types.go @@ -27,9 +27,8 @@ import ( // Service controller will maintain multi-flavor of workloads with // different accelerators for cost or performance considerations. type ServiceSpec struct { - // MultiModelsClaim represents claiming for multiple models with different claimModes, - // like standard or speculative-decoding to support different inference scenarios. - MultiModelsClaim coreapi.MultiModelsClaim `json:"multiModelsClaim,omitempty"` + // ModelClaims represents multiple claims for different models. + ModelClaims coreapi.ModelClaims `json:"modelClaims,omitempty"` // WorkloadTemplate defines the underlying workload layout and configuration. // Note: the LWS spec might be twisted with various LWS instances to support // accelerator fungibility or other cutting-edge researches. diff --git a/api/inference/v1alpha1/zz_generated.deepcopy.go b/api/inference/v1alpha1/zz_generated.deepcopy.go index cfdad84..dd373e4 100644 --- a/api/inference/v1alpha1/zz_generated.deepcopy.go +++ b/api/inference/v1alpha1/zz_generated.deepcopy.go @@ -166,9 +166,9 @@ func (in *PlaygroundSpec) DeepCopyInto(out *PlaygroundSpec) { *out = new(corev1alpha1.ModelClaim) (*in).DeepCopyInto(*out) } - if in.MultiModelsClaim != nil { - in, out := &in.MultiModelsClaim, &out.MultiModelsClaim - *out = new(corev1alpha1.MultiModelsClaim) + if in.ModelClaims != nil { + in, out := &in.ModelClaims, &out.ModelClaims + *out = new(corev1alpha1.ModelClaims) (*in).DeepCopyInto(*out) } if in.BackendConfig != nil { @@ -301,7 +301,7 @@ func (in *ServiceList) DeepCopyObject() runtime.Object { // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *ServiceSpec) DeepCopyInto(out *ServiceSpec) { *out = *in - in.MultiModelsClaim.DeepCopyInto(&out.MultiModelsClaim) + in.ModelClaims.DeepCopyInto(&out.ModelClaims) in.WorkloadTemplate.DeepCopyInto(&out.WorkloadTemplate) if in.ElasticConfig != nil { in, out := &in.ElasticConfig, &out.ElasticConfig diff --git a/client-go/applyconfiguration/core/v1alpha1/modelclaims.go b/client-go/applyconfiguration/core/v1alpha1/modelclaims.go new file mode 100644 index 0000000..52760ef --- /dev/null +++ b/client-go/applyconfiguration/core/v1alpha1/modelclaims.go @@ -0,0 +1,58 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +// Code generated by applyconfiguration-gen. DO NOT EDIT. + +package v1alpha1 + +import ( + corev1alpha1 "github.com/inftyai/llmaz/api/core/v1alpha1" +) + +// ModelClaimsApplyConfiguration represents an declarative configuration of the ModelClaims type for use +// with apply. +type ModelClaimsApplyConfiguration struct { + Models []ModelRepresentativeApplyConfiguration `json:"models,omitempty"` + InferenceFlavors []corev1alpha1.FlavorName `json:"inferenceFlavors,omitempty"` +} + +// ModelClaimsApplyConfiguration constructs an declarative configuration of the ModelClaims type for use with +// apply. +func ModelClaims() *ModelClaimsApplyConfiguration { + return &ModelClaimsApplyConfiguration{} +} + +// WithModels adds the given value to the Models field in the declarative configuration +// and returns the receiver, so that objects can be build by chaining "With" function invocations. +// If called multiple times, values provided by each call will be appended to the Models field. +func (b *ModelClaimsApplyConfiguration) WithModels(values ...*ModelRepresentativeApplyConfiguration) *ModelClaimsApplyConfiguration { + for i := range values { + if values[i] == nil { + panic("nil value passed to WithModels") + } + b.Models = append(b.Models, *values[i]) + } + return b +} + +// WithInferenceFlavors adds the given value to the InferenceFlavors field in the declarative configuration +// and returns the receiver, so that objects can be build by chaining "With" function invocations. +// If called multiple times, values provided by each call will be appended to the InferenceFlavors field. +func (b *ModelClaimsApplyConfiguration) WithInferenceFlavors(values ...corev1alpha1.FlavorName) *ModelClaimsApplyConfiguration { + for i := range values { + b.InferenceFlavors = append(b.InferenceFlavors, values[i]) + } + return b +} diff --git a/client-go/applyconfiguration/core/v1alpha1/modelrepresentative.go b/client-go/applyconfiguration/core/v1alpha1/modelrepresentative.go new file mode 100644 index 0000000..83477b2 --- /dev/null +++ b/client-go/applyconfiguration/core/v1alpha1/modelrepresentative.go @@ -0,0 +1,51 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +// Code generated by applyconfiguration-gen. DO NOT EDIT. + +package v1alpha1 + +import ( + v1alpha1 "github.com/inftyai/llmaz/api/core/v1alpha1" +) + +// ModelRepresentativeApplyConfiguration represents an declarative configuration of the ModelRepresentative type for use +// with apply. +type ModelRepresentativeApplyConfiguration struct { + Name *v1alpha1.ModelName `json:"name,omitempty"` + Role *v1alpha1.ModelRole `json:"role,omitempty"` +} + +// ModelRepresentativeApplyConfiguration constructs an declarative configuration of the ModelRepresentative type for use with +// apply. +func ModelRepresentative() *ModelRepresentativeApplyConfiguration { + return &ModelRepresentativeApplyConfiguration{} +} + +// WithName sets the Name field in the declarative configuration to the given value +// and returns the receiver, so that objects can be built by chaining "With" function invocations. +// If called multiple times, the Name field is set to the value of the last call. +func (b *ModelRepresentativeApplyConfiguration) WithName(value v1alpha1.ModelName) *ModelRepresentativeApplyConfiguration { + b.Name = &value + return b +} + +// WithRole sets the Role field in the declarative configuration to the given value +// and returns the receiver, so that objects can be built by chaining "With" function invocations. +// If called multiple times, the Role field is set to the value of the last call. +func (b *ModelRepresentativeApplyConfiguration) WithRole(value v1alpha1.ModelRole) *ModelRepresentativeApplyConfiguration { + b.Role = &value + return b +} diff --git a/client-go/applyconfiguration/core/v1alpha1/multimodelsclaim.go b/client-go/applyconfiguration/core/v1alpha1/multimodelsclaim.go deleted file mode 100644 index 3c6a8bc..0000000 --- a/client-go/applyconfiguration/core/v1alpha1/multimodelsclaim.go +++ /dev/null @@ -1,64 +0,0 @@ -/* -Copyright 2024. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -// Code generated by applyconfiguration-gen. DO NOT EDIT. - -package v1alpha1 - -import ( - v1alpha1 "github.com/inftyai/llmaz/api/core/v1alpha1" -) - -// MultiModelsClaimApplyConfiguration represents an declarative configuration of the MultiModelsClaim type for use -// with apply. -type MultiModelsClaimApplyConfiguration struct { - ModelNames []v1alpha1.ModelName `json:"modelNames,omitempty"` - InferenceMode *v1alpha1.InferenceMode `json:"inferenceMode,omitempty"` - InferenceFlavors []v1alpha1.FlavorName `json:"inferenceFlavors,omitempty"` -} - -// MultiModelsClaimApplyConfiguration constructs an declarative configuration of the MultiModelsClaim type for use with -// apply. -func MultiModelsClaim() *MultiModelsClaimApplyConfiguration { - return &MultiModelsClaimApplyConfiguration{} -} - -// WithModelNames adds the given value to the ModelNames field in the declarative configuration -// and returns the receiver, so that objects can be build by chaining "With" function invocations. -// If called multiple times, values provided by each call will be appended to the ModelNames field. -func (b *MultiModelsClaimApplyConfiguration) WithModelNames(values ...v1alpha1.ModelName) *MultiModelsClaimApplyConfiguration { - for i := range values { - b.ModelNames = append(b.ModelNames, values[i]) - } - return b -} - -// WithInferenceMode sets the InferenceMode field in the declarative configuration to the given value -// and returns the receiver, so that objects can be built by chaining "With" function invocations. -// If called multiple times, the InferenceMode field is set to the value of the last call. -func (b *MultiModelsClaimApplyConfiguration) WithInferenceMode(value v1alpha1.InferenceMode) *MultiModelsClaimApplyConfiguration { - b.InferenceMode = &value - return b -} - -// WithInferenceFlavors adds the given value to the InferenceFlavors field in the declarative configuration -// and returns the receiver, so that objects can be build by chaining "With" function invocations. -// If called multiple times, values provided by each call will be appended to the InferenceFlavors field. -func (b *MultiModelsClaimApplyConfiguration) WithInferenceFlavors(values ...v1alpha1.FlavorName) *MultiModelsClaimApplyConfiguration { - for i := range values { - b.InferenceFlavors = append(b.InferenceFlavors, values[i]) - } - return b -} diff --git a/client-go/applyconfiguration/inference/v1alpha1/playgroundspec.go b/client-go/applyconfiguration/inference/v1alpha1/playgroundspec.go index 6c39c92..b9692a3 100644 --- a/client-go/applyconfiguration/inference/v1alpha1/playgroundspec.go +++ b/client-go/applyconfiguration/inference/v1alpha1/playgroundspec.go @@ -24,10 +24,10 @@ import ( // PlaygroundSpecApplyConfiguration represents an declarative configuration of the PlaygroundSpec type for use // with apply. type PlaygroundSpecApplyConfiguration struct { - Replicas *int32 `json:"replicas,omitempty"` - ModelClaim *v1alpha1.ModelClaimApplyConfiguration `json:"modelClaim,omitempty"` - MultiModelsClaim *v1alpha1.MultiModelsClaimApplyConfiguration `json:"multiModelsClaim,omitempty"` - BackendConfig *BackendConfigApplyConfiguration `json:"backendConfig,omitempty"` + Replicas *int32 `json:"replicas,omitempty"` + ModelClaim *v1alpha1.ModelClaimApplyConfiguration `json:"modelClaim,omitempty"` + ModelClaims *v1alpha1.ModelClaimsApplyConfiguration `json:"modelClaims,omitempty"` + BackendConfig *BackendConfigApplyConfiguration `json:"backendConfig,omitempty"` } // PlaygroundSpecApplyConfiguration constructs an declarative configuration of the PlaygroundSpec type for use with @@ -52,11 +52,11 @@ func (b *PlaygroundSpecApplyConfiguration) WithModelClaim(value *v1alpha1.ModelC return b } -// WithMultiModelsClaim sets the MultiModelsClaim field in the declarative configuration to the given value +// WithModelClaims sets the ModelClaims field in the declarative configuration to the given value // and returns the receiver, so that objects can be built by chaining "With" function invocations. -// If called multiple times, the MultiModelsClaim field is set to the value of the last call. -func (b *PlaygroundSpecApplyConfiguration) WithMultiModelsClaim(value *v1alpha1.MultiModelsClaimApplyConfiguration) *PlaygroundSpecApplyConfiguration { - b.MultiModelsClaim = value +// If called multiple times, the ModelClaims field is set to the value of the last call. +func (b *PlaygroundSpecApplyConfiguration) WithModelClaims(value *v1alpha1.ModelClaimsApplyConfiguration) *PlaygroundSpecApplyConfiguration { + b.ModelClaims = value return b } diff --git a/client-go/applyconfiguration/inference/v1alpha1/servicespec.go b/client-go/applyconfiguration/inference/v1alpha1/servicespec.go index f31e425..1ba4aa2 100644 --- a/client-go/applyconfiguration/inference/v1alpha1/servicespec.go +++ b/client-go/applyconfiguration/inference/v1alpha1/servicespec.go @@ -25,9 +25,9 @@ import ( // ServiceSpecApplyConfiguration represents an declarative configuration of the ServiceSpec type for use // with apply. type ServiceSpecApplyConfiguration struct { - MultiModelsClaim *v1alpha1.MultiModelsClaimApplyConfiguration `json:"multiModelsClaim,omitempty"` - WorkloadTemplate *v1.LeaderWorkerSetSpec `json:"workloadTemplate,omitempty"` - ElasticConfig *ElasticConfigApplyConfiguration `json:"elasticConfig,omitempty"` + ModelClaims *v1alpha1.ModelClaimsApplyConfiguration `json:"modelClaims,omitempty"` + WorkloadTemplate *v1.LeaderWorkerSetSpec `json:"workloadTemplate,omitempty"` + ElasticConfig *ElasticConfigApplyConfiguration `json:"elasticConfig,omitempty"` } // ServiceSpecApplyConfiguration constructs an declarative configuration of the ServiceSpec type for use with @@ -36,11 +36,11 @@ func ServiceSpec() *ServiceSpecApplyConfiguration { return &ServiceSpecApplyConfiguration{} } -// WithMultiModelsClaim sets the MultiModelsClaim field in the declarative configuration to the given value +// WithModelClaims sets the ModelClaims field in the declarative configuration to the given value // and returns the receiver, so that objects can be built by chaining "With" function invocations. -// If called multiple times, the MultiModelsClaim field is set to the value of the last call. -func (b *ServiceSpecApplyConfiguration) WithMultiModelsClaim(value *v1alpha1.MultiModelsClaimApplyConfiguration) *ServiceSpecApplyConfiguration { - b.MultiModelsClaim = value +// If called multiple times, the ModelClaims field is set to the value of the last call. +func (b *ServiceSpecApplyConfiguration) WithModelClaims(value *v1alpha1.ModelClaimsApplyConfiguration) *ServiceSpecApplyConfiguration { + b.ModelClaims = value return b } diff --git a/client-go/applyconfiguration/utils.go b/client-go/applyconfiguration/utils.go index 0bb10ec..1ede179 100644 --- a/client-go/applyconfiguration/utils.go +++ b/client-go/applyconfiguration/utils.go @@ -54,16 +54,18 @@ func ForKind(kind schema.GroupVersionKind) interface{} { return &applyconfigurationcorev1alpha1.FlavorApplyConfiguration{} case corev1alpha1.SchemeGroupVersion.WithKind("ModelClaim"): return &applyconfigurationcorev1alpha1.ModelClaimApplyConfiguration{} + case corev1alpha1.SchemeGroupVersion.WithKind("ModelClaims"): + return &applyconfigurationcorev1alpha1.ModelClaimsApplyConfiguration{} case corev1alpha1.SchemeGroupVersion.WithKind("ModelHub"): return &applyconfigurationcorev1alpha1.ModelHubApplyConfiguration{} + case corev1alpha1.SchemeGroupVersion.WithKind("ModelRepresentative"): + return &applyconfigurationcorev1alpha1.ModelRepresentativeApplyConfiguration{} case corev1alpha1.SchemeGroupVersion.WithKind("ModelSource"): return &applyconfigurationcorev1alpha1.ModelSourceApplyConfiguration{} case corev1alpha1.SchemeGroupVersion.WithKind("ModelSpec"): return &applyconfigurationcorev1alpha1.ModelSpecApplyConfiguration{} case corev1alpha1.SchemeGroupVersion.WithKind("ModelStatus"): return &applyconfigurationcorev1alpha1.ModelStatusApplyConfiguration{} - case corev1alpha1.SchemeGroupVersion.WithKind("MultiModelsClaim"): - return &applyconfigurationcorev1alpha1.MultiModelsClaimApplyConfiguration{} case corev1alpha1.SchemeGroupVersion.WithKind("OpenModel"): return &applyconfigurationcorev1alpha1.OpenModelApplyConfiguration{} diff --git a/config/crd/bases/inference.llmaz.io_playgrounds.yaml b/config/crd/bases/inference.llmaz.io_playgrounds.yaml index 766444d..ae35cf7 100644 --- a/config/crd/bases/inference.llmaz.io_playgrounds.yaml +++ b/config/crd/bases/inference.llmaz.io_playgrounds.yaml @@ -224,10 +224,9 @@ spec: type: object modelClaim: description: |- - ModelClaim represents claiming for one model, it's the standard claimMode - of multiModelsClaim compared to other modes like SpeculativeDecoding. - Most of the time, modelClaim is enough. - ModelClaim and multiModelsClaim are exclusive configured. + ModelClaim represents claiming for one model, it's a simplified use case + of modelClaims. Most of the time, modelClaim is enough. + ModelClaim and modelClaims are exclusive configured. properties: inferenceFlavors: description: |- @@ -242,11 +241,11 @@ spec: description: ModelName represents the name of the Model. type: string type: object - multiModelsClaim: + modelClaims: description: |- - MultiModelsClaim represents claiming for multiple models with different claimModes, - like standard or speculative-decoding to support different inference scenarios. - ModelClaim and multiModelsClaim are exclusive configured. + ModelClaims represents claiming for multiple models for more complicated + use cases like speculative-decoding. + ModelClaims and modelClaim are exclusive configured. properties: inferenceFlavors: description: |- @@ -257,23 +256,28 @@ spec: items: type: string type: array - inferenceMode: - default: Standard + models: description: |- - Mode represents the paradigm to serve the model, whether via a standard way - or via an advanced technique like SpeculativeDecoding. - enum: - - Standard - - SpeculativeDecoding - type: string - modelNames: - description: |- - ModelNames represents a list of models, there maybe multiple models here - to support state-of-the-art technologies like speculative decoding. - If the composedMode is SpeculativeDecoding, the first model is the target model, - and the second model is the draft model. + Models represents a list of models with roles specified, there maybe + multiple models here to support state-of-the-art technologies like + speculative decoding, then one model is main(target) model, another one + is draft model. items: - type: string + properties: + name: + description: Name represents the model name. + type: string + role: + default: main + description: Role represents the model role once more than + one model is required. + enum: + - main + - draft + type: string + required: + - name + type: object minItems: 1 type: array type: object diff --git a/config/crd/bases/inference.llmaz.io_services.yaml b/config/crd/bases/inference.llmaz.io_services.yaml index e6bc503..f00ce46 100644 --- a/config/crd/bases/inference.llmaz.io_services.yaml +++ b/config/crd/bases/inference.llmaz.io_services.yaml @@ -65,10 +65,9 @@ spec: format: int32 type: integer type: object - multiModelsClaim: - description: |- - MultiModelsClaim represents claiming for multiple models with different claimModes, - like standard or speculative-decoding to support different inference scenarios. + modelClaims: + description: ModelClaims represents multiple claims for different + models. properties: inferenceFlavors: description: |- @@ -79,23 +78,28 @@ spec: items: type: string type: array - inferenceMode: - default: Standard + models: description: |- - Mode represents the paradigm to serve the model, whether via a standard way - or via an advanced technique like SpeculativeDecoding. - enum: - - Standard - - SpeculativeDecoding - type: string - modelNames: - description: |- - ModelNames represents a list of models, there maybe multiple models here - to support state-of-the-art technologies like speculative decoding. - If the composedMode is SpeculativeDecoding, the first model is the target model, - and the second model is the draft model. + Models represents a list of models with roles specified, there maybe + multiple models here to support state-of-the-art technologies like + speculative decoding, then one model is main(target) model, another one + is draft model. items: - type: string + properties: + name: + description: Name represents the model name. + type: string + role: + default: main + description: Role represents the model role once more than + one model is required. + enum: + - main + - draft + type: string + required: + - name + type: object minItems: 1 type: array type: object diff --git a/docs/examples/speculative-decoding/llamacpp/playground.yaml b/docs/examples/speculative-decoding/llamacpp/playground.yaml index 5ab223e..e237503 100644 --- a/docs/examples/speculative-decoding/llamacpp/playground.yaml +++ b/docs/examples/speculative-decoding/llamacpp/playground.yaml @@ -1,4 +1,4 @@ -# This is just an example, because it doesn't make any sense +# This is just an toy example, because it doesn't make any sense # in real world, drafting tokens for the model with similar size. apiVersion: inference.llmaz.io/v1alpha1 @@ -7,11 +7,12 @@ metadata: name: llamacpp-speculator spec: replicas: 1 - multiModelsClaim: - inferenceMode: SpeculativeDecoding - modelNames: - - llama2-7b-q8-gguf # the target model, should be the first one - - llama2-7b-q2-k-gguf # the draft model + modelClaims: + models: + - name: llama2-7b-q8-gguf # the target model + role: main + - name: llama2-7b-q2-k-gguf # the draft model + role: draft backendConfig: name: llamacpp args: diff --git a/docs/examples/speculative-decoding/vllm/playground.yaml b/docs/examples/speculative-decoding/vllm/playground.yaml index 40f1c43..152f08d 100644 --- a/docs/examples/speculative-decoding/vllm/playground.yaml +++ b/docs/examples/speculative-decoding/vllm/playground.yaml @@ -4,15 +4,14 @@ metadata: name: vllm-speculator spec: replicas: 1 - multiModelsClaim: - inferenceMode: SpeculativeDecoding - modelNames: - - opt-6--7b # the target model, should be the first one - - opt-125m # the draft model + modelClaims: + models: + - name: opt-6--7b # the target model + role: main + - name: opt-125m # the draft model + role: draft backendConfig: args: - --use-v2-block-manager - - -tp - - 1 - - --num_speculative_tokens - - 5 + - --num_speculative_tokens 5 + - -tp 1 diff --git a/pkg/controller/inference/playground_controller.go b/pkg/controller/inference/playground_controller.go index 48f0de4..2fbadbb 100644 --- a/pkg/controller/inference/playground_controller.go +++ b/pkg/controller/inference/playground_controller.go @@ -106,10 +106,10 @@ func (r *PlaygroundReconciler) Reconcile(ctx context.Context, req ctrl.Request) return ctrl.Result{}, err } models = append(models, model) - } else if playground.Spec.MultiModelsClaim != nil { - for _, modelName := range playground.Spec.MultiModelsClaim.ModelNames { + } else if playground.Spec.ModelClaims != nil { + for _, mr := range playground.Spec.ModelClaims.Models { model := &coreapi.OpenModel{} - if err := r.Get(ctx, types.NamespacedName{Name: string(modelName)}, model); err != nil { + if err := r.Get(ctx, types.NamespacedName{Name: string(mr.Name)}, model); err != nil { if apierrors.IsNotFound(err) && handleUnexpectedCondition(playground, false, false) { return ctrl.Result{}, r.Client.Status().Update(ctx, playground) } @@ -192,20 +192,28 @@ func buildServiceApplyConfiguration(models []*coreapi.OpenModel, playground *inf // Build spec. spec := inferenceclientgo.ServiceSpec() - claim := &coreclientgo.MultiModelsClaimApplyConfiguration{} + claim := &coreclientgo.ModelClaimsApplyConfiguration{} if playground.Spec.ModelClaim != nil { - claim = coreclientgo.MultiModelsClaim(). - WithModelNames(playground.Spec.ModelClaim.ModelName). - WithInferenceFlavors(playground.Spec.ModelClaim.InferenceFlavors...). - WithInferenceMode(coreapi.Standard) - } else if playground.Spec.MultiModelsClaim != nil { - claim = coreclientgo.MultiModelsClaim(). - WithModelNames(playground.Spec.MultiModelsClaim.ModelNames...). - WithInferenceFlavors(playground.Spec.MultiModelsClaim.InferenceFlavors...). - WithInferenceMode(playground.Spec.MultiModelsClaim.InferenceMode) + claim = coreclientgo.ModelClaims(). + WithModels(coreclientgo.ModelRepresentative().WithName(playground.Spec.ModelClaim.ModelName).WithRole(coreapi.MainRole)). + WithInferenceFlavors(playground.Spec.ModelClaim.InferenceFlavors...) + } else if playground.Spec.ModelClaims != nil { + mrs := []*coreclientgo.ModelRepresentativeApplyConfiguration{} + for _, model := range playground.Spec.ModelClaims.Models { + role := coreapi.MainRole + if model.Role != nil { + role = *model.Role + } + mr := coreclientgo.ModelRepresentative().WithName(model.Name).WithRole(role) + mrs = append(mrs, mr) + } + + claim = coreclientgo.ModelClaims(). + WithModels(mrs...). + WithInferenceFlavors(playground.Spec.ModelClaims.InferenceFlavors...) } - spec.WithMultiModelsClaim(claim) + spec.WithModelClaims(claim) spec.WithWorkloadTemplate(buildWorkloadTemplate(models, playground)) serviceApplyConfiguration.WithSpec(spec) @@ -237,6 +245,20 @@ func buildWorkloadTemplate(models []*coreapi.OpenModel, playground *inferenceapi return workload } +func involveRole(playground *inferenceapi.Playground) coreapi.ModelRole { + if playground.Spec.ModelClaim != nil { + return coreapi.MainRole + } else if playground.Spec.ModelClaims != nil { + for _, mr := range playground.Spec.ModelClaims.Models { + if *mr.Role != coreapi.MainRole { + return *mr.Role + } + } + } + + return coreapi.MainRole +} + func buildWorkerTemplate(models []*coreapi.OpenModel, playground *inferenceapi.Playground) corev1.PodTemplateSpec { backendName := inferenceapi.DefaultBackend if playground.Spec.BackendConfig != nil && playground.Spec.BackendConfig.Name != nil { @@ -249,12 +271,7 @@ func buildWorkerTemplate(models []*coreapi.OpenModel, playground *inferenceapi.P version = *playground.Spec.BackendConfig.Version } - mode := coreapi.Standard - if playground.Spec.MultiModelsClaim != nil { - mode = playground.Spec.MultiModelsClaim.InferenceMode - } - - args := bkd.Args(models, mode) + args := bkd.Args(models, involveRole(playground)) var envs []corev1.EnvVar if playground.Spec.BackendConfig != nil { @@ -285,7 +302,7 @@ func buildWorkerTemplate(models []*coreapi.OpenModel, playground *inferenceapi.P Name: modelSource.MODEL_RUNNER_CONTAINER_NAME, Image: bkd.Image(version), Resources: resources, - Command: bkd.Command(), + Command: bkd.DefaultCommand(), Args: args, Env: envs, Ports: []corev1.ContainerPort{ diff --git a/pkg/controller/inference/service_controller.go b/pkg/controller/inference/service_controller.go index 0f1f998..fb95ec9 100644 --- a/pkg/controller/inference/service_controller.go +++ b/pkg/controller/inference/service_controller.go @@ -81,9 +81,9 @@ func (r *ServiceReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ct logger.V(10).Info("reconcile Service", "Playground", klog.KObj(service)) models := []*coreapi.OpenModel{} - for _, modelName := range service.Spec.MultiModelsClaim.ModelNames { + for _, mr := range service.Spec.ModelClaims.Models { model := &coreapi.OpenModel{} - if err := r.Get(ctx, types.NamespacedName{Name: string(modelName)}, model); err != nil { + if err := r.Get(ctx, types.NamespacedName{Name: string(mr.Name)}, model); err != nil { return ctrl.Result{}, err } models = append(models, model) @@ -153,7 +153,7 @@ func injectModelProperties(template *applyconfigurationv1.LeaderWorkerTemplateAp } // We treat the 0-index model as the main model, we only consider the main model's requirements, - // like label, flavor. + // like label, flavor. Note: this may change in the future, let's see. template.WorkerTemplate.Labels = util.MergeKVs(template.WorkerTemplate.Labels, modelLabels(models[0])) injectModelFlavor(template, models[0]) } diff --git a/pkg/controller_helper/backend/backend.go b/pkg/controller_helper/backend/backend.go index d5b842b..249ec9c 100644 --- a/pkg/controller_helper/backend/backend.go +++ b/pkg/controller_helper/backend/backend.go @@ -36,13 +36,11 @@ type Backend interface { DefaultVersion() string // DefaultResources returns the default resources set for the container. DefaultResources() inferenceapi.ResourceRequirements - // Command returns the command to start the inference backend. - Command() []string + // DefaultCommand returns the command to start the inference backend. + DefaultCommand() []string // Args returns the bootstrap arguments to start the backend. - Args([]*coreapi.OpenModel, coreapi.InferenceMode) []string - - // defaultArgs returns the bootstrap arguments when inferenceMode is standard. - defaultArgs(*coreapi.OpenModel) []string + // The second parameter represents which particular modelRole involved, like draft. + Args([]*coreapi.OpenModel, coreapi.ModelRole) []string } // SpeculativeBackend represents backend supports speculativeDecoding inferenceMode. diff --git a/pkg/controller_helper/backend/llamacpp.go b/pkg/controller_helper/backend/llamacpp.go index e4404aa..cc2de38 100644 --- a/pkg/controller_helper/backend/llamacpp.go +++ b/pkg/controller_helper/backend/llamacpp.go @@ -28,7 +28,6 @@ import ( ) var _ Backend = (*LLAMACPP)(nil) -var _ SpeculativeBackend = (*LLAMACPP)(nil) type LLAMACPP struct{} @@ -61,37 +60,26 @@ func (l *LLAMACPP) DefaultResources() inferenceapi.ResourceRequirements { } } -func (l *LLAMACPP) Command() []string { +func (l *LLAMACPP) DefaultCommand() []string { return []string{"./llama-server"} } -func (l *LLAMACPP) Args(models []*coreapi.OpenModel, mode coreapi.InferenceMode) []string { - if mode == coreapi.Standard { - return l.defaultArgs(models[0]) - } - if mode == coreapi.SpeculativeDecoding { - return l.speculativeArgs(models) - } - // We should not reach here. - return nil -} +func (l *LLAMACPP) Args(models []*coreapi.OpenModel, involvedRole coreapi.ModelRole) []string { + targetModelSource := modelSource.NewModelSourceProvider(models[0]) -func (l *LLAMACPP) defaultArgs(model *coreapi.OpenModel) []string { - source := modelSource.NewModelSourceProvider(model) - return []string{ - "-m", source.ModelPath(), - "--port", strconv.Itoa(DEFAULT_BACKEND_PORT), - "--host", "0.0.0.0", + if involvedRole == coreapi.DraftRole { + draftModelSource := modelSource.NewModelSourceProvider(models[1]) + return []string{ + "-m", targetModelSource.ModelPath(), + "-md", draftModelSource.ModelPath(), + "--host", "0.0.0.0", + "--port", strconv.Itoa(DEFAULT_BACKEND_PORT), + } } -} -func (l *LLAMACPP) speculativeArgs(models []*coreapi.OpenModel) []string { - targetModelSource := modelSource.NewModelSourceProvider(models[0]) - draftModelSource := modelSource.NewModelSourceProvider(models[1]) return []string{ "-m", targetModelSource.ModelPath(), - "-md", draftModelSource.ModelPath(), - "--port", strconv.Itoa(DEFAULT_BACKEND_PORT), "--host", "0.0.0.0", + "--port", strconv.Itoa(DEFAULT_BACKEND_PORT), } } diff --git a/pkg/controller_helper/backend/llamacpp_test.go b/pkg/controller_helper/backend/llamacpp_test.go index f338265..b740259 100644 --- a/pkg/controller_helper/backend/llamacpp_test.go +++ b/pkg/controller_helper/backend/llamacpp_test.go @@ -57,41 +57,41 @@ func Test_llamacpp(t *testing.T) { } testCases := []struct { - name string - mode coreapi.InferenceMode - wantCommand []string - wantArgs []string + name string + involvedRole coreapi.ModelRole + wantCommand []string + wantArgs []string }{ { - name: "standard mode", - mode: coreapi.Standard, - wantCommand: []string{"./llama-server"}, + name: "one main model", + involvedRole: coreapi.MainRole, + wantCommand: []string{"./llama-server"}, wantArgs: []string{ "-m", "/workspace/models/models--hub--model-1", - "--port", "8080", "--host", "0.0.0.0", + "--port", "8080", }, }, { - name: "speculative decoding", - mode: coreapi.SpeculativeDecoding, - wantCommand: []string{"./llama-server"}, + name: "speculative decoding", + involvedRole: coreapi.DraftRole, + wantCommand: []string{"./llama-server"}, wantArgs: []string{ "-m", "/workspace/models/models--hub--model-1", "-md", "/workspace/models/models--hub--model-2", - "--port", "8080", "--host", "0.0.0.0", + "--port", "8080", }, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - if diff := cmp.Diff(backend.Command(), tc.wantCommand); diff != "" { - t.Fatalf("unexpected command, want %v, got %v", tc.wantCommand, backend.Command()) + if diff := cmp.Diff(backend.DefaultCommand(), tc.wantCommand); diff != "" { + t.Fatalf("unexpected command, want %v, got %v", tc.wantCommand, backend.DefaultCommand()) } - if diff := cmp.Diff(backend.Args(models, tc.mode), tc.wantArgs); diff != "" { - t.Fatalf("unexpected args, want %v, got %v", tc.wantArgs, backend.Args(models, tc.mode)) + if diff := cmp.Diff(backend.Args(models, tc.involvedRole), tc.wantArgs); diff != "" { + t.Fatalf("unexpected args, want %v, got %v", tc.wantArgs, backend.Args(models, tc.involvedRole)) } }) } diff --git a/pkg/controller_helper/backend/sglang.go b/pkg/controller_helper/backend/sglang.go index f9463f1..12a7307 100644 --- a/pkg/controller_helper/backend/sglang.go +++ b/pkg/controller_helper/backend/sglang.go @@ -60,23 +60,21 @@ func (s *SGLANG) DefaultResources() inferenceapi.ResourceRequirements { } } -func (s *SGLANG) Command() []string { +func (s *SGLANG) DefaultCommand() []string { return []string{"python3", "-m", "sglang.launch_server"} } -func (s *SGLANG) Args(models []*coreapi.OpenModel, mode coreapi.InferenceMode) []string { - if mode == coreapi.Standard { - return s.defaultArgs(models[0]) +func (s *SGLANG) Args(models []*coreapi.OpenModel, involvedRole coreapi.ModelRole) []string { + targetModelSource := modelSource.NewModelSourceProvider(models[0]) + + if involvedRole == coreapi.DraftRole { + // TODO: support speculative decoding + return nil } - // We should not reach here. - return nil -} -func (s *SGLANG) defaultArgs(model *coreapi.OpenModel) []string { - source := modelSource.NewModelSourceProvider(model) return []string{ - "--model-path", source.ModelPath(), - "--served-model-name", source.ModelName(), + "--model-path", targetModelSource.ModelPath(), + "--served-model-name", targetModelSource.ModelName(), "--host", "0.0.0.0", "--port", strconv.Itoa(DEFAULT_BACKEND_PORT), } diff --git a/pkg/controller_helper/backend/sglang_test.go b/pkg/controller_helper/backend/sglang_test.go index ecbf153..ce1bb50 100644 --- a/pkg/controller_helper/backend/sglang_test.go +++ b/pkg/controller_helper/backend/sglang_test.go @@ -57,15 +57,15 @@ func Test_SGLANG(t *testing.T) { } testCases := []struct { - name string - mode coreapi.InferenceMode - wantCommand []string - wantArgs []string + name string + involvedRole coreapi.ModelRole + wantCommand []string + wantArgs []string }{ { - name: "standard mode", - mode: coreapi.Standard, - wantCommand: []string{"python3", "-m", "sglang.launch_server"}, + name: "one main model", + involvedRole: coreapi.MainRole, + wantCommand: []string{"python3", "-m", "sglang.launch_server"}, wantArgs: []string{ "--model-path", "/workspace/models/models--hub--model-1", "--served-model-name", "model-1", @@ -74,20 +74,20 @@ func Test_SGLANG(t *testing.T) { }, }, { - name: "speculative decoding", - mode: coreapi.SpeculativeDecoding, - wantCommand: []string{"python3", "-m", "sglang.launch_server"}, - wantArgs: nil, + name: "speculative decoding", + involvedRole: coreapi.DraftRole, + wantCommand: []string{"python3", "-m", "sglang.launch_server"}, + wantArgs: nil, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - if diff := cmp.Diff(backend.Command(), tc.wantCommand); diff != "" { - t.Fatalf("unexpected command, want %v, got %v", tc.wantCommand, backend.Command()) + if diff := cmp.Diff(backend.DefaultCommand(), tc.wantCommand); diff != "" { + t.Fatalf("unexpected command, want %v, got %v", tc.wantCommand, backend.DefaultCommand()) } - if diff := cmp.Diff(backend.Args(models, tc.mode), tc.wantArgs); diff != "" { - t.Fatalf("unexpected args, want %v, got %v", tc.wantArgs, backend.Args(models, tc.mode)) + if diff := cmp.Diff(backend.Args(models, tc.involvedRole), tc.wantArgs); diff != "" { + t.Fatalf("unexpected args, want %v, got %v", tc.wantArgs, backend.Args(models, tc.involvedRole)) } }) } diff --git a/pkg/controller_helper/backend/vllm.go b/pkg/controller_helper/backend/vllm.go index 8334af7..467bfbc 100644 --- a/pkg/controller_helper/backend/vllm.go +++ b/pkg/controller_helper/backend/vllm.go @@ -28,7 +28,6 @@ import ( ) var _ Backend = (*VLLM)(nil) -var _ SpeculativeBackend = (*VLLM)(nil) type VLLM struct{} @@ -61,37 +60,26 @@ func (v *VLLM) DefaultResources() inferenceapi.ResourceRequirements { } } -func (v *VLLM) Command() []string { +func (v *VLLM) DefaultCommand() []string { return []string{"python3", "-m", "vllm.entrypoints.openai.api_server"} } -func (v *VLLM) Args(models []*coreapi.OpenModel, mode coreapi.InferenceMode) []string { - if mode == coreapi.Standard { - return v.defaultArgs(models[0]) - } - if mode == coreapi.SpeculativeDecoding { - return v.speculativeArgs(models) - } - // We should not reach here. - return nil -} +func (v *VLLM) Args(models []*coreapi.OpenModel, involvedRole coreapi.ModelRole) []string { + targetModelSource := modelSource.NewModelSourceProvider(models[0]) -func (v *VLLM) defaultArgs(model *coreapi.OpenModel) []string { - source := modelSource.NewModelSourceProvider(model) - return []string{ - "--model", source.ModelPath(), - "--served-model-name", source.ModelName(), - "--host", "0.0.0.0", - "--port", strconv.Itoa(DEFAULT_BACKEND_PORT), + if involvedRole == coreapi.DraftRole { + draftModelSource := modelSource.NewModelSourceProvider(models[1]) + return []string{ + "--model", targetModelSource.ModelPath(), + "--speculative_model", draftModelSource.ModelPath(), + "--served-model-name", targetModelSource.ModelName(), + "--host", "0.0.0.0", + "--port", strconv.Itoa(DEFAULT_BACKEND_PORT), + } } -} -func (v *VLLM) speculativeArgs(models []*coreapi.OpenModel) []string { - targetModelSource := modelSource.NewModelSourceProvider(models[0]) - draftModelSource := modelSource.NewModelSourceProvider(models[1]) return []string{ "--model", targetModelSource.ModelPath(), - "--speculative_model", draftModelSource.ModelPath(), "--served-model-name", targetModelSource.ModelName(), "--host", "0.0.0.0", "--port", strconv.Itoa(DEFAULT_BACKEND_PORT), diff --git a/pkg/controller_helper/backend/vllm_test.go b/pkg/controller_helper/backend/vllm_test.go index 7b8d062..d75fe4e 100644 --- a/pkg/controller_helper/backend/vllm_test.go +++ b/pkg/controller_helper/backend/vllm_test.go @@ -57,15 +57,15 @@ func Test_vllm(t *testing.T) { } testCases := []struct { - name string - mode coreapi.InferenceMode - wantCommand []string - wantArgs []string + name string + involvedRole coreapi.ModelRole + wantCommand []string + wantArgs []string }{ { - name: "standard mode", - mode: coreapi.Standard, - wantCommand: []string{"python3", "-m", "vllm.entrypoints.openai.api_server"}, + name: "one main model", + involvedRole: coreapi.MainRole, + wantCommand: []string{"python3", "-m", "vllm.entrypoints.openai.api_server"}, wantArgs: []string{ "--model", "/workspace/models/models--hub--model-1", "--served-model-name", "model-1", @@ -74,9 +74,9 @@ func Test_vllm(t *testing.T) { }, }, { - name: "speculative decoding", - mode: coreapi.SpeculativeDecoding, - wantCommand: []string{"python3", "-m", "vllm.entrypoints.openai.api_server"}, + name: "speculative decoding", + involvedRole: coreapi.DraftRole, + wantCommand: []string{"python3", "-m", "vllm.entrypoints.openai.api_server"}, wantArgs: []string{ "--model", "/workspace/models/models--hub--model-1", "--speculative_model", "/workspace/models/models--hub--model-2", @@ -89,11 +89,11 @@ func Test_vllm(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - if diff := cmp.Diff(backend.Command(), tc.wantCommand); diff != "" { - t.Fatalf("unexpected command, want %v, got %v", tc.wantCommand, backend.Command()) + if diff := cmp.Diff(backend.DefaultCommand(), tc.wantCommand); diff != "" { + t.Fatalf("unexpected command, want %v, got %v", tc.wantCommand, backend.DefaultCommand()) } - if diff := cmp.Diff(backend.Args(models, tc.mode), tc.wantArgs); diff != "" { - t.Fatalf("unexpected args, want %v, got %v", tc.wantArgs, backend.Args(models, tc.mode)) + if diff := cmp.Diff(backend.Args(models, tc.involvedRole), tc.wantArgs); diff != "" { + t.Fatalf("unexpected args, want %v, got %v", tc.wantArgs, backend.Args(models, tc.involvedRole)) } }) } diff --git a/pkg/controller_helper/model_source/modelsource.go b/pkg/controller_helper/model_source/modelsource.go index a32573a..75676e6 100644 --- a/pkg/controller_helper/model_source/modelsource.go +++ b/pkg/controller_helper/model_source/modelsource.go @@ -53,8 +53,6 @@ type ModelSourceProvider interface { ModelName() string ModelPath() string // InjectModelLoader will inject the model loader to the spec, - // initContainerOnly means whether to inject specs other than initContainers, - // just in case of rewriting the specs, // index refers to the suffix of the initContainer name, like model-loader, model-loader-1. InjectModelLoader(spec *corev1.PodTemplateSpec, index int) } diff --git a/pkg/webhook/playground_webhook.go b/pkg/webhook/playground_webhook.go index 2fa8240..4f42be5 100644 --- a/pkg/webhook/playground_webhook.go +++ b/pkg/webhook/playground_webhook.go @@ -52,9 +52,12 @@ func (w *PlaygroundWebhook) Default(ctx context.Context, obj runtime.Object) err var modelName string if playground.Spec.ModelClaim != nil { modelName = string(playground.Spec.ModelClaim.ModelName) - } else if playground.Spec.MultiModelsClaim != nil { - // We choose the first model as the main model. - modelName = string(playground.Spec.MultiModelsClaim.ModelNames[0]) + } else if playground.Spec.ModelClaims != nil { + for _, model := range playground.Spec.ModelClaims.Models { + if model.Role == nil || *model.Role == coreapi.MainRole { + modelName = string(model.Name) + } + } } if playground.Labels == nil { @@ -95,22 +98,34 @@ func (w *PlaygroundWebhook) generateValidate(obj runtime.Object) field.ErrorList specPath := field.NewPath("spec") var allErrs field.ErrorList - if playground.Spec.ModelClaim == nil && playground.Spec.MultiModelsClaim == nil { - allErrs = append(allErrs, field.Forbidden(specPath, "modelClaim and multiModelsClaim couldn't be both nil")) + if playground.Spec.ModelClaim == nil && playground.Spec.ModelClaims == nil { + allErrs = append(allErrs, field.Forbidden(specPath, "modelClaim and modelClaims couldn't be both nil")) } - if playground.Spec.MultiModelsClaim != nil { - if playground.Spec.MultiModelsClaim.InferenceMode == coreapi.SpeculativeDecoding { - // if playground.Spec.BackendConfig != nil && !(*playground.Spec.BackendConfig.Name == inferenceapi.VLLM || *playground.Spec.BackendConfig.Name == inferenceapi.LLAMACPP) { - // allErrs = append(allErrs, field.Forbidden(specPath.Child("multiModelsClaim", "inferenceMode"), "only vLLM and llama.cpp supports speculativeDecoding mode")) - // } - if playground.Spec.BackendConfig != nil && *playground.Spec.BackendConfig.Name != inferenceapi.VLLM { - allErrs = append(allErrs, field.Forbidden(specPath.Child("multiModelsClaim", "inferenceMode"), "only vLLM supports speculativeDecoding mode")) + if playground.Spec.ModelClaims != nil { + mainModelCount := 0 + var speculativeDecoding bool + + for _, model := range playground.Spec.ModelClaims.Models { + if model.Name == coreapi.ModelName(coreapi.MainRole) { + mainModelCount += 1 + } + if *model.Role == coreapi.DraftRole { + speculativeDecoding = true } - if len(playground.Spec.MultiModelsClaim.ModelNames) != 2 { - allErrs = append(allErrs, field.Forbidden(specPath.Child("multiModelsClaim", "modelNames"), "only two models are allowed in speculativeDecoding mode")) + } + + if speculativeDecoding { + if len(playground.Spec.ModelClaims.Models) != 2 { + allErrs = append(allErrs, field.Forbidden(specPath.Child("modelClaims", "models"), "only two models are allowed in speculativeDecoding mode")) + } + if playground.Spec.BackendConfig != nil && *playground.Spec.BackendConfig.Name != inferenceapi.VLLM { + allErrs = append(allErrs, field.Forbidden(specPath.Child("backendConfig", "name"), "only vLLM supports speculativeDecoding mode")) } } + if mainModelCount > 1 { + allErrs = append(allErrs, field.Forbidden(specPath.Child("modelClaims", "models"), "only one main model is allowed")) + } } return allErrs } diff --git a/pkg/webhook/service_webhook.go b/pkg/webhook/service_webhook.go index 3756c96..fd21b2e 100644 --- a/pkg/webhook/service_webhook.go +++ b/pkg/webhook/service_webhook.go @@ -26,6 +26,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/webhook" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + coreapi "github.com/inftyai/llmaz/api/core/v1alpha1" inferenceapi "github.com/inftyai/llmaz/api/inference/v1alpha1" modelSource "github.com/inftyai/llmaz/pkg/controller_helper/model_source" ) @@ -56,7 +57,7 @@ var _ webhook.CustomValidator = &ServiceWebhook{} // ValidateCreate implements webhook.Validator so a webhook will be registered for the type func (w *ServiceWebhook) ValidateCreate(ctx context.Context, obj runtime.Object) (admission.Warnings, error) { - allErrs := field.ErrorList{} + allErrs := w.generateValidate(obj) service := obj.(*inferenceapi.Service) for _, err := range validation.IsDNS1123Label(service.Name) { allErrs = append(allErrs, field.Invalid(field.NewPath("metadata.name"), service.Name, err)) @@ -78,10 +79,38 @@ func (w *ServiceWebhook) ValidateCreate(ctx context.Context, obj runtime.Object) // ValidateUpdate implements webhook.Validator so a webhook will be registered for the type func (w *ServiceWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Object) (admission.Warnings, error) { - return nil, nil + allErrs := w.generateValidate(newObj) + return nil, allErrs.ToAggregate() } // ValidateDelete implements webhook.Validator so a webhook will be registered for the type func (w *ServiceWebhook) ValidateDelete(ctx context.Context, obj runtime.Object) (admission.Warnings, error) { return nil, nil } + +func (w *ServiceWebhook) generateValidate(obj runtime.Object) field.ErrorList { + service := obj.(*inferenceapi.Service) + specPath := field.NewPath("spec") + var allErrs field.ErrorList + + mainModelCount := 0 + var speculativeDecoding bool + for _, model := range service.Spec.ModelClaims.Models { + if model.Role == nil || *model.Role == coreapi.MainRole { + mainModelCount += 1 + } + if model.Role != nil && *model.Role == coreapi.DraftRole { + speculativeDecoding = true + } + } + + if speculativeDecoding { + if len(service.Spec.ModelClaims.Models) != 2 { + allErrs = append(allErrs, field.Forbidden(specPath.Child("modelClaims", "models"), "only two models are allowed in speculativeDecoding mode")) + } + if mainModelCount != 1 { + allErrs = append(allErrs, field.Forbidden(specPath.Child("modelClaims", "models"), "main model is required")) + } + } + return allErrs +} diff --git a/test/integration/controller/inference/playground_test.go b/test/integration/controller/inference/playground_test.go index 6591339..db6b566 100644 --- a/test/integration/controller/inference/playground_test.go +++ b/test/integration/controller/inference/playground_test.go @@ -182,7 +182,7 @@ var _ = ginkgo.Describe("playground controller test", func() { }), ginkgo.Entry("Playground with speculativeDecoding", &testValidatingCase{ makePlayground: func() *inferenceapi.Playground { - return wrapper.MakePlayground("playground", ns.Name).MultiModelsClaim([]string{model.Name, draftModel.Name}, coreapi.SpeculativeDecoding).Label(coreapi.ModelNameLabelKey, model.Name). + return wrapper.MakePlayground("playground", ns.Name).ModelClaims([]string{model.Name, draftModel.Name}, []string{"main", "draft"}).Label(coreapi.ModelNameLabelKey, model.Name). Obj() }, updates: []*update{ @@ -242,7 +242,7 @@ var _ = ginkgo.Describe("playground controller test", func() { updateFunc: func(playground *inferenceapi.Playground) { // Create a service with the same name as the playground. service := wrapper.MakeService(playground.Name, playground.Namespace). - ModelsClaim([]string{"llama3-8b"}, coreapi.Standard, nil). + ModelClaims([]string{"llama3-8b"}, []string{"main"}). WorkerTemplate(). Obj() gomega.Expect(k8sClient.Create(ctx, service)).To(gomega.Succeed()) @@ -256,7 +256,7 @@ var _ = ginkgo.Describe("playground controller test", func() { // Delete the service, playground should be updated to Pending. updateFunc: func(playground *inferenceapi.Playground) { service := wrapper.MakeService(playground.Name, playground.Namespace). - ModelsClaim([]string{"llama3-8b"}, coreapi.Standard, nil). + ModelClaims([]string{"llama3-8b"}, []string{"main"}). WorkerTemplate(). Obj() gomega.Expect(k8sClient.Delete(ctx, service)).To(gomega.Succeed()) diff --git a/test/integration/controller/inference/service_test.go b/test/integration/controller/inference/service_test.go index 4afaae3..cb18eb7 100644 --- a/test/integration/controller/inference/service_test.go +++ b/test/integration/controller/inference/service_test.go @@ -157,7 +157,7 @@ var _ = ginkgo.Describe("inferenceService controller test", func() { ginkgo.Entry("service created with URI configured Model", &testValidatingCase{ makeService: func() *inferenceapi.Service { return wrapper.MakeService("service-llama3-8b", ns.Name). - ModelsClaim([]string{"model-with-uri"}, coreapi.Standard, nil). + ModelClaims([]string{"model-with-uri"}, []string{"main"}). WorkerTemplate(). Obj() }, @@ -185,7 +185,7 @@ var _ = ginkgo.Describe("inferenceService controller test", func() { ginkgo.Entry("service created with speculativeDecoding mode", &testValidatingCase{ makeService: func() *inferenceapi.Service { return wrapper.MakeService("service-llama3-8b", ns.Name). - ModelsClaim([]string{"llama3-8b", "model-with-uri"}, coreapi.SpeculativeDecoding, nil). + ModelClaims([]string{"llama3-8b", "model-with-uri"}, []string{"main", "draft"}). WorkerTemplate(). Obj() }, diff --git a/test/integration/webhook/playground_test.go b/test/integration/webhook/playground_test.go index e61f2f6..0f244be 100644 --- a/test/integration/webhook/playground_test.go +++ b/test/integration/webhook/playground_test.go @@ -89,13 +89,13 @@ var _ = ginkgo.Describe("playground default and validation", func() { }), ginkgo.Entry("speculativeDecoding with SGLang is not allowed", &testValidatingCase{ playground: func() *inferenceapi.Playground { - return wrapper.MakePlayground("playground", ns.Name).Replicas(1).MultiModelsClaim([]string{"llama3-405b", "llama3-8b"}, coreapi.SpeculativeDecoding).Backend(string(inferenceapi.SGLANG)).Obj() + return wrapper.MakePlayground("playground", ns.Name).Replicas(1).ModelClaims([]string{"llama3-405b", "llama3-8b"}, []string{"main", "draft"}).Backend(string(inferenceapi.SGLANG)).Obj() }, failed: true, }), - ginkgo.Entry("speculativeDecoding with three models claimed", &testValidatingCase{ + ginkgo.Entry("speculativeDecoding with three models is not allowed", &testValidatingCase{ playground: func() *inferenceapi.Playground { - return wrapper.MakePlayground("playground", ns.Name).Replicas(1).MultiModelsClaim([]string{"llama3-405b", "llama3-8b", "llama3-2b"}, coreapi.SpeculativeDecoding).Obj() + return wrapper.MakePlayground("playground", ns.Name).Replicas(1).ModelClaims([]string{"llama3-405b", "llama3-8b", "llama3-2b"}, []string{"main", "draft", "draft"}).Obj() }, failed: true, }), @@ -105,9 +105,9 @@ var _ = ginkgo.Describe("playground default and validation", func() { }, failed: true, }), - ginkgo.Entry("unknown inference mode", &testValidatingCase{ + ginkgo.Entry("no main model", &testValidatingCase{ playground: func() *inferenceapi.Playground { - return wrapper.MakePlayground("playground", ns.Name).Replicas(1).MultiModelsClaim([]string{"llama3-405b", "llama3-8b"}, coreapi.InferenceMode("unknown")).Obj() + return wrapper.MakePlayground("playground", ns.Name).Replicas(1).ModelClaims([]string{"llama3-8b"}, []string{"draft"}).Obj() }, failed: true, }), @@ -133,16 +133,25 @@ var _ = ginkgo.Describe("playground default and validation", func() { return wrapper.MakePlayground("playground", ns.Name).ModelClaim("llama3-8b").Replicas(1).Label(coreapi.ModelNameLabelKey, "llama3-8b").Obj() }, }), - ginkgo.Entry("defaulting inferenceMode with multiModelsClaim", &testDefaultingCase{ + ginkgo.Entry("defaulting model role with modelClaims", &testDefaultingCase{ playground: func() *inferenceapi.Playground { playground := wrapper.MakePlayground("playground", ns.Name).Replicas(1).Obj() - playground.Spec.MultiModelsClaim = &coreapi.MultiModelsClaim{ - ModelNames: []coreapi.ModelName{"llama3-405b", "llama3-8b"}, + draftRole := coreapi.DraftRole + playground.Spec.ModelClaims = &coreapi.ModelClaims{ + Models: []coreapi.ModelRepresentative{ + { + Name: "llama3-405b", + }, + { + Name: "llama3-8b", + Role: &draftRole, + }, + }, } return playground }, wantPlayground: func() *inferenceapi.Playground { - return wrapper.MakePlayground("playground", ns.Name).MultiModelsClaim([]string{"llama3-405b", "llama3-8b"}, coreapi.Standard).Replicas(1).Label(coreapi.ModelNameLabelKey, "llama3-405b").Obj() + return wrapper.MakePlayground("playground", ns.Name).ModelClaims([]string{"llama3-405b", "llama3-8b"}, []string{"main", "draft"}).Replicas(1).Label(coreapi.ModelNameLabelKey, "llama3-405b").Obj() }, }), ) diff --git a/test/integration/webhook/service_test.go b/test/integration/webhook/service_test.go index 6ed4c47..abd9ada 100644 --- a/test/integration/webhook/service_test.go +++ b/test/integration/webhook/service_test.go @@ -22,7 +22,6 @@ import ( corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - coreapi "github.com/inftyai/llmaz/api/core/v1alpha1" inferenceapi "github.com/inftyai/llmaz/api/inference/v1alpha1" "github.com/inftyai/llmaz/test/util" "github.com/inftyai/llmaz/test/util/wrapper" @@ -73,12 +72,42 @@ var _ = ginkgo.Describe("service default and validation", func() { ginkgo.Entry("model-runner container doesn't exist", &testValidatingCase{ service: func() *inferenceapi.Service { return wrapper.MakeService("service-llama3-8b", ns.Name). - ModelsClaim([]string{"llama3-8b"}, coreapi.Standard, nil). + ModelClaims([]string{"llama3-8b"}, []string{"main"}). WorkerTemplate(). ContainerName("model-runner-fake"). Obj() }, failed: true, }), + ginkgo.Entry("speculative-decoding with three models", &testValidatingCase{ + service: func() *inferenceapi.Service { + return wrapper.MakeService("service-llama3-8b", ns.Name). + ModelClaims([]string{"llama3-405b", "llama3-8b", "llama3-2b"}, []string{"main", "draft", "draft"}). + WorkerTemplate(). + Obj() + }, + failed: true, + }), + ginkgo.Entry("modelClaims with nil role", &testValidatingCase{ + service: func() *inferenceapi.Service { + service := wrapper.MakeService("service-llama3-8b", ns.Name). + ModelClaims([]string{"llama3-405b", "llama3-8b"}, []string{"main", "draft"}). + WorkerTemplate(). + Obj() + // Set the role to nil + service.Spec.ModelClaims.Models[0].Role = nil + return service + }, + failed: false, + }), + ginkgo.Entry("no main model", &testValidatingCase{ + service: func() *inferenceapi.Service { + return wrapper.MakeService("service-llama3-8b", ns.Name). + ModelClaims([]string{"llama3-8b"}, []string{"draft"}). + WorkerTemplate(). + Obj() + }, + failed: true, + }), ) }) diff --git a/test/util/mock.go b/test/util/mock.go index 91febb9..7d774b2 100644 --- a/test/util/mock.go +++ b/test/util/mock.go @@ -35,7 +35,7 @@ func MockASamplePlayground(ns string) *inferenceapi.Playground { func MockASampleService(ns string) *inferenceapi.Service { return wrapper.MakeService("service-llama3-8b", ns). - ModelsClaim([]string{"llama3-8b"}, coreapi.Standard, nil). + ModelClaims([]string{"llama3-8b"}, []string{"main"}). WorkerTemplate(). Obj() } diff --git a/test/util/validation/validate_playground.go b/test/util/validation/validate_playground.go index f5ec1a4..3a5cea1 100644 --- a/test/util/validation/validate_playground.go +++ b/test/util/validation/validate_playground.go @@ -46,21 +46,18 @@ func validateModelClaim(ctx context.Context, k8sClient client.Client, playground return errors.New("failed to get model") } - if playground.Spec.ModelClaim.ModelName != service.Spec.MultiModelsClaim.ModelNames[0] { - return fmt.Errorf("expected modelName %s, got %s", playground.Spec.ModelClaim.ModelName, service.Spec.MultiModelsClaim.ModelNames[0]) + if playground.Spec.ModelClaim.ModelName != service.Spec.ModelClaims.Models[0].Name { + return fmt.Errorf("expected modelName %s, got %s", playground.Spec.ModelClaim.ModelName, service.Spec.ModelClaims.Models[0].Name) } - if diff := cmp.Diff(playground.Spec.ModelClaim.InferenceFlavors, service.Spec.MultiModelsClaim.InferenceFlavors); diff != "" { - return fmt.Errorf("unexpected flavors, want %v, got %v", playground.Spec.ModelClaim.InferenceFlavors, service.Spec.MultiModelsClaim.InferenceFlavors) + if diff := cmp.Diff(playground.Spec.ModelClaim.InferenceFlavors, service.Spec.ModelClaims.InferenceFlavors); diff != "" { + return fmt.Errorf("unexpected flavors, want %v, got %v", playground.Spec.ModelClaim.InferenceFlavors, service.Spec.ModelClaims.InferenceFlavors) } - } else if playground.Spec.MultiModelsClaim != nil { - if err := k8sClient.Get(ctx, types.NamespacedName{Name: string(playground.Spec.MultiModelsClaim.ModelNames[0]), Namespace: playground.Namespace}, &model); err != nil { + } else if playground.Spec.ModelClaims != nil { + if err := k8sClient.Get(ctx, types.NamespacedName{Name: string(playground.Spec.ModelClaims.Models[0].Name), Namespace: playground.Namespace}, &model); err != nil { return errors.New("failed to get model") } - if diff := cmp.Diff(playground.Spec.MultiModelsClaim.ModelNames, service.Spec.MultiModelsClaim.ModelNames); diff != "" { - return fmt.Errorf("expected modelNames, want %s, got %s", playground.Spec.MultiModelsClaim.ModelNames, service.Spec.MultiModelsClaim.ModelNames) - } - if diff := cmp.Diff(playground.Spec.MultiModelsClaim.InferenceFlavors, service.Spec.MultiModelsClaim.InferenceFlavors); diff != "" { - return fmt.Errorf("unexpected flavors, want %v, got %v", playground.Spec.MultiModelsClaim.InferenceFlavors, service.Spec.MultiModelsClaim.InferenceFlavors) + if diff := cmp.Diff(*playground.Spec.ModelClaims, service.Spec.ModelClaims); diff != "" { + return fmt.Errorf("expected modelClaims, want %v, got %v", *playground.Spec.ModelClaims, service.Spec.ModelClaims) } } @@ -95,7 +92,7 @@ func ValidatePlayground(ctx context.Context, k8sClient client.Client, playground if service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Name != modelSource.MODEL_RUNNER_CONTAINER_NAME { return fmt.Errorf("container name not right, want %s, got %s", modelSource.MODEL_RUNNER_CONTAINER_NAME, service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Name) } - if diff := cmp.Diff(bkd.Command(), service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Command); diff != "" { + if diff := cmp.Diff(bkd.DefaultCommand(), service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Command); diff != "" { return errors.New("command not right") } if playground.Spec.BackendConfig != nil { diff --git a/test/util/validation/validate_service.go b/test/util/validation/validate_service.go index 717465a..3a83455 100644 --- a/test/util/validation/validate_service.go +++ b/test/util/validation/validate_service.go @@ -51,10 +51,9 @@ func ValidateService(ctx context.Context, k8sClient client.Client, service *infe // TODO: multi-host models := []*coreapi.OpenModel{} - modelNames := service.Spec.MultiModelsClaim.ModelNames - for _, modelName := range modelNames { + for _, mr := range service.Spec.ModelClaims.Models { model := &coreapi.OpenModel{} - if err := k8sClient.Get(ctx, types.NamespacedName{Name: string(modelName)}, model); err != nil { + if err := k8sClient.Get(ctx, types.NamespacedName{Name: string(mr.Name)}, model); err != nil { return errors.New("failed to get model") } models = append(models, model) diff --git a/test/util/wrapper/playground.go b/test/util/wrapper/playground.go index 5160a0c..4f7e702 100644 --- a/test/util/wrapper/playground.go +++ b/test/util/wrapper/playground.go @@ -71,23 +71,22 @@ func (w *PlaygroundWrapper) ModelClaim(modelName string, flavorNames ...string) return w } -func (w *PlaygroundWrapper) MultiModelsClaim(modelNames []string, mode coreapi.InferenceMode, flavorNames ...string) *PlaygroundWrapper { - mNames := []coreapi.ModelName{} - for _, name := range modelNames { - mNames = append(mNames, coreapi.ModelName(name)) +func (w *PlaygroundWrapper) ModelClaims(modelNames []string, roles []string, flavorNames ...string) *PlaygroundWrapper { + models := []coreapi.ModelRepresentative{} + for i, name := range modelNames { + models = append(models, coreapi.ModelRepresentative{Name: coreapi.ModelName(name), Role: (*coreapi.ModelRole)(&roles[i])}) + } + w.Spec.ModelClaims = &coreapi.ModelClaims{ + Models: models, } fNames := []coreapi.FlavorName{} for _, name := range flavorNames { fNames = append(fNames, coreapi.FlavorName(name)) } - w.Spec.MultiModelsClaim = &coreapi.MultiModelsClaim{ - InferenceMode: mode, - ModelNames: mNames, - } if len(fNames) > 0 { - w.Spec.ModelClaim.InferenceFlavors = fNames + w.Spec.ModelClaims.InferenceFlavors = fNames } return w } diff --git a/test/util/wrapper/service.go b/test/util/wrapper/service.go index 512f074..e3d4dc5 100644 --- a/test/util/wrapper/service.go +++ b/test/util/wrapper/service.go @@ -45,19 +45,22 @@ func (w *ServiceWrapper) Obj() *inferenceapi.Service { return &w.Service } -func (w *ServiceWrapper) ModelsClaim(modelNames []string, mode coreapi.InferenceMode, flavorNames []string) *ServiceWrapper { - names := []coreapi.ModelName{} - for i := range modelNames { - names = append(names, coreapi.ModelName(modelNames[i])) +func (w *ServiceWrapper) ModelClaims(modelNames []string, roles []string, flavorNames ...string) *ServiceWrapper { + models := []coreapi.ModelRepresentative{} + for i, name := range modelNames { + models = append(models, coreapi.ModelRepresentative{Name: coreapi.ModelName(name), Role: (*coreapi.ModelRole)(&roles[i])}) } - flavors := []coreapi.FlavorName{} - for i := range flavorNames { - flavors = append(flavors, coreapi.FlavorName(flavorNames[i])) + w.Spec.ModelClaims = coreapi.ModelClaims{ + Models: models, } - w.Spec.MultiModelsClaim = coreapi.MultiModelsClaim{ - ModelNames: names, - InferenceMode: mode, - InferenceFlavors: flavors, + + fNames := []coreapi.FlavorName{} + for _, name := range flavorNames { + fNames = append(fNames, coreapi.FlavorName(name)) + } + + if len(fNames) > 0 { + w.Spec.ModelClaims.InferenceFlavors = fNames } return w }