diff --git a/torchtune/data/_messages.py b/torchtune/data/_messages.py index b9787ca6bd..7f70772193 100644 --- a/torchtune/data/_messages.py +++ b/torchtune/data/_messages.py @@ -254,12 +254,23 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: eot=True, ), ] - if self.new_system_prompt is not None: + + system_column = self.column_map.get("system", "system") + system_prompt = None + + if (system_column in sample and sample[system_column] and sample[system_column].strip()): + system_prompt = sample[system_column] + + elif self.new_system_prompt is not None: + system_prompt = self.new_system_prompt + + if system_prompt is not None: messages = [ Message( - role="system", content=self.new_system_prompt, masked=True, eot=True + role="system", content=system_prompt, masked=True, eot=True ) ] + messages + return {"messages": messages}