Skip to content

Support zero-copy training from Python package #128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@ version = "0.8.0"

[workspace.dependencies]
anyhow = { version = "1.0", features = ["backtrace"] }
arrow2 = { version = "0.14" }
backtrace = "0.3"
base64 = "0.13"
bitvec = "1.0"
8 changes: 5 additions & 3 deletions crates/cli/train.rs
Original file line number Diff line number Diff line change
@@ -53,12 +53,14 @@ pub fn train(args: TrainArgs) -> Result<()> {
let input = match (&args.file, &args.file_train, &args.file_test, args.stdin) {
(None, None, None, true) => modelfox_core::train::TrainingDataSource::Stdin,
(Some(file_path), None, None, false) => {
modelfox_core::train::TrainingDataSource::File(file_path.to_owned())
modelfox_core::train::TrainingDataSource::Train(
modelfox_core::train::FileOrArrow::File(file_path.to_owned()),
)
}
(None, Some(file_path_train), Some(file_path_test), false) => {
modelfox_core::train::TrainingDataSource::TrainAndTest {
train: file_path_train.to_owned(),
test: file_path_test.to_owned(),
train: modelfox_core::train::FileOrArrow::File(file_path_train.to_owned()),
test: modelfox_core::train::FileOrArrow::File(file_path_test.to_owned()),
}
}
_ => bail!("Must use the stdin flag or provide training data files."),
1 change: 1 addition & 0 deletions crates/core/Cargo.toml
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@ path = "lib.rs"

[dependencies]
anyhow = { workspace = true }
arrow2 = { workspace = true }
bitvec = { workspace = true }
buffalo = { workspace = true }
chrono = { workspace = true }
109 changes: 65 additions & 44 deletions crates/core/train.rs
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@ use crate::{
test,
};
use anyhow::{anyhow, bail, Result};
use arrow2::ffi::ArrowArrayStream;
use modelfox_id::Id;
use modelfox_kill_chip::KillChip;
use modelfox_progress_counter::ProgressCounter;
@@ -35,12 +36,17 @@ use std::{
unreachable,
};

pub enum FileOrArrow {
File(std::path::PathBuf),
Arrow(*const ArrowArrayStream),
}

pub enum TrainingDataSource {
Stdin,
File(std::path::PathBuf),
Train(FileOrArrow),
TrainAndTest {
train: std::path::PathBuf,
test: std::path::PathBuf,
train: FileOrArrow,
test: FileOrArrow,
},
}

@@ -82,7 +88,7 @@ impl Trainer {
target_column_name,
handle_progress_event,
)?),
TrainingDataSource::File(file_path) => Dataset::Train(load_and_shuffle_dataset_train(
TrainingDataSource::Train(file_path) => Dataset::Train(load_and_shuffle_dataset_train(
&file_path,
&config,
target_column_name,
@@ -729,25 +735,31 @@ fn load_and_shuffle_dataset_stdin(
}

fn load_and_shuffle_dataset_train(
file_path: &Path,
file_path: &FileOrArrow,
config: &Config,
target_column_name: &str,
handle_progress_event: &mut dyn FnMut(ProgressEvent),
) -> Result<DatasetTrain> {
let mut handle_progress_event_inner = |progress_event| {
handle_progress_event(ProgressEvent::Load(LoadProgressEvent::Train(
progress_event,
)))
};
// Get the column types from the config, if set.
let mut table = Table::from_path(
file_path,
modelfox_table::FromCsvOptions {
column_types: column_types_from_config(config),
infer_options: Default::default(),
..Default::default()
},
&mut |progress_event| {
handle_progress_event(ProgressEvent::Load(LoadProgressEvent::Train(
progress_event,
)))
},
)?;
let mut table = match file_path {
FileOrArrow::File(file_path) => Table::from_path(
file_path,
modelfox_table::FromCsvOptions {
column_types: column_types_from_config(config),
infer_options: Default::default(),
..Default::default()
},
&mut handle_progress_event_inner,
)?,
FileOrArrow::Arrow(stream_ptr) => {
Table::from_arrow(*stream_ptr, &mut handle_progress_event_inner)?
}
};
// Drop any rows with invalid data in the target column
drop_invalid_target_rows(&mut table, target_column_name, handle_progress_event);
// Shuffle the table if enabled.
@@ -761,27 +773,33 @@ fn load_and_shuffle_dataset_train(
}

fn load_and_shuffle_dataset_train_and_test(
file_path_train: &Path,
file_path_test: &Path,
file_path_train: &FileOrArrow,
file_path_test: &FileOrArrow,
config: &Config,
target_column_name: &str,
handle_progress_event: &mut dyn FnMut(ProgressEvent),
) -> Result<DatasetTrainAndTest> {
let mut handle_progress_event_inner = |progress_event| {
handle_progress_event(ProgressEvent::Load(LoadProgressEvent::Train(
progress_event,
)))
};
// Get the column types from the config, if set.
let column_types = column_types_from_config(config);
let mut table_train = Table::from_path(
file_path_train,
modelfox_table::FromCsvOptions {
column_types,
infer_options: Default::default(),
..Default::default()
},
&mut |progress_event| {
handle_progress_event(ProgressEvent::Load(LoadProgressEvent::Train(
progress_event,
)))
},
)?;
let mut table_train = match file_path_train {
FileOrArrow::File(file_path_train) => Table::from_path(
file_path_train,
modelfox_table::FromCsvOptions {
column_types,
infer_options: Default::default(),
..Default::default()
},
&mut handle_progress_event_inner,
)?,
FileOrArrow::Arrow(stream_ptr_train) => {
Table::from_arrow(*stream_ptr_train, &mut handle_progress_event_inner)?
}
};
// Force the column types for table_test to be the same as table_train.
let column_types = table_train
.columns()
@@ -802,17 +820,20 @@ fn load_and_shuffle_dataset_train_and_test(
TableColumn::Text(column) => (column.name().to_owned().unwrap(), TableColumnType::Text),
})
.collect();
let mut table_test = Table::from_path(
file_path_test,
modelfox_table::FromCsvOptions {
column_types: Some(column_types),
infer_options: Default::default(),
..Default::default()
},
&mut |progress_event| {
handle_progress_event(ProgressEvent::Load(LoadProgressEvent::Test(progress_event)))
},
)?;
let mut table_test = match file_path_test {
FileOrArrow::File(file_path_test) => Table::from_path(
file_path_test,
modelfox_table::FromCsvOptions {
column_types: Some(column_types),
infer_options: Default::default(),
..Default::default()
},
&mut handle_progress_event_inner,
)?,
FileOrArrow::Arrow(stream_ptr_test) => {
Table::from_arrow(*stream_ptr_test, &mut handle_progress_event_inner)?
}
};
if table_train.columns().len() != table_test.columns().len() {
bail!("Training data and test data must contain the same number of columns.")
}
1 change: 1 addition & 0 deletions crates/table/Cargo.toml
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@ insta = { workspace = true }

[dependencies]
anyhow = { workspace = true }
arrow2 = { workspace = true }
csv = { workspace = true }
fast-float = { workspace = true }
fnv = { workspace = true }
130 changes: 130 additions & 0 deletions crates/table/load.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
use super::{Table, TableColumn, TableColumnType};
use anyhow::Result;
use arrow2::{
array::{BooleanArray, PrimitiveArray, StructArray},
datatypes::DataType,
ffi,
};
use modelfox_progress_counter::ProgressCounter;
use modelfox_zip::zip;
// NOTE - this import is actually used, false positive with the lint.
@@ -8,6 +13,7 @@ use num::ToPrimitive;
use std::{
collections::{BTreeMap, BTreeSet},
path::Path,
vec,
};

#[derive(Clone)]
@@ -244,6 +250,130 @@ impl Table {
handle_progress_event(ProgressEvent::LoadDone);
Ok(table)
}

pub fn from_arrow(
stream_ptr: *const ffi::ArrowArrayStream,
handle_progress_event: &mut impl FnMut(ProgressEvent),
) -> Result<Table> {
let stream = unsafe { Box::from_raw(stream_ptr as *mut ffi::ArrowArrayStream) };

// copy fields out from stream reader
let mut iter = unsafe { ffi::ArrowArrayStreamReader::try_new(stream) }?;
let mut all_values = vec![];
let mut column_names = vec![];
let mut column_types = vec![];

while let Some(array) = unsafe { iter.next() } {
let array = array.unwrap();
let array = array.as_any().downcast_ref::<StructArray>().unwrap();
let (fields, values, _) = array.clone().into_data();

for (field, value) in zip!(fields, values) {
let column_name = field.name.clone();
let column_type = match field.data_type {
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Float16
| DataType::Float32
| DataType::Float64
| DataType::Boolean => TableColumnType::Number,
DataType::Utf8 => {
let mut uniques = BTreeSet::new();
if let Some(value) = value
.as_any()
.downcast_ref::<arrow2::array::Utf8Array<i32>>()
{
uniques
.extend(value.values_iter().map(std::string::ToString::to_string));
} else if let Some(value) = value
.as_any()
.downcast_ref::<arrow2::array::Utf8Array<i64>>()
{
uniques
.extend(value.values_iter().map(std::string::ToString::to_string));
} else {
unreachable!();
}
let variants = uniques.into_iter().collect();
TableColumnType::Enum { variants }
}
_ => TableColumnType::Unknown,
};

column_names.push(Some(column_name));
column_types.push(column_type);
all_values.push(value);
}
}
std::mem::forget(iter);
handle_progress_event(ProgressEvent::InferDone);

// write table data
let mut table = Table::new(column_names, column_types);

for (column, value) in zip!(&mut table.columns, all_values) {
match column {
TableColumn::Unknown(_) => {
unreachable!();
}
TableColumn::Number(column) => {
if let Some(value) = value.as_any().downcast_ref::<PrimitiveArray<f32>>() {
column.data.extend(value.values_iter());
} else if let Some(value) = value.as_any().downcast_ref::<PrimitiveArray<f64>>()
{
column.data.extend(value.values_iter().map(|&x| x as f32));
} else if let Some(value) = value.as_any().downcast_ref::<PrimitiveArray<i32>>()
{
column.data.extend(value.values_iter().map(|&x| x as f32));
} else if let Some(value) = value.as_any().downcast_ref::<PrimitiveArray<i64>>()
{
column.data.extend(value.values_iter().map(|&x| x as f32));
} else if let Some(value) = value.as_any().downcast_ref::<BooleanArray>() {
column
.data
.extend(value.values_iter().map(|x| i32::from(x) as f32));
} else {
unreachable!();
}
}
TableColumn::Enum(column) => {
if let Some(value) = value
.as_any()
.downcast_ref::<arrow2::array::Utf8Array<i32>>()
{
let mut v: Vec<Option<std::num::NonZeroUsize>> = Vec::new();
for s in value.values_iter() {
v.push(column.value_for_variant(s));
}
column.data.extend(v);
} else if let Some(value) = value
.as_any()
.downcast_ref::<arrow2::array::Utf8Array<i64>>()
{
let mut v: Vec<Option<std::num::NonZeroUsize>> = Vec::new();
for s in value.values_iter() {
v.push(column.value_for_variant(s));
}
column.data.extend(v);
} else {
unreachable!();
}
}
TableColumn::Text(_column) => {
unreachable!();
}
}
}

handle_progress_event(ProgressEvent::LoadDone);
Ok(table)
}
}

#[derive(Clone, Debug)]
1 change: 1 addition & 0 deletions crates/www/content/Cargo.toml
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@ version = { workspace = true }

[lib]
path = "lib.rs"
doctest = false

[dependencies]
anyhow = { workspace = true }
5 changes: 4 additions & 1 deletion languages/python/Cargo.toml
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@ path = "lib.rs"

[dependencies]
anyhow = { workspace = true }
arrow2 = { workspace = true }
chrono = { workspace = true }
memmap = { workspace = true }
pyo3 = { workspace = true }
@@ -26,4 +27,6 @@ serde_json = { workspace = true }
url = { workspace = true }

modelfox_core = { workspace = true }
modelfox_model = { workspace = true }
modelfox_model = { workspace = true }
modelfox_kill_chip = { workspace = true }
modelfox_id = { workspace = true }
12 changes: 12 additions & 0 deletions languages/python/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
.PHONY: test dev

test: dev
.venv/bin/python examples/basic/train.py

dev: .venv
cargo build -p modelfox_python
cp ../../target/debug/libmodelfox_python.so modelfox/modelfox_python.so
.venv/bin/pip install -e .

.venv:
virtualenv .venv
304 changes: 304 additions & 0 deletions languages/python/examples/basic/heart_disease.csv

Large diffs are not rendered by default.

Binary file not shown.
9 changes: 4 additions & 5 deletions languages/python/examples/basic/main.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import os
import json
import modelfox
from typing import cast

# Get the path to the .modelfox file.
model_path = os.path.join(os.path.dirname(__file__), "heart_disease.modelfox")
# Load the model from the path.
model = modelfox.Model.from_path(model_path)

# Create an example input matching the schema of the CSV file the model was trained on. Here the data is just hard-coded, but in your application you will probably get this from a database or user input.
input = {
specimen = {
"age": 63,
"gender": "male",
"chest_pain": "typical angina",
@@ -26,7 +24,8 @@
}

# Make the prediction!
output = model.predict(5)
output = model.predict(specimen)

# Print the output.
print("Output:", output)
print("Output.class_name:", output.class_name)
print("Output.probability:", output.probability)
5 changes: 5 additions & 0 deletions languages/python/examples/basic/pyrightconfig.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"venvPath": "../..",
"venv": ".venv"
}

47 changes: 47 additions & 0 deletions languages/python/examples/basic/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import os
import pyarrow as pa
from pyarrow.cffi import ffi as arrow_c
import pandas as pd
import modelfox

# Get the path to the CSV file.
csv_path = os.path.join(os.path.dirname(__file__), "heart_disease.csv")
# Get the path to the .modelfox file.
model_path = os.path.join(os.path.dirname(__file__), "heart_disease.modelfox")

# # Read the CSV file into a PyArrow.
df = pd.read_csv(csv_path)

batch = pa.RecordBatch.from_pandas(df)
reader = pa.ipc.RecordBatchStreamReader.from_batches(batch.schema, [batch])

with arrow_c.new("struct ArrowArrayStream*") as c_stream:
c_stream_ptr = int(arrow_c.cast("uintptr_t", c_stream))
reader._export_to_c(c_stream_ptr)

# Train a model.
model = modelfox.Model.train(c_stream_ptr, "diagnosis", model_path)

# Create an example input matching the schema of the CSV file the model was trained on. Here the data is just hard-coded, but in your application you will probably get this from a database or user input.
specimen = {
"age": 63,
"gender": "male",
"chest_pain": "typical angina",
"resting_blood_pressure": 145,
"cholesterol": 233,
"fasting_blood_sugar_greater_than_120": "true",
"resting_ecg_result": "probable or definite left ventricular hypertrophy",
"exercise_max_heart_rate": 150,
"exercise_induced_angina": "no",
"exercise_st_depression": 2.3,
"exercise_st_slope": "downsloping",
"fluoroscopy_vessels_colored": "0",
"thallium_stress_test": "fixed defect",
}

# Make the prediction!
output = model.predict(specimen)

# Print the output.
print("Output.class_name:", output.class_name)
print("Output.probability:", output.probability)
90 changes: 89 additions & 1 deletion languages/python/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use anyhow::anyhow;
use arrow2::ffi::ArrowArrayStream;
use memmap::Mmap;
use pyo3::{prelude::*, type_object::PyTypeObject, types::PyType};
use std::collections::BTreeMap;
use std::{collections::BTreeMap, path::PathBuf};
use url::Url;

#[pymodule]
@@ -118,6 +119,70 @@ impl Model {
self.model.id.clone()
}

/**
Train a model!
Args:
input (Union[List[`PredictInput`], `PredictInput`]): A predict input is either a single predict input which is a dict from strings to strings or floats or an array of such dicts. The keys should match the columns in the CSV file you trained your model with.
options (Optional[`PredictOptions`]): These are the predict options.
Returns:
[Union[List[`PredictOutput`], `PredictOutput`]). Return a single output if `input` was a single input, or an array if `input` was an array of `input`s.
*/
#[classmethod]
#[args(input, target, output, config = "None")]
#[pyo3(text_signature = "(input, target, output, config=None)")]
pub fn train(
cls: &PyType,
input: Input,
target: String,
output: String,
config: Option<String>,
) -> PyResult<Model> {
let mut handle_progress_event = |_progress_event| {};
let input = match input {
Input::Train(file) => modelfox_core::train::TrainingDataSource::Train(file.into()),
Input::TrainAndTest((file_train, file_test)) => {
modelfox_core::train::TrainingDataSource::TrainAndTest {
train: file_train.into(),
test: file_test.into(),
}
}
};
// Load the dataset, compute stats, and prepare for training.
let mut trainer = modelfox_core::train::Trainer::prepare(
modelfox_id::Id::generate(),
input,
&target,
config.map(PathBuf::from).as_deref(),
&mut handle_progress_event,
)
.map_err(ModelFoxError)?;
let kill_chip = modelfox_kill_chip::KillChip::new();
let train_grid_item_outputs = trainer
.train_grid(&kill_chip, &mut handle_progress_event)
.map_err(ModelFoxError)?;
let model = trainer
.test_and_assemble_model(train_grid_item_outputs, &mut handle_progress_event)
.map_err(ModelFoxError)?;

// Write the model to the output path.
let output_path = PathBuf::from(output.clone());
model.to_path(&output_path).map_err(ModelFoxError)?;

// Announce that everything worked!
eprintln!("Your model was written to {}.", output_path.display());
eprintln!(
"For help making predictions in your code, read the docs at https://www.modelfox.dev/docs."
);
eprintln!(
"To learn more about how your model works and set up production monitoring, run `modelfox app`."
);

// TODO: load the model more efficiently
Model::from_path(cls, output, None)
}

/**
Make a prediction!
@@ -306,6 +371,29 @@ impl LoadModelOptions {
}
}

#[derive(FromPyObject)]
enum FileOrArrow {
File(String),
Arrow(usize),
}

#[derive(FromPyObject)]
enum Input {
Train(FileOrArrow),
TrainAndTest((FileOrArrow, FileOrArrow)),
}

impl From<FileOrArrow> for modelfox_core::train::FileOrArrow {
fn from(value: FileOrArrow) -> modelfox_core::train::FileOrArrow {
match value {
FileOrArrow::File(file) => modelfox_core::train::FileOrArrow::File(file.into()),
FileOrArrow::Arrow(stream_ptr) => {
modelfox_core::train::FileOrArrow::Arrow(stream_ptr as *const ArrowArrayStream)
}
}
}
}

#[derive(FromPyObject)]
enum PredictInputSingleOrMultiple {
Single(PredictInput),
9 changes: 8 additions & 1 deletion languages/python/modelfox/tangram_python.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import (
Any,
cast,
Dict,
List,
Literal,
@@ -26,6 +25,14 @@ class Model:
) -> "Model": ...
@property
def id(self) -> str: ...
@classmethod
def train(
cls,
input: Union[str, Tuple[str, str]],
target: str,
output: str,
config: Optional[str] = None,
) -> "Model": ...
@overload
def predict(
self,