Skip to content

Commit

Permalink
Merge pull request #59 from mindsdb/refactor/modules
Browse files Browse the repository at this point in the history
[refactor] Inference engines
  • Loading branch information
paxcema authored Dec 25, 2023
2 parents 42419a8 + 9c63ba0 commit 667e581
Show file tree
Hide file tree
Showing 19 changed files with 817 additions and 739 deletions.
13 changes: 9 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "type_infer"
version = "0.0.17"
version = "0.0.18"
description = "Automated type inference for Machine Learning pipelines."
authors = ["MindsDB Inc. <[email protected]>"]
license = "GPL-3.0"
Expand All @@ -15,12 +15,17 @@ numpy = "^1.15"
pandas = "^2"
dataclasses-json = "^0.6.3"
colorlog = "^6.5.0"
langid = "^1.1.6"
nltk = "^3"
toml = "^0.10.2"
psutil = "^5.9.0"
toml = "^0.10.2"

# rule based deps, part of core
langid = "^1.1.6"
nltk = "^3"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

# TODO: update once this engine is introduced
[tool.poetry.extras]
# bert = ["torch"]
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
from datetime import datetime, timedelta

from type_infer.dtype import dtype
from type_infer.infer import infer_types
from type_infer.api import infer_types


class TestTypeInference(unittest.TestCase):
class TestRuleBasedTypeInference(unittest.TestCase):
def test_0_airline_sentiment(self):
df = pd.read_csv("tests/data/airline_sentiment_sample.csv")
inferred_types = infer_types(df, pct_invalid=0)
config = {'engine': 'rule_based', 'pct_invalid': 0, 'seed': 420, 'mp_cutoff': 1e4}
inferred_types = infer_types(df, config=config)

