diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index e2b95603379348..4cac66ac4ab54a 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -221,13 +221,12 @@ def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Op :param credentials: model credentials :return: model schema """ - # get predefined models (predefined_models) - models = self.predefined_models() - - model_map = {model.model: model for model in models} - if model in model_map: - return model_map[model] + # Try to get model schema from predefined models + for predefined_model in self.predefined_models(): + if model == predefined_model.model: + return predefined_model + # Try to get model schema from credentials if credentials: model_schema = self.get_customizable_model_schema_from_credentials(model, credentials) if model_schema: diff --git a/api/core/model_runtime/model_providers/cohere/llm/llm.py b/api/core/model_runtime/model_providers/cohere/llm/llm.py index f230157a34ec3f..a39eb56f718337 100644 --- a/api/core/model_runtime/model_providers/cohere/llm/llm.py +++ b/api/core/model_runtime/model_providers/cohere/llm/llm.py @@ -677,16 +677,17 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode :return: model schema """ - # get model schema - models = self.predefined_models() - model_map = {model.model: model for model in models} - mode = credentials.get("mode") + base_model_schema = None + for predefined_model in self.predefined_models(): + if ( + mode == "chat" and predefined_model.model == "command-light-chat" + ) or predefined_model.model == "command-light": + base_model_schema = predefined_model + break - if mode == "chat": - base_model_schema = model_map["command-light-chat"] - else: - base_model_schema = model_map["command-light"] + if not base_model_schema: + raise ValueError("Model not found") base_model_schema = cast(AIModelEntity, base_model_schema) diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index 634dbc55352cf9..05872020583664 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -341,9 +341,6 @@ def remote_models(self, credentials: dict) -> list[AIModelEntity]: :param credentials: provider credentials :return: """ - # get predefined models - predefined_models = self.predefined_models() - predefined_models_map = {model.model: model for model in predefined_models} # transform credentials to kwargs for model instance credentials_kwargs = self._to_credential_kwargs(credentials) @@ -359,9 +356,10 @@ def remote_models(self, credentials: dict) -> list[AIModelEntity]: base_model = model.id.split(":")[1] base_model_schema = None - for predefined_model_name, predefined_model in predefined_models_map.items(): - if predefined_model_name in base_model: + for predefined_model in self.predefined_models(): + if predefined_model.model in base_model: base_model_schema = predefined_model + break if not base_model_schema: continue @@ -1186,12 +1184,14 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode base_model = model.split(":")[1] # get model schema - models = self.predefined_models() - model_map = {model.model: model for model in models} - if base_model not in model_map: - raise ValueError(f"Base model {base_model} not found") + base_model_schema = None + for predefined_model in self.predefined_models(): + if base_model == predefined_model.model: + base_model_schema = predefined_model + break - base_model_schema = model_map[base_model] + if not base_model_schema: + raise ValueError(f"Base model {base_model} not found") base_model_schema_features = base_model_schema.features or [] base_model_schema_model_properties = base_model_schema.model_properties