Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions examples/README.MD
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Text Sentiment (SVM) with Class Rebalancing — imbalanced-learn Example

**What:** A small, runnable demo for 3-class sentiment (negative / neutral / positive) using:

```
TF-IDF → RandomUnderSampler → LinearSVC
```

**Why:** Text features are sparse (TF-IDF). Oversampling methods like SMOTE target dense/continuous data;

**under-sampling** works out-of-the-box for sparse text.

---

## Files

* `examples/text_sentiment_svm_with_resampling.py` — example script (CLI)
* `imblearn/tests/test_text_sentiment_example.py` & `..._cli.py` — fast smoke and unit tests

---

## Setup

```bash
# in a virtual env
pip install -e . # install this repo
pip install datasets matplotlib pytest
# optional: keep dataset cache local
export HF_DATASETS_CACHE="$PWD/.hf_cache"
```

## Run

```bash
python examples/text_sentiment_svm_with_resampling.py --plot --max-samples 6000
```

**Outputs**

* Prints **balanced accuracy** + **classification report**
* Saves `confmat_svm_imblearn.png` when `--plot` is used

**CLI options**

```
--max-samples INT Limit training size (None = full). Default: 6000
--plot Save confusion matrix image
--output PATH Image path (default: confmat_svm_imblearn.png)
```

---

## Tests

```bash
pytest -q imblearn/tests/test_text_sentiment_example.py
pytest -q imblearn/tests/test_text_sentiment_example_cli.py
```

Tests are quick, deterministic, and skipped if `datasets` isn’t installed.

---

## Notes

* Metric focus: **balanced accuracy** & **macro-F1** (better for imbalance)
* Reproducible: fixed `random_state`, controllable `--max-samples`
* Troubleshooting: low disk? use `pip --no-cache-dir`, clear caches, keep only one env active

---
139 changes: 139 additions & 0 deletions examples/text_sentiment_svm_with_resampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
#!/usr/bin/env python
"""
Text Sentiment with LinearSVC and Class Rebalancing
===================================================

What
----
A runnable example for 3-class sentiment (negative/neutral/positive) using a
TF–IDF → RandomUnderSampler → LinearSVC pipeline.

Why
---
Text data is typically represented as **sparse** features. Popular over-sampling
methods (e.g., SMOTE) operate on dense, continuous features, so we demonstrate
**under-sampling** as a practical alternative for sparse TF–IDF.

Usage
-----
$ python examples/text_sentiment_svm_with_resampling.py --plot --max-samples 6000

Outputs
-------
- Prints balanced accuracy and a classification report.
- Saves a confusion-matrix image when `--plot` is supplied (default: confmat_svm_imblearn.png).

Dependencies
------------
pip install datasets matplotlib

Reproducibility
---------------
- `--max-samples` bounds runtime and memory.
- Fixed `random_state` for deterministic sampling.

Limitations
-----------
- Under-sampling discards majority samples (variance can increase on small data).
- SMOTE is not used because it doesn’t support sparse TF–IDF.
"""


from __future__ import annotations
import argparse
import numpy as np
import matplotlib.pyplot as plt

from imblearn.pipeline import Pipeline
from imblearn.under_sampling import RandomUnderSampler
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import (
classification_report,
balanced_accuracy_score,
ConfusionMatrixDisplay,
confusion_matrix,
)
from sklearn.svm import LinearSVC

try:
from datasets import load_dataset
except Exception as e: # pragma: no cover
raise SystemExit(
"This example requires the 'datasets' package.\n"
"Install it with:\n pip install datasets\n"
) from e


def load_tweet_eval(max_samples: int | None = 6000, random_state: int = 42):
"""Load 3-class sentiment from tweet_eval.

Returns X_train, y_train, X_test, y_test.
If max_samples is set, subsamples training data for speed.
"""
ds = load_dataset("tweet_eval", "sentiment")
# labels: 0=negative, 1=neutral, 2=positive
def xy(split):
X = [ex["text"] for ex in split]
y = np.array([ex["label"] for ex in split], dtype=int)
return X, y

X_tr, y_tr = xy(ds["train"])
X_va, y_va = xy(ds["validation"])
X_te, y_te = xy(ds["test"])

# merge train+validation for a larger training pool
X_train = X_tr + X_va
y_train = np.concatenate([y_tr, y_va])

