Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,11 @@ Minja supports the following subset of the [Jinja2/3 template syntax](https://ji
- Full expression syntax
- Statements `{{% … %}}`, variable sections `{{ … }}`, and comments `{# … #}` with pre/post space elision `{%- … -%}` / `{{- … -}}` / `{#- … -#}`
- `if` / `elif` / `else` / `endif`
- `for` (`recursive`) (`if`) / `else` / `endfor` w/ `loop.*` (including `loop.cycle`) and destructuring
- `for` (`recursive`) (`if`) / `else` / `endfor` w/ `loop.*` (including `loop.cycle`) and destructuring)
- `break`, `continue` (aka [loop controls extensions](https://github.com/google/minja/pull/39))
- `set` w/ namespaces & destructuring
- `macro` / `endmacro`
- `call` / `endcall` - for calling macro (w/ macro arguments and `caller()` syntax) and passing a macro to another macro (w/o passing arguments back to the call block)
- `filter` / `endfilter`
- Extensible filters collection: `count`, `dictsort`, `equalto`, `e` / `escape`, `items`, `join`, `joiner`, `namespace`, `raise_exception`, `range`, `reject` / `rejectattr` / `select` / `selectattr`, `tojson`, `trim`

Expand Down
86 changes: 77 additions & 9 deletions include/minja/minja.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ enum SpaceHandling { Keep, Strip, StripSpaces, StripNewline };

class TemplateToken {
public:
enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter, Break, Continue };
enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter, Break, Continue, Call, EndCall };

static std::string typeToString(Type t) {
switch (t) {
Expand All @@ -729,6 +729,8 @@ class TemplateToken {
case Type::EndGeneration: return "endgeneration";
case Type::Break: return "break";
case Type::Continue: return "continue";
case Type::Call: return "call";
case Type::EndCall: return "endcall";
}
return "Unknown";
}
Expand Down Expand Up @@ -846,6 +848,17 @@ struct LoopControlTemplateToken : public TemplateToken {
LoopControlTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, LoopControlType control_type) : TemplateToken(Type::Break, loc, pre, post), control_type(control_type) {}
};

struct CallTemplateToken : public TemplateToken {
std::shared_ptr<Expression> expr;
CallTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && e)
: TemplateToken(Type::Call, loc, pre, post), expr(std::move(e)) {}
};

struct EndCallTemplateToken : public TemplateToken {
EndCallTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post)
: TemplateToken(Type::EndCall, loc, pre, post) {}
};

