Skip to content

Commit

Permalink
Add filters to prompt function (#1371)
Browse files Browse the repository at this point in the history
Allow giving custom filters to the prompt decorator

```
def reverses: str) -> str:
    return s[::-1]

@prompt(filters={ 'reverse': reverse })
def reverse_prompt(text):
    '''{{ text | reverse }}'''

prompt = reverse_prompt("Hello")

print(prompt)
>>> "olleH"
```
  • Loading branch information
derfred authored Jan 14, 2025
1 parent 3af1e5a commit 79100b2
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 21 deletions.
71 changes: 50 additions & 21 deletions outlines/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __call__(self, *args, **kwargs) -> str:
return self.template.render(**kwargs)

@classmethod
def from_str(cls, content: str):
def from_str(cls, content: str, filters: Dict[str, Callable] = {}):
"""
Create an instance of the class from a string.
Expand All @@ -53,10 +53,10 @@ def from_str(cls, content: str):
-------
An instance of the class with the provided content as a template.
"""
return cls(cls._template_from_str(content), None)
return cls(cls._template_from_str(content, filters), None)

@classmethod
def from_file(cls, path: Path):
def from_file(cls, path: Path, filters: Dict[str, Callable] = {}):
"""
Create a Prompt instance from a file containing a Jinja template.
Expand All @@ -75,10 +75,12 @@ def from_file(cls, path: Path):
"""
# We don't use a `Signature` here because it seems not feasible to infer one from a Jinja2 environment that is
# split across multiple files (since e.g. we support features like Jinja2 includes and template inheritance)
return cls(cls._template_from_file(path), None)
return cls(cls._template_from_file(path, filters), None)

@classmethod
def _template_from_str(_, content: str) -> jinja2.Template:
def _template_from_str(
_, content: str, filters: Dict[str, Callable] = {}
) -> jinja2.Template:
# Dedent, and remove extra linebreak
cleaned_template = inspect.cleandoc(content)

Expand All @@ -93,12 +95,7 @@ def _template_from_str(_, content: str) -> jinja2.Template:
# used to continue to the next line without linebreak.
cleaned_template = re.sub(r"(?![\r\n])(\b\s+)", " ", cleaned_template)

env = jinja2.Environment(
trim_blocks=True,
lstrip_blocks=True,
keep_trailing_newline=True,
undefined=jinja2.StrictUndefined,
)
env = create_jinja_env(None, filters)
env.filters["name"] = get_fn_name
env.filters["description"] = get_fn_description
env.filters["source"] = get_fn_source
Expand All @@ -109,19 +106,19 @@ def _template_from_str(_, content: str) -> jinja2.Template:
return env.from_string(cleaned_template)

@classmethod
def _template_from_file(_, path: Path) -> jinja2.Template:
def _template_from_file(
_, path: Path, filters: Dict[str, Callable] = {}
) -> jinja2.Template:
file_directory = os.path.dirname(os.path.abspath(path))
env = jinja2.Environment(
loader=jinja2.FileSystemLoader(file_directory),
trim_blocks=True,
lstrip_blocks=True,
keep_trailing_newline=True,
undefined=jinja2.StrictUndefined,
)
env = create_jinja_env(jinja2.FileSystemLoader(file_directory), filters)

return env.get_template(os.path.basename(path))


def prompt(fn: Callable) -> Prompt:
def prompt(
fn: Optional[Callable] = None,
filters: Dict[str, Callable] = {},
) -> Callable:
"""Decorate a function that contains a prompt template.
This allows to define prompts in the docstring of a function and simplify their
Expand Down Expand Up @@ -152,11 +149,26 @@ def prompt(fn: Callable) -> Prompt:
...
>>> hal = ft.partial(solve_task, "HAL", "Travel to Jupiter")
Additional Jinja2 filters can be provided as keyword arguments to the decorator.
>>> def reverse(s: str) -> str:
... return s[::-1]
...
>>> @outlines.prompt(filters={ 'reverse': reverse })
... def reverse_prompt(text):
... '''{{ text | reverse }}'''
...
>>> prompt = reverse_prompt("Hello")
>>> print(prompt)
... "olleH"
Returns
-------
A `Prompt` callable class which will render the template when called.
"""
if fn is None:
return lambda fn: prompt(fn, cast(Dict[str, Callable], filters))

signature = inspect.signature(fn)

Expand All @@ -166,11 +178,28 @@ def prompt(fn: Callable) -> Prompt:
if docstring is None:
raise TypeError("Could not find a template in the function's docstring.")

template = Prompt._template_from_str(cast(str, docstring))
template = Prompt._template_from_str(cast(str, docstring), filters)

return Prompt(template, signature)


def create_jinja_env(
loader: Optional[jinja2.BaseLoader], filters: Dict[str, Callable]
) -> jinja2.Environment:
env = jinja2.Environment(
loader=loader,
trim_blocks=True,
lstrip_blocks=True,
keep_trailing_newline=True,
undefined=jinja2.StrictUndefined,
)

for name, filter_fn in filters.items():
env.filters[name] = filter_fn

return env


def get_fn_name(fn: Callable):
"""Returns the name of a callable."""
if not callable(fn):
Expand Down
17 changes: 17 additions & 0 deletions tests/test_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,23 @@ def args_prompt(fn):
)


def test_prompt_with_additional_filters():
def reverse(s: str) -> str:
return s[::-1]

@outlines.prompt(filters=dict(reverse=reverse))
def test_tpl(variable):
"""{{ variable | reverse }} test"""

assert list(test_tpl.signature.parameters) == ["variable"]

p = test_tpl("test")
assert p == "tset test"

p = test_tpl(variable="example")
assert p == "elpmaxe test"


@pytest.fixture
def temp_prompt_file():
test_dir = tempfile.mkdtemp()
Expand Down

0 comments on commit 79100b2

Please sign in to comment.