if max_samples:
rng = np.random.default_rng(random_state)
idx = rng.choice(len(X_train), size=min(max_samples, len(X_train)), replace=False)
X_train = [X_train[i] for i in idx]
y_train = y_train[idx]
# also downsample test a bit for quick runs
idx_t = rng.choice(len(X_te), size=min(max_samples // 3 + 300, len(X_te)), replace=False)
X_test = [X_te[i] for i in idx_t]
y_test = y_te[idx_t]
else:
X_test, y_test = X_te, y_te

return X_train, y_train, X_test, y_test


def main(argv=None):
parser = argparse.ArgumentParser(
description="3-class sentiment with LinearSVC and RandomUnderSampler."
)
parser.add_argument("--max-samples", type=int, default=6000,
help="Max training samples for speed (set None for full).")
parser.add_argument("--plot", action="store_true", help="Save confusion matrix PNG.")
parser.add_argument("--output", type=str, default="confmat_svm_imblearn.png",
help="Output path for confusion matrix.")
args = parser.parse_args(argv)

X_train, y_train, X_test, y_test = load_tweet_eval(max_samples=args.max_samples)

# Note: SMOTE does not support sparse input; use an under-sampler for text
pipe = Pipeline(steps=[
("tfidf", TfidfVectorizer(min_df=2, ngram_range=(1, 2))),
("balance", RandomUnderSampler(random_state=0)),
("clf", LinearSVC()),
])

pipe.fit(X_train, y_train)
y_pred = pipe.predict(X_test)

bal_acc = balanced_accuracy_score(y_test, y_pred)
print(f"Balanced accuracy: {bal_acc:.3f}")
print(classification_report(y_test, y_pred, target_names=["negative", "neutral", "positive"]))

if args.plot:
cm = confusion_matrix(y_test, y_pred, labels=[0, 1, 2])
ConfusionMatrixDisplay(cm, display_labels=["neg", "neu", "pos"]).plot(values_format="d")
plt.tight_layout()
plt.savefig(args.output, dpi=150)
print(f"Saved confusion matrix to {args.output}")


if __name__ == "__main__":
main()
30 changes: 30 additions & 0 deletions imblearn/tests/test_text_sentiment_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import numpy as np
import pytest

datasets = pytest.importorskip("datasets")
from datasets import load_dataset
from imblearn.pipeline import Pipeline
from imblearn.under_sampling import RandomUnderSampler
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.svm import LinearSVC


def _small_split(n_train=900, n_test=300):
ds = load_dataset("tweet_eval", "sentiment")
X = [x["text"] for x in ds["train"]][: n_train + n_test]
y = np.array([x["label"] for x in ds["train"]][: n_train + n_test], dtype=int)
return X[:n_train], y[:n_train], X[n_train:], y[n_train:]


def test_pipeline_trains_and_predicts():
Xtr, ytr, Xte, yte = _small_split()
pipe = Pipeline([
("tfidf", TfidfVectorizer(max_features=20000)),
("balance", RandomUnderSampler(random_state=0)),
("clf", LinearSVC()),
])
pipe.fit(Xtr, ytr)
pred = pipe.predict(Xte)
assert len(pred) == len(yte)
# predictions should be 0/1/2 labels
assert set(np.unique(pred)).issubset({0, 1, 2})
59 changes: 59 additions & 0 deletions imblearn/tests/test_text_sentiment_example_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os
import sys
import numpy as np
import pytest

datasets = pytest.importorskip("datasets")

# Import the example as a module (pytest adds repo root to sys.path)
from examples.text_sentiment_svm_with_resampling import (
main,
load_tweet_eval,
)

@pytest.mark.filterwarnings("ignore::UserWarning")
def test_loader_reproducible_small():
"""Same seed -> identical splits (reproducibility)."""
X1, y1, Xt1, Yt1 = load_tweet_eval(max_samples=800, random_state=42)
X2, y2, Xt2, Yt2 = load_tweet_eval(max_samples=800, random_state=42)
assert X1 == X2
assert np.array_equal(y1, y2)
assert Xt1 == Xt2
assert np.array_equal(Yt1, Yt2)

@pytest.mark.filterwarnings("ignore::UserWarning")
def test_smoke_predicts_labels_small():
"""End-to-end: pipeline trains and predicts on a tiny slice."""
Xtr, ytr, Xte, yte = load_tweet_eval(max_samples=800, random_state=0)
# Build the same pipeline as in the example
from imblearn.pipeline import Pipeline
from imblearn.under_sampling import RandomUnderSampler
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.svm import LinearSVC

pipe = Pipeline([
("tfidf", TfidfVectorizer(min_df=2, ngram_range=(1, 2))),
("balance", RandomUnderSampler(random_state=0)),
("clf", LinearSVC()),
])
pipe.fit(Xtr, ytr)
pred = pipe.predict(Xte)
assert len(pred) == len(yte)
# Predictions must be in the expected label set {0,1,2}
assert set(np.unique(pred)).issubset({0, 1, 2})

@pytest.mark.filterwarnings("ignore::UserWarning")
def test_cli_saves_plot(tmp_path):
"""CLI: --plot should create the confusion matrix image."""
out = tmp_path / "cm.png"
main(["--plot", "--max-samples", "800", "--output", str(out)])
assert out.exists() and out.stat().st_size > 0

@pytest.mark.filterwarnings("ignore::UserWarning")
def test_cli_no_plot_no_file(tmp_path):
"""CLI: without --plot, no image should be created."""
out = tmp_path / "cm.png"
if out.exists():
os.remove(out)
main(["--max-samples", "500", "--output", str(out)])
assert not out.exists()