expected_types = {
'airline_sentiment': 'categorical',
Expand Down Expand Up @@ -44,6 +45,7 @@ def test_0_airline_sentiment(self):

def test_1_stack_overflow_survey(self):
df = pd.read_csv("tests/data/stack_overflow_survey_sample.csv")
config = {'engine': 'rule_based', 'pct_invalid': 0, 'seed': 420, 'mp_cutoff': 1e4}

expected_types = {
'Respondent': 'integer',
Expand All @@ -68,7 +70,7 @@ def test_1_stack_overflow_survey(self):
'Professional': 'No Information'
}

inferred_types = infer_types(df, pct_invalid=0)
inferred_types = infer_types(df, config=config)

for col in expected_types:
self.assertTrue(expected_types[col], inferred_types.dtypes[col])
Expand All @@ -90,7 +92,10 @@ def test_2_simple(self):
# manual tinkering
df['float'].iloc[-n_corrupted:] = 'random string'

inferred_types = infer_types(df, pct_invalid=100 * (n_corrupted) / n_points)
pct_invalid = 100 * (n_corrupted) / n_points
config = {'engine': 'rule_based', 'pct_invalid': pct_invalid, 'seed': 420, 'mp_cutoff': 1e4}

inferred_types = infer_types(df, config=config)
expected_types = {
'date': dtype.date,
'datetime': dtype.datetime,
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import unittest

from type_infer.dtype import dtype
from type_infer.infer import type_check_date
from type_infer.rule_based.core import RuleBasedEngine

type_check_date = RuleBasedEngine.type_check_date


class TestDates(unittest.TestCase):
Expand Down
22 changes: 22 additions & 0 deletions tests/unit_tests/rule_based/test_infer_dtypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import unittest
import random

import pandas as pd
from type_infer.rule_based.core import RuleBasedEngine
from type_infer.dtype import dtype

get_column_data_type = RuleBasedEngine.get_column_data_type


class TestInferDtypes(unittest.TestCase):
def test_negative_integers(self):
data = pd.DataFrame([-random.randint(-10, 10) for _ in range(100)], columns=['test_col'])
engine = RuleBasedEngine()
dtyp, dist, ainfo, warn, info = engine.get_column_data_type(data['test_col'], data, 'test_col', 0.0)
self.assertEqual(dtyp, dtype.integer)

def test_negative_floats(self):
data = pd.DataFrame([float(-random.randint(-10, 10)) for _ in range(100)] + [0.1], columns=['test_col'])
engine = RuleBasedEngine()
dtyp, dist, ainfo, warn, info = engine.get_column_data_type(data['test_col'], data, 'test_col', 0.0)
self.assertEqual(dtyp, dtype.float)
12 changes: 12 additions & 0 deletions tests/unit_tests/rule_based/test_misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import unittest

from type_infer.rule_based.helpers import tokenize_text


class TestDates(unittest.TestCase):
def test_get_tokens(self):
sentences = ['hello, world!', ' !hello! world!!,..#', '#hello!world']
for sent in sentences:
assert list(tokenize_text(sent)) == ['hello', 'world']

assert list(tokenize_text("don't wouldn't")) == ['do', 'not', 'would', 'not']
18 changes: 0 additions & 18 deletions tests/unit_tests/test_infer_dtypes.py

This file was deleted.

8 changes: 0 additions & 8 deletions tests/unit_tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from pathlib import Path

import type_infer
from type_infer.helpers import tokenize_text


class TestDates(unittest.TestCase):
Expand All @@ -19,10 +18,3 @@ def test_versions_are_in_sync(self):
package_init_version = type_infer.__version__

self.assertEqual(package_init_version, pyproject_version)

def test_get_tokens(self):
sentences = ['hello, world!', ' !hello! world!!,..#', '#hello!world']
for sent in sentences:
assert list(tokenize_text(sent)) == ['hello', 'world']

assert list(tokenize_text("don't wouldn't")) == ['do', 'not', 'would', 'not']
10 changes: 6 additions & 4 deletions type_infer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from type_infer import base
from type_infer import dtype
from type_infer import infer
from type_infer import api
from type_infer import helpers

__version__ = '0.0.18'

__version__ = '0.0.17'


__all__ = ['base', 'dtype', 'infer', 'helpers', '__version__']
__all__ = [
'__version__',
'base', 'dtype', 'api', 'helpers',
]
40 changes: 40 additions & 0 deletions type_infer/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Dict, Optional
import pandas as pd

from type_infer.base import TypeInformation, ENGINES
from type_infer.rule_based.core import RuleBasedEngine


def infer_types(
data: pd.DataFrame,
config: Optional[Dict] = None
) -> TypeInformation:
"""
Infers the data types of each column of the dataset by analyzing a small sample of
each column's items.
Inputs
----------
data : pd.DataFrame
The input dataset for which we want to infer data type information.
"""
# Set global defaults if missing
if config is None:
config = {'engine': 'rule_based', 'pct_invalid': 2, 'seed': 420, 'mp_cutoff': 1e4}
elif 'engine' not in config:
config['engine'] = 'rule_based'

if 'pct_invalid' not in config:
config['pct_invalid'] = 2

if 'seed' not in config:
config['seed'] = 420

if config['engine'] == ENGINES.RULE_BASED:
if 'mp_cutoff' not in config:
config['mp_cutoff'] = 1e4

engine = RuleBasedEngine(config)
return engine.infer(data)
else:
raise Exception(f'Unknown engine {config["engine"]}')
13 changes: 13 additions & 0 deletions type_infer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,16 @@ def __init__(self):
self.dtypes = dict()
self.additional_info = dict()
self.identifiers = dict()


class BaseEngine:
def __init__(self, stable=True):
self.stable = stable # whether the engine is stable or not (i.e. experimental)

def infer(self, df) -> TypeInformation:
"""Given a dataframe, infer the types of each column and return a TypeInformation object."""
raise NotImplementedError


class ENGINES:
RULE_BASED = 'rule_based'
Empty file added type_infer/bert/__init__.py
Empty file.
9 changes: 9 additions & 0 deletions type_infer/bert/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from type_infer.base import BaseEngine


class BERType(BaseEngine):
def __init__(self, stable=False):
super().__init__(stable=stable)

def infer(self, df):
raise NotImplementedError
3 changes: 3 additions & 0 deletions type_infer/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,6 @@ class dtype:
# Misc (Unk/NaNs)
empty = "empty"
invalid = "invalid"


# TODO: modifier class + system
Loading

0 comments on commit 667e581

Please sign in to comment.