Skip to content

Parallelize predict subcommand #117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions crates/cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ backtrace = "0.3"
bytes = { version = "1", optional = true }
clap = { version = "3", features = ["derive"] }
colored = "2"
crossbeam-channel = "0.5"
csv = "1"
dirs = "4"
either = "1"
hyper = { version = "0.14", optional = true }
itertools = "0.10"
num = "0.4"
num_cpus = "1"
once_cell = "1"
rayon = "1.5"
serde = { version = "1", features = ["derive"] }
Expand Down
105 changes: 89 additions & 16 deletions crates/cli/predict.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use crate::PredictArgs;
use anyhow::Result;
use crossbeam_channel::Sender;
use csv::StringRecord;
use either::Either;
use itertools::Itertools;
use std::sync::Arc;
use tangram_core::predict::PredictOutput;
use tangram_core::predict::{PredictInput, PredictInputValue, PredictOptions};
use tangram_zip::zip;

Expand Down Expand Up @@ -68,25 +72,89 @@ pub fn predict(args: PredictArgs) -> Result<()> {
}
}
};

let header = reader.headers()?.to_owned();
for records in &reader.records().chunks(PREDICT_CHUNK_SIZE) {
let input: Vec<PredictInput> = records
.into_iter()
.map(|record| -> Result<PredictInput> {
let record = record?;
let input = zip!(header.iter(), record.into_iter())
.map(|(column_name, value)| {
(
column_name.to_owned(),
PredictInputValue::String(value.to_owned()),
)
let chunk_count = num_cpus::get() * 2;
let (input_tx, input_rx): (
Sender<(
Vec<StringRecord>,
Sender<Result<Vec<PredictOutput>, anyhow::Error>>,
)>,
_,
) = crossbeam_channel::bounded(chunk_count);
let (output_tx, output_rx) = crossbeam_channel::bounded(chunk_count);

let header = Arc::new(header);
let model = Arc::new(model);
let options = Arc::new(options);

let mut threads = Vec::new();

for _ in 0..num_cpus::get() {
let header = header.clone();
let model = model.clone();
let options = options.clone();
let input_rx = input_rx.clone();

threads.push(std::thread::spawn(move || {
while let Ok((records, chunk_tx)) = input_rx.recv() {
let input: Result<Vec<PredictInput>, _> = records
.into_iter()
.map(|record| -> Result<PredictInput> {
let input = zip!(header.iter(), record.into_iter())
.map(|(column_name, value)| {
(
column_name.to_owned(),
PredictInputValue::String(value.to_owned()),
)
})
.collect();
Ok(PredictInput(input))
})
.collect();
Ok(PredictInput(input))
})
.collect::<Result<_, _>>()?;
let output = tangram_core::predict::predict(&model, &input, &options);
for output in output {

let output =
input.map(|input| tangram_core::predict::predict(&model, &input, &options));

if chunk_tx.send(output).is_err() {
break;
};
}
}));
}

threads.push(std::thread::spawn(move || {
for records_chunk in &reader.records().chunks(PREDICT_CHUNK_SIZE) {
let records_chunk: Result<Vec<_>, _> = records_chunk.collect();
let records_chunk = match records_chunk {
Ok(records_chunk) => records_chunk,
Err(error) => {
let error: anyhow::Error = error.into();
let _ = output_tx.send(Err(error));
break;
}
};

// Here we create a single use channel which will allow the CSV writer
// to wait for the prediction results in-order while still allowing
// the prediction for future chunks to run in parallel.
let (chunk_tx, chunk_rx) = crossbeam_channel::bounded(1);
if let Err(error) = input_tx.send((records_chunk, chunk_tx)) {
let error: anyhow::Error = error.into();
let _ = output_tx.send(Err(error));
break;
}
if output_tx.send(Ok(chunk_rx)).is_err() {
break;
}
}
}));

while let Ok(output) = output_rx.recv() {
let chunk_rx = output?;
let outputs = chunk_rx.recv()??;

for output in outputs {
let output = match output {
tangram_core::predict::PredictOutput::Regression(output) => {
vec![output.value.to_string()]
Expand Down Expand Up @@ -129,5 +197,10 @@ pub fn predict(args: PredictArgs) -> Result<()> {
writer.write_record(&output)?;
}
}

for thread in threads {
thread.join().unwrap();
}

Ok(())
}