Skip to content

Commit 341939a

Browse files
committed
feat(functions): add jaro_winkler string similarity function
1 parent ab08029 commit 341939a

File tree

4 files changed

+263
-1
lines changed

4 files changed

+263
-1
lines changed

src/query/functions/src/scalars/other.rs

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ use databend_common_expression::types::TimestampType;
4848
use databend_common_expression::types::ValueType;
4949
use databend_common_expression::vectorize_with_builder_1_arg;
5050
use databend_common_expression::vectorize_with_builder_2_arg;
51+
use databend_common_expression::vectorize_2_arg;
5152
use databend_common_expression::Column;
5253
use databend_common_expression::Domain;
5354
use databend_common_expression::EvalContext;
@@ -241,6 +242,17 @@ pub fn register(registry: &mut FunctionRegistry) {
241242
Value::Column(col)
242243
},
243244
);
245+
246+
registry
247+
.register_passthrough_nullable_2_arg::<StringType, StringType, Float64Type, _, _>(
248+
"jaro_winkler",
249+
|_, _, _| FunctionDomain::Full,
250+
vectorize_2_arg::<StringType, StringType, Float64Type>(
251+
|s1, s2, _ctx| {
252+
jaro_winkler::jaro_winkler(s1, s2).into()
253+
},
254+
),
255+
);
244256
}
245257

246258
fn register_inet_aton(registry: &mut FunctionRegistry) {
@@ -486,3 +498,187 @@ pub fn compute_grouping(cols: &[usize], grouping_id: u32) -> u32 {
486498
}
487499
grouping
488500
}
501+
//
502+
// this implementation comes from https://github.com/joshuaclayton/jaro_winkler
503+
pub(crate) mod jaro_winkler {
504+
#![deny(missing_docs)]
505+
506+
//! `jaro_winkler` is a crate for calculating Jaro-Winkler distance of two strings.
507+
//!
508+
//! # Examples
509+
//!
510+
//! ```
511+
//! use jaro_winkler::jaro_winkler;
512+
//!
513+
//! assert_eq!(jaro_winkler("martha", "marhta"), 0.9611111111111111);
514+
//! assert_eq!(jaro_winkler("", "words"), 0.0);
515+
//! assert_eq!(jaro_winkler("same", "same"), 1.0);
516+
//! ```
517+
518+
enum DataWrapper {
519+
Vec(Vec<bool>),
520+
Bitwise(u128),
521+
}
522+
523+
impl DataWrapper {
524+
fn build(len: usize) -> Self {
525+
if len <= 128 {
526+
DataWrapper::Bitwise(0)
527+
} else {
528+
let mut internal = Vec::with_capacity(len);
529+
internal.extend(std::iter::repeat(false).take(len));
530+
DataWrapper::Vec(internal)
531+
}
532+
}
533+
534+
fn get(&self, idx: usize) -> bool {
535+
match self {
536+
DataWrapper::Vec(v) => v[idx],
537+
DataWrapper::Bitwise(v1) => (v1 >> idx) & 1 == 1,
538+
}
539+
}
540+
541+
fn set_true(&mut self, idx: usize) {
542+
match self {
543+
DataWrapper::Vec(v) => v[idx] = true,
544+
DataWrapper::Bitwise(v1) => *v1 |= 1 << idx,
545+
}
546+
}
547+
}
548+
549+
/// Calculates the Jaro-Winkler distance of two strings.
550+
///
551+
/// The return value is between 0.0 and 1.0, where 1.0 means the strings are equal.
552+
pub fn jaro_winkler(left_: &str, right_: &str) -> f64 {
553+
let llen = left_.len();
554+
let rlen = right_.len();
555+
556+
let (left, right, s1_len, s2_len) = if llen < rlen {
557+
(right_, left_, rlen, llen)
558+
} else {
559+
(left_, right_, llen, rlen)
560+
};
561+
562+
match (s1_len, s2_len) {
563+
(0, 0) => return 1.0,
564+
(0, _) | (_, 0) => return 0.0,
565+
(_, _) => (),
566+
}
567+
568+
if left == right {
569+
return 1.0;
570+
}
571+
572+
let range = matching_distance(s1_len, s2_len);
573+
let mut s1m = DataWrapper::build(s1_len);
574+
let mut s2m = DataWrapper::build(s2_len);
575+
let mut matching: f64 = 0.0;
576+
let mut transpositions: f64 = 0.0;
577+
let left_as_bytes = left.as_bytes();
578+
let right_as_bytes = right.as_bytes();
579+
580+
for i in 0..s2_len {
581+
let mut j = (i as isize - range as isize).max(0) as usize;
582+
let l = (i + range + 1).min(s1_len);
583+
while j < l {
584+
if right_as_bytes[i] == left_as_bytes[j] && !s1m.get(j) {
585+
s1m.set_true(j);
586+
s2m.set_true(i);
587+
matching += 1.0;
588+
break;
589+
}
590+
591+
j += 1;
592+
}
593+
}
594+
595+
if matching == 0.0 {
596+
return 0.0;
597+
}
598+
599+
let mut l = 0;
600+
601+
for i in 0..s2_len - 1 {
602+
if s2m.get(i) {
603+
let mut j = l;
604+
605+
while j < s1_len {
606+
if s1m.get(j) {
607+
l = j + 1;
608+
break;
609+
}
610+
611+
j += 1;
612+
}
613+
614+
if right_as_bytes[i] != left_as_bytes[j] {
615+
transpositions += 1.0;
616+
}
617+
}
618+
}
619+
transpositions = (transpositions / 2.0).ceil();
620+
621+
let jaro = (matching / (s1_len as f64)
622+
+ matching / (s2_len as f64)
623+
+ (matching - transpositions) / matching)
624+
/ 3.0;
625+
626+
let prefix_length = left_as_bytes
627+
.iter()
628+
.zip(right_as_bytes)
629+
.take(4)
630+
.take_while(|(l, r)| l == r)
631+
.count() as f64;
632+
633+
jaro + prefix_length * 0.1 * (1.0 - jaro)
634+
}
635+
636+
fn matching_distance(s1_len: usize, s2_len: usize) -> usize {
637+
let max = s1_len.max(s2_len) as f32;
638+
((max / 2.0).floor() - 1.0) as usize
639+
}
640+
641+
#[cfg(test)]
642+
mod tests {
643+
use super::*;
644+
645+
#[test]
646+
fn different_is_zero() {
647+
assert_eq!(jaro_winkler("foo", "bar"), 0.0);
648+
}
649+
650+
#[test]
651+
fn same_is_one() {
652+
assert_eq!(jaro_winkler("foo", "foo"), 1.0);
653+
assert_eq!(jaro_winkler("", ""), 1.0);
654+
}
655+
656+
#[test]
657+
fn test_hello() {
658+
assert_eq!(jaro_winkler("hell", "hello"), 0.96);
659+
}
660+
661+
macro_rules! assert_within {
662+
($x:expr, $y:expr, delta=$d:expr) => {
663+
assert!(($x - $y).abs() <= $d)
664+
};
665+
}
666+
667+
#[test]
668+
fn test_boundary() {
669+
let long_value = "test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s Doc-tests jaro running 0 tests test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s";
670+
let longer_value = "test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured;test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured;test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured;test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured;test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s Doc-tests jaro running 0 tests test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s";
671+
let result = jaro_winkler(long_value, longer_value);
672+
assert_within!(result, 0.82, delta = 0.01);
673+
}
674+
675+
#[test]
676+
fn test_close_to_boundary() {
677+
let long_value = "test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s Doc-tests jaro running 0 tests test";
678+
assert_eq!(long_value.len(), 129);
679+
let longer_value = "test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured;test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured;test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured;test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured;test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s Doc-tests jaro running 0 tests test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s";
680+
let result = jaro_winkler(long_value, longer_value);
681+
assert_within!(result, 0.8, delta = 0.001);
682+
}
683+
}
684+
}

