Skip to content

Commit

Permalink
Chase down unwrap() in utils (#107)
Browse files Browse the repository at this point in the history
  • Loading branch information
tyb0807 authored Nov 16, 2024
1 parent f5f8fa9 commit ee7ca20
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 18 deletions.
2 changes: 1 addition & 1 deletion rs/aggregator/src/node_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ impl NodeManager {
}

pub async fn check_for_update(&self) {
let latest_version = get_latest_version(&self.config_path);
let latest_version = get_latest_version(&self.config_path).unwrap();
if latest_version > self.nodes.read().await.clone().version {
self.load_version(latest_version).await;
}
Expand Down
2 changes: 1 addition & 1 deletion rs/aggregator/src/shard_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl ShardManager {
}

pub async fn check_for_update(&self) {
let latest_version = get_latest_version(&self.config_directory);
let latest_version = get_latest_version(&self.config_directory).unwrap();
if latest_version > self.config.read().await.version {
self.load_version(latest_version).await;
} else {
Expand Down
2 changes: 1 addition & 1 deletion rs/index_server/src/index_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ impl IndexManager {
}

pub async fn check_for_update(&mut self) {
let latest_version = get_latest_version(&self.config_path);
let latest_version = get_latest_version(&self.config_path).unwrap();
if latest_version > self.latest_version {
info!("New version available: {}", latest_version);
let latest_config_path = format!("{}/version_{}", self.config_path, latest_version);
Expand Down
25 changes: 16 additions & 9 deletions rs/utils/src/io.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::fs::{read_dir, File};
use std::io::{BufReader, BufWriter, Read, Write};

use anyhow::Result;
use anyhow::{anyhow, Result};

/// Convenient wrapper for going from io::Result<usize> to Result<usize, String>
pub fn wrap_write(writer: &mut BufWriter<&mut File>, buf: &[u8]) -> Result<usize> {
Expand All @@ -10,12 +10,12 @@ pub fn wrap_write(writer: &mut BufWriter<&mut File>, buf: &[u8]) -> Result<usize

/// Read file and append to the writer
pub fn append_file_to_writer(path: &str, writer: &mut BufWriter<&mut File>) -> Result<usize> {
let input_file = File::open(path).unwrap();
let input_file = File::open(path)?;
let mut buffer_reader = BufReader::new(&input_file);
let mut buffer: [u8; 4096] = [0; 4096];
let mut written = 0;
loop {
let read = buffer_reader.read(&mut buffer).unwrap();
let read = buffer_reader.read(&mut buffer)?;
written += wrap_write(writer, &buffer[0..read])?;
if read < 4096 {
break;
Expand All @@ -24,22 +24,29 @@ pub fn append_file_to_writer(path: &str, writer: &mut BufWriter<&mut File>) -> R
Ok(written)
}

pub fn get_latest_version(config_path: &str) -> u64 {
pub fn get_latest_version(config_path: &str) -> Result<u64> {
// List all files in the directory
let mut latest_version = 0;
for entry in read_dir(config_path).unwrap() {
for entry in read_dir(config_path)? {
let entry = entry.unwrap();
let path = entry.path();
let filename = path.file_name().unwrap().to_str().unwrap();
let filename = path
.file_name()
.unwrap()
.to_str()
.ok_or_else(|| anyhow!("Cannot get filename"))?;
if filename.starts_with("version_") {
let version = filename.split("_").last().unwrap();
let version = version.parse::<u64>().unwrap();
let version = filename
.split("_")
.last()
.ok_or_else(|| anyhow!("Cannot get version"))?;
let version = version.parse::<u64>()?;
if version > latest_version {
latest_version = version;
}
}
}
latest_version
Ok(latest_version)
}

// Test
Expand Down
12 changes: 9 additions & 3 deletions rs/utils/src/kmeans_builder/kmeans_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,9 @@ mod tests {
.collect();

let kmeans = KMeansBuilder::new(3, 100, 1e-4, 2, KMeansVariant::Lloyd);
let result = kmeans.fit(flattened_data).unwrap();
let result = kmeans
.fit(flattened_data)
.expect("KMeans run should succeed");

assert_eq!(kmeans.num_cluters, 3);
assert_eq!(kmeans.max_iter, 100);
Expand Down Expand Up @@ -256,7 +258,9 @@ mod tests {
.cloned()
.collect();
let kmeans = KMeansBuilder::new(3, 100, 10000.0, 2, KMeansVariant::Lloyd);
let result = kmeans.fit(flattened_data).unwrap();
let result = kmeans
.fit(flattened_data)
.expect("KMeans run should succeed");

assert_eq!(result.centroids.len(), 3 * 2);
assert_eq!(result.assignments[0], result.assignments[3]);
Expand Down Expand Up @@ -288,7 +292,9 @@ mod tests {
.cloned()
.collect();
let kmeans = KMeansBuilder::new(3, 100, 0.0, 2, KMeansVariant::Lloyd);
let result = kmeans.fit(flattened_data).unwrap();
let result = kmeans
.fit(flattened_data)
.expect("KMeans run should succeed");

assert_eq!(result.centroids.len(), 3 * 2);
assert_eq!(result.assignments[0], result.assignments[3]);
Expand Down
7 changes: 4 additions & 3 deletions rs/utils/src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ mod tests {
}

// Write to /tmp/dataset.bin
let mut file = std::fs::File::create("/tmp/dataset.bin").unwrap();
let mut file = std::fs::File::create("/tmp/dataset.bin").expect("File should be created");
for result in results {
for i in 0..dimension {
file.write_all(&result[i].to_le_bytes()).unwrap();
file.write_all(&result[i].to_le_bytes())
.expect("Write should succeed");
}
}
file.flush().unwrap();
file.flush().expect("Flush should succeed");
}
}

0 comments on commit ee7ca20

Please sign in to comment.