From 0cf1858c462e6873403d5e898bb2274d4645525f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mart=C3=ADn=20Santill=C3=A1n=20Cooper?= Date: Mon, 10 Feb 2025 15:46:32 -0300 Subject: [PATCH] Apply linter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Martín Santillán Cooper --- ...luate_granite_guardian_assistant_message_risks.py | 2 +- examples/evaluate_granite_guardian_custom_risks.py | 2 +- .../evaluate_granite_guardian_user_message_risks.py | 2 +- prepare/metrics/granite_guardian.py | 2 +- src/unitxt/metrics.py | 12 ++++-------- 5 files changed, 8 insertions(+), 12 deletions(-) diff --git a/examples/evaluate_granite_guardian_assistant_message_risks.py b/examples/evaluate_granite_guardian_assistant_message_risks.py index 32dc38e00..4dfe12c60 100644 --- a/examples/evaluate_granite_guardian_assistant_message_risks.py +++ b/examples/evaluate_granite_guardian_assistant_message_risks.py @@ -3,7 +3,7 @@ from unitxt import evaluate from unitxt.api import create_dataset from unitxt.blocks import Task -from unitxt.metrics import GraniteGuardianAssistantRisk, RiskType +from unitxt.metrics import GraniteGuardianAssistantRisk from unitxt.templates import NullTemplate print("Assistant response risks") diff --git a/examples/evaluate_granite_guardian_custom_risks.py b/examples/evaluate_granite_guardian_custom_risks.py index dea381b0c..e4c317c02 100644 --- a/examples/evaluate_granite_guardian_custom_risks.py +++ b/examples/evaluate_granite_guardian_custom_risks.py @@ -1,7 +1,7 @@ from unitxt import evaluate from unitxt.api import create_dataset from unitxt.blocks import Task -from unitxt.metrics import GraniteGuardianCustomRisk, RiskType +from unitxt.metrics import GraniteGuardianCustomRisk from unitxt.templates import NullTemplate print("Bring your own risk") diff --git a/examples/evaluate_granite_guardian_user_message_risks.py b/examples/evaluate_granite_guardian_user_message_risks.py index 1a3e901fc..1de22dc2e 100644 --- a/examples/evaluate_granite_guardian_user_message_risks.py +++ b/examples/evaluate_granite_guardian_user_message_risks.py @@ -3,7 +3,7 @@ from unitxt import evaluate from unitxt.api import create_dataset from unitxt.blocks import Task -from unitxt.metrics import GraniteGuardianUserRisk, RiskType +from unitxt.metrics import GraniteGuardianUserRisk from unitxt.templates import NullTemplate print("User prompt risks") diff --git a/prepare/metrics/granite_guardian.py b/prepare/metrics/granite_guardian.py index d7d5f1001..7a83583a5 100644 --- a/prepare/metrics/granite_guardian.py +++ b/prepare/metrics/granite_guardian.py @@ -1,5 +1,5 @@ from unitxt import add_to_catalog -from unitxt.metrics import GraniteGuardianBase, RISK_TYPE_TO_CLASS +from unitxt.metrics import RISK_TYPE_TO_CLASS, GraniteGuardianBase for risk_type, risk_names in GraniteGuardianBase.available_risks.items(): for risk_name in risk_names: diff --git a/src/unitxt/metrics.py b/src/unitxt/metrics.py index 0eaf8faf3..dd1f98b55 100644 --- a/src/unitxt/metrics.py +++ b/src/unitxt/metrics.py @@ -5868,6 +5868,7 @@ class GraniteGuardianBase(InstanceMetric): main_score = None reduction_map = {} wml_model_name: str = "ibm/granite-guardian-3-8b" + hf_model_name: str = "ibm-granite/granite-guardian-3.1-8b" wml_params = { "decoding_method": "greedy", @@ -5880,8 +5881,6 @@ class GraniteGuardianBase(InstanceMetric): }, } - hf_model_name: str = "ibm-granite/granite-guardian-3.1-8b" - safe_token = "No" unsafe_token = "Yes" @@ -5936,8 +5935,6 @@ def process_input_fields(self, task_data): @classmethod def get_available_risk_names(cls): - print(cls.risk_type) - print(cls.available_risks) return cls.available_risks[cls.risk_type] def set_main_score(self): @@ -5974,7 +5971,6 @@ def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> di messages = self.process_input_fields(task_data) prompt = self.get_prompt(messages) result = self.inference_engine.infer_log_probs([{"source": prompt}]) - print(' '.join([r['text'] for r in result[0]])) generated_tokens_list = result[0] label, prob_of_risk = self.parse_output(generated_tokens_list) confidence_score = ( @@ -6138,8 +6134,8 @@ class GraniteGuardianCustomRisk(GraniteGuardianBase): def verify(self): super().verify() - assert self.risk_type != None, UnitxtError("In a custom risk, risk_type must be defined") - + assert self.risk_type is not None, UnitxtError("In a custom risk, risk_type must be defined") + def verify_granite_guardian_config(self, task_data): # even though this is a custom risks, we will limit the # message roles to be a subset of the roles Granite Guardian @@ -6176,7 +6172,7 @@ def process_input_fields(self, task_data): RISK_TYPE_TO_CLASS: Dict[RiskType, GraniteGuardianBase] = { RiskType.USER_MESSAGE: GraniteGuardianUserRisk, - RiskType.ASSISTANT_MESSAGE: GraniteGuardianAssistantRisk, + RiskType.ASSISTANT_MESSAGE: GraniteGuardianAssistantRisk, RiskType.RAG: GraniteGuardianRagRisk, RiskType.AGENTIC: GraniteGuardianAgenticRisk, }