class TemplateNode {
Location location_;
protected:
Expand Down Expand Up @@ -1050,31 +1063,36 @@ class MacroNode : public TemplateNode {
void do_render(std::ostringstream &, const std::shared_ptr<Context> & macro_context) const override {
if (!name) throw std::runtime_error("MacroNode.name is null");
if (!body) throw std::runtime_error("MacroNode.body is null");
auto callable = Value::callable([&](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
auto call_context = macro_context;
auto callable = Value::callable([this, macro_context](const std::shared_ptr<Context> & call_context, ArgumentsValue & args) {
auto execution_context = Context::make(Value::object(), macro_context);

if (call_context->contains("caller")) {
execution_context->set("caller", call_context->get("caller"));
}

std::vector<bool> param_set(params.size(), false);
for (size_t i = 0, n = args.args.size(); i < n; i++) {
auto & arg = args.args[i];
if (i >= params.size()) throw std::runtime_error("Too many positional arguments for macro " + name->get_name());
param_set[i] = true;
auto & param_name = params[i].first;
call_context->set(param_name, arg);
execution_context->set(param_name, arg);
}
for (auto & [arg_name, value] : args.kwargs) {
auto it = named_param_positions.find(arg_name);
if (it == named_param_positions.end()) throw std::runtime_error("Unknown parameter name for macro " + name->get_name() + ": " + arg_name);

call_context->set(arg_name, value);
execution_context->set(arg_name, value);
param_set[it->second] = true;
}
// Set default values for parameters that were not passed
for (size_t i = 0, n = params.size(); i < n; i++) {
if (!param_set[i] && params[i].second != nullptr) {
auto val = params[i].second->evaluate(context);
call_context->set(params[i].first, val);
auto val = params[i].second->evaluate(call_context);
execution_context->set(params[i].first, val);
}
}
return body->render(call_context);
return body->render(execution_context);
});
macro_context->set(name->get_name(), callable);
}
Expand Down Expand Up @@ -1611,6 +1629,40 @@ class CallExpr : public Expression {
}
};

class CallNode : public TemplateNode {
std::shared_ptr<Expression> expr;
std::shared_ptr<TemplateNode> body;

public:
CallNode(const Location & loc, std::shared_ptr<Expression> && e, std::shared_ptr<TemplateNode> && b)
: TemplateNode(loc), expr(std::move(e)), body(std::move(b)) {}

void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
if (!expr) throw std::runtime_error("CallNode.expr is null");
if (!body) throw std::runtime_error("CallNode.body is null");

auto caller = Value::callable([this, context](const std::shared_ptr<Context> &, ArgumentsValue &) -> Value {
return Value(body->render(context));
});

context->set("caller", caller);

auto call_expr = dynamic_cast<CallExpr*>(expr.get());
if (!call_expr) {
throw std::runtime_error("Invalid call block syntax - expected function call");
}

Value function = call_expr->object->evaluate(context);
if (!function.is_callable()) {
throw std::runtime_error("Call target must be callable: " + function.dump());
}
ArgumentsValue args = call_expr->args.evaluate(context);

Value result = function.call(context, args);
out << result.to_str();
}
};

class FilterExpr : public Expression {
std::vector<std::shared_ptr<Expression>> parts;
public:
Expand Down Expand Up @@ -2320,7 +2372,7 @@ class Parser {
static std::regex comment_tok(R"(\{#([-~]?)([\s\S]*?)([-~]?)#\})");
static std::regex expr_open_regex(R"(\{\{([-~])?)");
static std::regex block_open_regex(R"(^\{%([-~])?\s*)");
static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)");
static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue|call|endcall)\b)");
static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)");
static std::regex expr_close_regex(R"(\s*([-~])?\}\})");
static std::regex block_close_regex(R"(\s*([-~])?%\})");
Expand Down Expand Up @@ -2443,6 +2495,15 @@ class Parser {
} else if (keyword == "endmacro") {
auto post_space = parseBlockClose();
tokens.push_back(std::make_unique<EndMacroTemplateToken>(location, pre_space, post_space));
} else if (keyword == "call") {
auto expr = parseExpression();
if (!expr) throw std::runtime_error("Expected expression in call block");

auto post_space = parseBlockClose();
tokens.push_back(std::make_unique<CallTemplateToken>(location, pre_space, post_space, std::move(expr)));
} else if (keyword == "endcall") {
auto post_space = parseBlockClose();
tokens.push_back(std::make_unique<EndCallTemplateToken>(location, pre_space, post_space));
} else if (keyword == "filter") {
auto filter = parseExpression();
if (!filter) throw std::runtime_error("Expected expression in filter block");
Expand Down Expand Up @@ -2575,6 +2636,12 @@ class Parser {
throw unterminated(**start);
}
children.emplace_back(std::make_shared<MacroNode>(token->location, std::move(macro_token->name), std::move(macro_token->params), std::move(body)));
} else if (auto call_token = dynamic_cast<CallTemplateToken*>(token.get())) {
auto body = parseTemplate(begin, it, end);
if (it == end || (*(it++))->type != TemplateToken::Type::EndCall) {
throw unterminated(**start);
}
children.emplace_back(std::make_shared<CallNode>(token->location, std::move(call_token->expr), std::move(body)));
} else if (auto filter_token = dynamic_cast<FilterTemplateToken*>(token.get())) {
auto body = parseTemplate(begin, it, end);
if (it == end || (*(it++))->type != TemplateToken::Type::EndFilter) {
Expand All @@ -2588,6 +2655,7 @@ class Parser {
} else if (dynamic_cast<EndForTemplateToken*>(token.get())
|| dynamic_cast<EndSetTemplateToken*>(token.get())
|| dynamic_cast<EndMacroTemplateToken*>(token.get())
|| dynamic_cast<EndCallTemplateToken*>(token.get())
|| dynamic_cast<EndFilterTemplateToken*>(token.get())
|| dynamic_cast<EndIfTemplateToken*>(token.get())
|| dynamic_cast<ElseTemplateToken*>(token.get())
Expand Down
2 changes: 1 addition & 1 deletion tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ set(MODEL_IDS
OnlyCheeini/greesychat-turbo
onnx-community/DeepSeek-R1-Distill-Qwen-1.5B-ONNX
open-thoughts/OpenThinker-7B
openbmb/MiniCPM3-4B
openchat/openchat-3.5-0106
Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2
OrionStarAI/Orion-14B-Chat
Expand Down Expand Up @@ -261,7 +262,6 @@ set(MODEL_IDS
prithivMLmods/Qwen2.5-7B-DeepSeek-R1-1M
prithivMLmods/QwQ-Math-IO-500M
prithivMLmods/Triangulum-v2-10B
qingy2024/Falcon3-2x10B-MoE-Instruct
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this removal accidental?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Qwen/QVQ-72B-Preview
Qwen/Qwen1.5-7B-Chat
Qwen/Qwen2-7B-Instruct
Expand Down
56 changes: 56 additions & 0 deletions tests/test-syntax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,54 @@ TEST(SyntaxTest, SimpleCases) {
{%- endmacro -%}
{{- foo() }} {{ foo() -}})", {}, {}));

EXPECT_EQ(
"x,x",
render(R"(
{%- macro test() -%}{{ caller() }},{{ caller() }}{%- endmacro -%}
{%- call test() -%}x{%- endcall -%}
)", {}, {}));

EXPECT_EQ(
"Outer[Inner(X)]",
render(R"(
{%- macro outer() -%}Outer[{{ caller() }}]{%- endmacro -%}
{%- macro inner() -%}Inner({{ caller() }}){%- endmacro -%}
{%- call outer() -%}{%- call inner() -%}X{%- endcall -%}{%- endcall -%}
)", {}, {}));

EXPECT_EQ(
"<ul><li>A</li><li>B</li></ul>",
render(R"(
{%- macro test(prefix, suffix) -%}{{ prefix }}{{ caller() }}{{ suffix }}{%- endmacro -%}
{%- set items = ["a", "b"] -%}
{%- call test("<ul>", "</ul>") -%}
{%- for item in items -%}
<li>{{ item | upper }}</li>
{%- endfor -%}
{%- endcall -%}
)", {}, {}));

EXPECT_EQ(
"\\n\\nclass A:\\n b: 1\\n c: 2\\n",
render(R"(
{%- macro recursive(obj) -%}
{%- set ns = namespace(content = caller()) -%}
{%- for key, value in obj.items() %}
{%- if value is mapping %}
{%- call recursive(value) -%}
{{ '\\n\\nclass ' + key.title() + ':\\n' }}
{%- endcall -%}
{%- else -%}
{%- set ns.content = ns.content + ' ' + key + ': ' + value + '\\n' -%}
{%- endif -%}
{%- endfor -%}
{{ ns.content }}
{%- endmacro -%}

{%- call recursive({"a": {"b": "1", "c": "2"}}) -%}
{%- endcall -%}
)", {}, {}));

if (!getenv("USE_JINJA2")) {
EXPECT_EQ(
"Foo",
Expand Down Expand Up @@ -576,6 +624,8 @@ TEST(SyntaxTest, SimpleCases) {
EXPECT_THAT([]() { render("{% elif 1 %}", {}, {}); }, ThrowsWithSubstr("Unexpected elif"));
EXPECT_THAT([]() { render("{% endfor %}", {}, {}); }, ThrowsWithSubstr("Unexpected endfor"));
EXPECT_THAT([]() { render("{% endfilter %}", {}, {}); }, ThrowsWithSubstr("Unexpected endfilter"));
EXPECT_THAT([]() { render("{% endmacro %}", {}, {}); }, ThrowsWithSubstr("Unexpected endmacro"));
EXPECT_THAT([]() { render("{% endcall %}", {}, {}); }, ThrowsWithSubstr("Unexpected endcall"));

EXPECT_THAT([]() { render("{% if 1 %}", {}, {}); }, ThrowsWithSubstr("Unterminated if"));
EXPECT_THAT([]() { render("{% for x in 1 %}", {}, {}); }, ThrowsWithSubstr("Unterminated for"));
Expand All @@ -584,6 +634,12 @@ TEST(SyntaxTest, SimpleCases) {
EXPECT_THAT([]() { render("{% if 1 %}{% else %}{% elif 1 %}{% endif %}", {}, {}); }, ThrowsWithSubstr("Unterminated if"));
EXPECT_THAT([]() { render("{% filter trim %}", {}, {}); }, ThrowsWithSubstr("Unterminated filter"));
EXPECT_THAT([]() { render("{# ", {}, {}); }, ThrowsWithSubstr("Missing end of comment tag"));
EXPECT_THAT([]() { render("{% macro test() %}", {}, {}); }, ThrowsWithSubstr("Unterminated macro"));
EXPECT_THAT([]() { render("{% call test %}", {}, {}); }, ThrowsWithSubstr("Unterminated call"));

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please also add an unterminated call test:

EXPECT_THAT([]() { render("{%- call test -%}", {}, {}); }, ThrowsWithSubstr("Missing end of call tag"));

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests for call and macro added

EXPECT_THAT([]() {
render("{%- macro test() -%}content{%- endmacro -%}{%- call test -%}caller_content{%- endcall -%}", {}, {});
}, ThrowsWithSubstr("Invalid call block syntax - expected function call"));
}

EXPECT_EQ(
Expand Down
Loading