Skip to content

Commit 70b5c32

Browse files
authored
Merge pull request #113 from BillFarber/task/refactorVectors
MLE-12345 - Refactors vector utility functions
2 parents 6cc0c09 + 56d4f28 commit 70b5c32

File tree

2 files changed

+39
-41
lines changed

2 files changed

+39
-41
lines changed

marklogic/vectors.py

Lines changed: 33 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,43 @@
1+
"""
2+
Supports encoding and decoding vectors using the same approach as the vec:base64-encode and vec:base64-decode
3+
functions supported by the MarkLogic server.
4+
"""
5+
16
import base64
27
import struct
38
from typing import List
49

510

6-
class VectorUtil:
11+
def base64_encode(vector: List[float]) -> str:
712
"""
8-
Supports encoding and decoding vectors using the same approach as the vec:base64-encode and vec:base64-decode
9-
functions supported by the MarkLogic server.
13+
Encodes a list of floats as a base64 string compatible with MarkLogic's vec:base64-encode.
1014
"""
15+
dimensions = len(vector)
16+
# version (int32, 0) + dimensions (int32) + floats (little-endian)
17+
buffer = struct.pack("<ii", 0, dimensions) + struct.pack(
18+
"<" + "f" * dimensions, *vector
19+
)
20+
return base64.b64encode(buffer).decode("ascii")
1121

12-
@staticmethod
13-
def base64_encode(vector: List[float]) -> str:
14-
"""
15-
Encodes a list of floats as a base64 string compatible with MarkLogic's vec:base64-encode.
16-
"""
17-
dimensions = len(vector)
18-
# version (int32, 0) + dimensions (int32) + floats (little-endian)
19-
buffer = struct.pack("<ii", 0, dimensions) + struct.pack(
20-
"<" + "f" * dimensions, *vector
21-
)
22-
return base64.b64encode(buffer).decode("ascii")
2322

24-
@staticmethod
25-
def base64_decode(encoded_vector: str) -> List[float]:
26-
"""
27-
Decodes a base64 string to a list of floats compatible with MarkLogic's vec:base64-decode.
28-
"""
29-
buffer = base64.b64decode(encoded_vector)
30-
if len(buffer) < 8:
31-
raise ValueError(
32-
"Buffer is too short to contain version and dimensions."
33-
)
34-
version, dimensions = struct.unpack("<ii", buffer[:8])
35-
if version != 0:
36-
raise ValueError(f"Unsupported vector version: {version}")
37-
expected_length = 8 + 4 * dimensions
38-
if len(buffer) < expected_length:
39-
raise ValueError(
40-
f"Buffer is too short for the specified dimensions: expected {expected_length}, got {len(buffer)}"
41-
)
42-
floats = struct.unpack(
43-
"<" + "f" * dimensions, buffer[8 : 8 + 4 * dimensions]
23+
def base64_decode(encoded_vector: str) -> List[float]:
24+
"""
25+
Decodes a base64 string to a list of floats compatible with MarkLogic's vec:base64-decode.
26+
"""
27+
buffer = base64.b64decode(encoded_vector)
28+
if len(buffer) < 8:
29+
raise ValueError(
30+
"Buffer is too short to contain version and dimensions."
31+
)
32+
version, dimensions = struct.unpack("<ii", buffer[:8])
33+
if version != 0:
34+
raise ValueError(f"Unsupported vector version: {version}")
35+
expected_length = 8 + 4 * dimensions
36+
if len(buffer) < expected_length:
37+
raise ValueError(
38+
f"Buffer is too short for the specified dimensions: expected {expected_length}, got {len(buffer)}"
4439
)
45-
return list(floats)
40+
floats = struct.unpack(
41+
"<" + "f" * dimensions, buffer[8 : 8 + 4 * dimensions]
42+
)
43+
return list(floats)

tests/test_vectors.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math
22
import ast
3-
from marklogic.vectors import VectorUtil
3+
from marklogic.vectors import base64_encode, base64_decode
44
from marklogic import Client
55

66
VECTOR = [3.14, 1.59, 2.65]
@@ -9,17 +9,17 @@
99

1010

1111
def test_encode_and_decode_with_python():
12-
encoded = VectorUtil.base64_encode(VECTOR)
12+
encoded = base64_encode(VECTOR)
1313
assert encoded == EXPECTED_BASE64
1414

15-
decoded = VectorUtil.base64_decode(encoded)
15+
decoded = base64_decode(encoded)
1616
assert len(decoded) == len(VECTOR)
1717
for a, b in zip(decoded, VECTOR):
1818
assert abs(a - b) < ACCEPTABLE_DELTA
1919

2020

2121
def test_decode_known_base64():
22-
decoded = VectorUtil.base64_decode(EXPECTED_BASE64)
22+
decoded = base64_decode(EXPECTED_BASE64)
2323
assert len(decoded) == len(VECTOR)
2424
for a, b in zip(decoded, VECTOR):
2525
assert abs(a - b) < ACCEPTABLE_DELTA
@@ -29,7 +29,7 @@ def test_encode_and_decode_with_server(client: Client):
2929
"""
3030
Encode a vector in Python, decode it on the MarkLogic server, and check the result.
3131
"""
32-
encoded = VectorUtil.base64_encode(VECTOR)
32+
encoded = base64_encode(VECTOR)
3333
assert encoded == EXPECTED_BASE64
3434

3535
# Use MarkLogic's eval endpoint to decode the vector on the server
@@ -49,7 +49,7 @@ def test_encode_with_server_and_decode_with_python(client: Client):
4949
encoded = client.eval(xquery=xquery)[0]
5050
assert encoded == EXPECTED_BASE64
5151

52-
decoded = VectorUtil.base64_decode(encoded)
52+
decoded = base64_decode(encoded)
5353
assert len(decoded) == len(VECTOR)
5454
for a, b in zip(decoded, VECTOR):
5555
assert math.isclose(a, b, abs_tol=ACCEPTABLE_DELTA)

0 commit comments

Comments
 (0)