src/query/functions/src/scalars/string.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,7 @@ pub fn register(registry: &mut FunctionRegistry) {
776776
output.commit_row();
777777
},
778778
),
779-
)
779+
);
780780
}
781781

782782
pub(crate) mod soundex {

src/query/functions/tests/it/scalars/testdata/function_list.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2288,6 +2288,8 @@ Functions overloads:
22882288
1 is_string(Variant NULL) :: Boolean NULL
22892289
0 is_true(Boolean) :: Boolean
22902290
1 is_true(Boolean NULL) :: Boolean
2291+
0 jaro_winkler(String, String) :: Float64
2292+
1 jaro_winkler(String NULL, String NULL) :: Float64 NULL
22912293
0 jq FACTORY
22922294
0 json_array FACTORY
22932295
0 json_array_distinct(Variant) :: Variant
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
query T
2+
SELECT jaro_winkler(NULL, 'hello')
3+
----
4+
NULL
5+
6+
query T
7+
SELECT jaro_winkler('hello', NULL)
8+
----
9+
NULL
10+
11+
query T
12+
SELECT jaro_winkler(NULL, NULL)
13+
----
14+
NULL
15+
16+
query T
17+
SELECT jaro_winkler('', '')
18+
----
19+
1.0
20+
21+
query T
22+
SELECT jaro_winkler('hello', 'hello')
23+
----
24+
1.0
25+
26+
query T
27+
SELECT jaro_winkler('hello', 'helo')
28+
----
29+
0.9533333333333333
30+
31+
query T
32+
SELECT jaro_winkler('martha', 'marhta')
33+
----
34+
0.9611111111111111
35+
36+
query T
37+
SELECT jaro_winkler('你好', '你好啊')
38+
----
39+
0.9333333333333333
40+
41+
query T
42+
SELECT jaro_winkler('🦀hello', '🦀helo')
43+
----
44+
0.9777777777777777
45+
46+
query T
47+
SELECT jaro_winkler('dixon', 'dicksonx')
48+
----
49+
0.8133333333333332
50+
51+
query T
52+
SELECT jaro_winkler('duane', 'dwayne')
53+
----
54+
0.8400000000000001
55+
56+
query T
57+
select jaro_winkler('asdf', 'as x c f');
58+
----
59+
0.6592592592592592
60+
61+
query T
62+
SELECT jaro_winkler('', 'hello')
63+
----
64+
0.0

0 commit comments

Comments
 (0)