Skip to content

Commit 18ccbd2

Browse files
Revert "Handle Jinja error messages (#292)" (#293)
This reverts commit 6fdfa9e. Co-authored-by: Anthony DePasquale <[email protected]>
1 parent d93354d commit 18ccbd2

File tree

1 file changed

+29
-36
lines changed

1 file changed

+29
-36
lines changed

Sources/Tokenizers/Tokenizer.swift

Lines changed: 29 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -787,48 +787,41 @@ public class PreTrainedTokenizer: @unchecked Sendable, Tokenizer {
787787
throw TokenizerError.missingChatTemplate
788788
}
789789

790-
let renderedTemplate: String
791-
do {
792-
let template = try compiledTemplate(for: selectedChatTemplate)
793-
794-
var context: [String: Jinja.Value] = try [
795-
"messages": .array(messages.map { try Value(any: $0) }),
796-
"add_generation_prompt": .boolean(addGenerationPrompt),
797-
]
798-
if let tools {
799-
context["tools"] = try .array(tools.map { try Value(any: $0) })
800-
}
801-
if let additionalContext {
802-
// Additional keys and values to be added to the context provided to the prompt templating engine.
803-
// For example, the app could set "tools_in_user_message" to false for Llama 3.1 and 3.2 if a system message is provided.
804-
// The default value is true in the Llama 3.1 and 3.2 chat templates, but these models will perform better if the tools are included in a system message.
805-
for (key, value) in additionalContext {
806-
context[key] = try Value(any: value)
807-
}
790+
let template = try compiledTemplate(for: selectedChatTemplate)
791+
var context: [String: Jinja.Value] = try [
792+
"messages": .array(messages.map { try Value(any: $0) }),
793+
"add_generation_prompt": .boolean(addGenerationPrompt),
794+
]
795+
if let tools {
796+
context["tools"] = try .array(tools.map { try Value(any: $0) })
797+
}
798+
if let additionalContext {
799+
// Additional keys and values to be added to the context provided to the prompt templating engine.
800+
// For example, the app could set "tools_in_user_message" to false for Llama 3.1 and 3.2 if a system message is provided.
801+
// The default value is true in the Llama 3.1 and 3.2 chat templates, but these models will perform better if the tools are included in a system message.
802+
for (key, value) in additionalContext {
803+
context[key] = try Value(any: value)
808804
}
805+
}
809806

810-
for (key, value) in tokenizerConfig.dictionary(or: [:]) {
811-
if specialTokenAttributes.contains(key.string), !value.isNull() {
812-
if let stringValue = value.string() {
813-
context[key.string] = .string(stringValue)
814-
} else if let dictionary = value.dictionary() {
815-
if let addedTokenString = addedTokenAsString(Config(dictionary)) {
816-
context[key.string] = .string(addedTokenString)
817-
}
818-
} else if let array: [String] = value.get() {
819-
context[key.string] = .array(array.map { .string($0) })
820-
} else {
821-
context[key.string] = try Value(any: value)
807+
for (key, value) in tokenizerConfig.dictionary(or: [:]) {
808+
if specialTokenAttributes.contains(key.string), !value.isNull() {
809+
if let stringValue = value.string() {
810+
context[key.string] = .string(stringValue)
811+
} else if let dictionary = value.dictionary() {
812+
if let addedTokenString = addedTokenAsString(Config(dictionary)) {
813+
context[key.string] = .string(addedTokenString)
822814
}
815+
} else if let array: [String] = value.get() {
816+
context[key.string] = .array(array.map { .string($0) })
817+
} else {
818+
context[key.string] = try Value(any: value)
823819
}
824820
}
825-
826-
renderedTemplate = try template.render(context)
827-
} catch let error as JinjaError {
828-
let description = (error as? LocalizedError)?.errorDescription ?? "\(error)"
829-
throw TokenizerError.chatTemplate(description)
830821
}
831-
var encodedTokens = encode(text: renderedTemplate, addSpecialTokens: false)
822+
823+
let rendered = try template.render(context)
824+
var encodedTokens = encode(text: rendered, addSpecialTokens: false)
832825
var maxLength = maxLength ?? encodedTokens.count
833826
maxLength = min(maxLength, tokenizerConfig.modelMaxLength.integer() ?? maxLength)
834827
if encodedTokens.count > maxLength {

0 commit comments

Comments
 (0)