Skip to content

Commit 1222399

Browse files
authored
Add functions necessary for NER (#12)
1 parent 1194331 commit 1222399

File tree

6 files changed

+52
-1
lines changed

6 files changed

+52
-1
lines changed

lib/tokenizers/encoding.ex

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,19 @@ defmodule Tokenizers.Encoding do
3535
@spec get_type_ids(Encoding.t()) :: [integer()]
3636
def get_type_ids(encoding), do: encoding |> Native.get_type_ids() |> Shared.unwrap()
3737

38+
@doc """
39+
Get special tokens mask from an encoding.
40+
"""
41+
@spec get_special_tokens_mask(Encoding.t()) :: [integer()]
42+
def get_special_tokens_mask(encoding),
43+
do: encoding |> Native.get_special_tokens_mask() |> Shared.unwrap()
44+
45+
@doc """
46+
Get offsets from an encoding.
47+
"""
48+
@spec get_offsets(Encoding.t()) :: [{integer(), integer()}]
49+
def get_offsets(encoding), do: encoding |> Native.get_offsets() |> Shared.unwrap()
50+
3851
@doc """
3952
Truncate the encoding to the given length.
4053

lib/tokenizers/native.ex

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ defmodule Tokenizers.Native do
2020
def get_type_ids(_encoding), do: err()
2121
def get_ids(_encoding), do: err()
2222
def get_tokens(_encoding), do: err()
23+
def get_special_tokens_mask(_encoding), do: err()
24+
def get_offsets(_encoding), do: err()
2325
def get_vocab(_tokenizer, _with_added_tokens), do: err()
2426
def get_vocab_size(_tokenizer, _with_added_tokens), do: err()
2527
def id_to_token(_tokenizer, _id), do: err()

native/ex_tokenizers/src/encoding.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,20 @@ pub fn get_type_ids(encoding: ExTokenizersEncoding) -> Result<Vec<u32>, ExTokeni
4545
Ok(encoding.resource.0.get_type_ids().to_vec())
4646
}
4747

48+
#[rustler::nif]
49+
pub fn get_special_tokens_mask(
50+
encoding: ExTokenizersEncoding,
51+
) -> Result<Vec<u32>, ExTokenizersError> {
52+
Ok(encoding.resource.0.get_special_tokens_mask().to_vec())
53+
}
54+
55+
#[rustler::nif]
56+
pub fn get_offsets(
57+
encoding: ExTokenizersEncoding,
58+
) -> Result<Vec<(usize, usize)>, ExTokenizersError> {
59+
Ok(encoding.resource.0.get_offsets().to_vec())
60+
}
61+
4862
#[rustler::nif]
4963
pub fn n_tokens(encoding: ExTokenizersEncoding) -> Result<usize, ExTokenizersError> {
5064
Ok(encoding.resource.0.len())

native/ex_tokenizers/src/error.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ pub enum ExTokenizersError {
2121
Unknown(#[from] anyhow::Error),
2222
}
2323

24-
impl<'a> Encoder for ExTokenizersError {
24+
impl Encoder for ExTokenizersError {
2525
fn encode<'b>(&self, env: Env<'b>) -> Term<'b> {
2626
format!("{:?}", self).encode(env)
2727
}

native/ex_tokenizers/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ rustler::init!(
2929
get_attention_mask,
3030
get_type_ids,
3131
get_ids,
32+
get_special_tokens_mask,
33+
get_offsets,
3234
get_model,
3335
get_model_details,
3436
get_tokens,

test/tokenizers/tokenizer_test.exs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,4 +80,24 @@ defmodule Tokenizers.TokenizerTest do
8080
assert decoded == text
8181
end
8282
end
83+
84+
describe "encode metadata" do
85+
test "can return special tokens mask", %{tokenizer: tokenizer} do
86+
text = ["This is a test", "And so is this"]
87+
{:ok, encodings} = Tokenizer.encode(tokenizer, text)
88+
special_tokens_mask = Enum.map(encodings, &Encoding.get_special_tokens_mask/1)
89+
assert [[1, 0, 0, 0, 0, 1], [1, 0, 0, 0, 0, 1]] == special_tokens_mask
90+
end
91+
92+
test "can return offsets", %{tokenizer: tokenizer} do
93+
text = ["This is a test", "And so is this"]
94+
{:ok, encodings} = Tokenizer.encode(tokenizer, text)
95+
offsets = Enum.map(encodings, &Encoding.get_offsets/1)
96+
97+
assert [
98+
[{0, 0}, {0, 4}, {5, 7}, {8, 9}, {10, 14}, {0, 0}],
99+
[{0, 0}, {0, 3}, {4, 6}, {7, 9}, {10, 14}, {0, 0}]
100+
] == offsets
101+
end
102+
end
83103
end

0 commit comments

Comments
 (0)