Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 94 additions & 3 deletions stratum/tests/adapters/test_string_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@

from stratum import StringEncoder
from stratum import set_config
from stratum.adapters.string_encoder import RustyStringEncoder, _clean_strings
from stratum.adapters.string_encoder import (
RustyStringEncoder,
_clean_strings,
_rust_supported_subset,
_prep_strings,
_prep_strings_transform,
)

set_config(rust_backend=True, debug_timing=True, num_threads=8)

Expand All @@ -18,19 +24,104 @@ def capture_std_out(capfd):
return combined_output

@pytest.mark.parametrize("analyzer", ["char", "char_wb"])
def test_string_encoder_result(analyzer, capfd):
@pytest.mark.parametrize("vectorizer", ["hashing", "tfidf"])
def test_string_encoder_result(analyzer, vectorizer, capfd):
s = pd.Series(["foo", "bar", None, "lorem ipsum dolor"]) # nulls handled upstream

# StringEncoder should point to our subclass
assert StringEncoder is RustyStringEncoder

enc = StringEncoder(vectorizer='hashing', analyzer=analyzer, ngram_range=(3,5), n_components=2)
enc = StringEncoder(vectorizer=vectorizer, analyzer=analyzer, ngram_range=(3,5), n_components=2)
Z = enc.fit_transform(s)
assert Z.shape[0] == len(s)

# Round-trip through transform to cover the transform code path
s2 = pd.Series(["foo", "baz", "lorem"])
Z2 = enc.transform(s2)
assert Z2.shape[0] == len(s2)
assert Z2.shape[1] == enc.n_components_

# Assert if rust timing appeared (verifies that rust code is executed)
assert "[rust]" in capture_std_out(capfd)


def test_string_encoder_padding():
"""n_components larger than achievable rank triggers the padding branch."""
s = pd.Series(["a", "b"])
enc = StringEncoder(vectorizer="hashing", analyzer="char",
ngram_range=(3, 5), n_components=50)
Z = enc.fit_transform(s)
assert Z.shape == (2, 50)
Z2 = enc.transform(pd.Series(["a", "c"]))
assert Z2.shape == (2, 50)


def test_string_encoder_fallback_when_rust_disabled():
"""Disabling rust_backend routes through sklearn fallback path."""
set_config(rust_backend=False)
try:
s = pd.Series(["foo", "bar", "baz", "lorem ipsum"])
enc = StringEncoder(vectorizer="hashing", analyzer="char",
ngram_range=(3, 5), n_components=2)
Z = enc.fit_transform(s)
assert Z.shape[0] == len(s)
Z2 = enc.transform(pd.Series(["foo", "bar"]))
assert Z2.shape[0] == 2
finally:
set_config(rust_backend=True)


def test_transform_without_fit_falls_back():
"""Calling transform before fit_transform should take the sklearn fallback."""
enc = StringEncoder(vectorizer="hashing", analyzer="char",
ngram_range=(3, 5), n_components=2)
# Force sklearn fit so internal sklearn state exists, but _rust_state_ stays None
set_config(rust_backend=False)
try:
enc.fit(pd.Series(["foo", "bar", "baz"]))
finally:
set_config(rust_backend=True)
# _rust_state_ is None → transform should fall back
Z = enc.transform(pd.Series(["foo", "bar"]))
assert Z.shape[0] == 2


class _Dummy:
def __init__(self, **kw):
for k, v in kw.items():
setattr(self, k, v)


def test_rust_supported_subset_rejections():
ok, _ = _rust_supported_subset(_Dummy(vectorizer="hashing", stop_words=None,
analyzer="char", ngram_range=(3, 5)))
assert ok

ok, msg = _rust_supported_subset(_Dummy(vectorizer="count"))
assert not ok and "vectorizer" in msg

ok, msg = _rust_supported_subset(_Dummy(vectorizer="tfidf", stop_words="english"))
assert not ok and "stop_words" in msg

ok, msg = _rust_supported_subset(_Dummy(vectorizer="tfidf", stop_words=None,
analyzer="word"))
assert not ok and "analyzer" in msg

ok, msg = _rust_supported_subset(_Dummy(vectorizer="tfidf", stop_words=None,
analyzer="char", ngram_range=(5, 3)))
assert not ok and "ngram_range" in msg


def test_prep_strings_roundtrip():
s = pd.Series(["foo", None, "bar"])
out = _prep_strings(s)
assert len(out) == 3
assert all(isinstance(v, str) for v in out)

out2 = _prep_strings_transform(s)
assert len(out2) == 3
assert all(isinstance(v, str) for v in out2)

def test_clean_strings_with_nan():
"""Test _clean_strings with NaN values."""
x_list = ["foo", None, np.nan, float('nan'), "bar", 123]
Expand Down
Loading