Skip to content

[3/n tensor engine] hello tensor engine #187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: gh/zdevito/6/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions controller/src/history.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ enum RefStatus {
/// borrows, drops etc. directly.
#[derive(Debug)]
#[allow(dead_code)]
pub(crate) struct History {
pub struct History {
/// The first incomplete Seq for each rank. This is used to determine which
/// Seqs are no longer relevant and can be purged from the history.
first_incomplete_seqs: MinVector<Seq>,
Expand Down Expand Up @@ -198,7 +198,7 @@ where
}

impl History {
pub(crate) fn new(world_size: usize) -> Self {
pub fn new(world_size: usize) -> Self {
Self {
first_incomplete_seqs: MinVector::new(vec![Seq::default(); world_size]),
min_incomplete_seq: Seq::default(),
Expand All @@ -213,23 +213,23 @@ impl History {
}

#[cfg(test)]
pub(crate) fn first_incomplete_seqs(&self) -> &[Seq] {
pub fn first_incomplete_seqs(&self) -> &[Seq] {
self.first_incomplete_seqs.vec()
}

pub(crate) fn first_incomplete_seqs_controller(&self) -> &[Seq] {
pub fn first_incomplete_seqs_controller(&self) -> &[Seq] {
self.first_incomplete_seqs_controller.vec()
}

pub(crate) fn min_incomplete_seq_reported(&self) -> Seq {
pub fn min_incomplete_seq_reported(&self) -> Seq {
self.min_incompleted_seq_controller
}

pub(crate) fn world_size(&self) -> usize {
pub fn world_size(&self) -> usize {
self.first_incomplete_seqs.len()
}

pub(crate) fn delete_invocations_for_refs(&mut self, refs: Vec<Ref>) {
pub fn delete_invocations_for_refs(&mut self, refs: Vec<Ref>) {
self.marked_for_deletion.extend(refs);

self.marked_for_deletion
Expand All @@ -251,7 +251,7 @@ impl History {
}

/// Add an invocation to the history.
pub(crate) fn add_invocation(
pub fn add_invocation(
&mut self,
seq: Seq,
uses: Vec<Ref>,
Expand Down Expand Up @@ -306,7 +306,7 @@ impl History {

/// Propagate worker error to the invocation with the given Seq. This will also propagate
/// to all seqs that depend on this seq directly or indirectly.
pub(crate) fn propagate_exception(&mut self, seq: Seq, exception: Exception) {
pub fn propagate_exception(&mut self, seq: Seq, exception: Exception) {
let mut queue = vec![seq];
let mut visited = HashSet::new();

Expand Down Expand Up @@ -364,13 +364,13 @@ impl History {
results
}

pub(crate) fn report_deadline_missed(&mut self) {
pub fn report_deadline_missed(&mut self) {
if let Some((seq, time, _)) = self.deadline {
self.deadline = Some((seq, time, true));
}
}

pub(crate) fn deadline(
pub fn deadline(
&mut self,
expected_progress: u64,
timeout: tokio::time::Duration,
Expand All @@ -397,7 +397,7 @@ impl History {
self.deadline
}

pub(crate) fn update_deadline_tracking(&mut self, rank: usize, seq: Seq) {
pub fn update_deadline_tracking(&mut self, rank: usize, seq: Seq) {
// rank_completed also calls this so that we stay up to date with client request_status messages.
// However, controller request_status messages may be ahead of the client as the client may retain invocations
// past the time completed so we should take the max
Expand All @@ -411,7 +411,7 @@ impl History {

/// Mark the given rank as completed up to but excluding the given Seq. This will also purge history for
/// any Seqs that are no longer relevant (completed on all ranks).
pub(crate) fn rank_completed(
pub fn rank_completed(
&mut self,
rank: usize,
seq: Seq,
Expand Down
2 changes: 1 addition & 1 deletion controller/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#![allow(unsafe_op_in_unsafe_fn)]

pub mod bootstrap;
mod history;
pub mod history;

use std::collections::HashMap;
use std::collections::HashSet;
Expand Down
70 changes: 70 additions & 0 deletions hyperactor_mesh/src/actor_mesh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use ndslice::Range;
use ndslice::Selection;
use ndslice::Shape;
use ndslice::ShapeError;
use ndslice::Slice;
use serde::Deserialize;
use serde::Serialize;

Expand Down Expand Up @@ -147,6 +148,75 @@ impl<'a, A: RemoteActor> RootActorMesh<'a, A> {
pub(crate) fn open_port<M: Message>(&self) -> (PortHandle<M>, PortReceiver<M>) {
self.proc_mesh.client().open_port()
}

/// Cast an [`M`]-typed message to the ranks selected by `sel`
/// in this ActorMesh.
#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`.
pub fn cast<M: RemoteMessage + Clone>(
&self,
selection: Selection,
message: M,
) -> Result<(), CastError>
where
A: RemoteHandles<Cast<M>> + RemoteHandles<IndexedErasedUnbound<Cast<M>>>,
{
let _ = metrics::ACTOR_MESH_CAST_DURATION.start(hyperactor::kv_pairs!(
"message_type" => M::typename(),
"message_variant" => message.arm().unwrap_or_default(),
));
let message = Cast {
rank: CastRank(usize::MAX),
shape: self.shape().clone(),
message,
};
let message = CastMessageEnvelope::new(
self.proc_mesh.client().actor_id().clone(),
DestinationPort::new::<A, Cast<M>>(self.name.clone()),
message,
None, // TODO: reducer typehash
)?;

self.proc_mesh.comm_actor().send(
self.proc_mesh.client(),
CastMessage {
dest: Uslice {
slice: self.shape().slice().clone(),
selection,
},
message,
},
)?;
Ok(())
}

/// Until the selection logic is more powerful, we need a way to
/// replicate the send patterns that the worker actor mesh actually does.
pub fn cast_slices<M: RemoteMessage + Clone>(
&self,
sel: Vec<Slice>,
message: M,
) -> Result<(), CastError>
where
A: RemoteHandles<Cast<M>> + RemoteHandles<IndexedErasedUnbound<Cast<M>>>,
{
let _ = metrics::ACTOR_MESH_CAST_DURATION.start(hyperactor::kv_pairs!(
"message_type" => M::typename(),
"message_variant" => message.arm().unwrap_or_default(),
));
for ref slice in sel {
for rank in slice.iter() {
let cast = Cast {
rank: CastRank(rank),
shape: self.shape().clone(),
message: message.clone(),
};
self.ranks[rank]
.send(self.proc_mesh.client(), cast)
.map_err(|err| CastError::MailboxSenderError(rank, err))?;
}
}
Ok(())
}
}

#[async_trait]
Expand Down
8 changes: 8 additions & 0 deletions hyperactor_mesh/src/proc_mesh/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,14 @@ impl ProcMesh {
&self.client
}

pub fn client_proc(&self) -> &Proc {
&self.client_proc
}

pub fn proc_id(&self) -> &ProcId {
self.client_proc.proc_id()
}

/// An event stream of proc events. Each ProcMesh can produce only one such
/// stream, returning None after the first call.
pub fn events(&mut self) -> Option<ProcEvents> {
Expand Down
1 change: 1 addition & 0 deletions monarch_extension/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ clap = { version = "4.5.38", features = ["derive", "env", "string", "unicode", "
controller = { version = "0.0.0", path = "../controller" }
hyperactor = { version = "0.0.0", path = "../hyperactor" }
hyperactor_extension = { version = "0.0.0", path = "../hyperactor_extension" }
hyperactor_mesh = { version = "0.0.0", path = "../hyperactor_mesh" }
hyperactor_multiprocess = { version = "0.0.0", path = "../hyperactor_multiprocess" }
monarch_hyperactor = { version = "0.0.0", path = "../monarch_hyperactor" }
monarch_messages = { version = "0.0.0", path = "../monarch_messages" }
Expand Down
10 changes: 8 additions & 2 deletions monarch_extension/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,17 @@ use crate::controller::PyRanks;
use crate::convert::convert;

#[pyclass(frozen, module = "monarch._rust_bindings.monarch_extension.client")]
struct WorkerResponse {
pub struct WorkerResponse {
seq: Seq,
result: Option<Result<Serialized, Exception>>,
}

impl WorkerResponse {
pub fn new(seq: Seq, result: Option<Result<Serialized, Exception>>) -> Self {
Self { seq, result }
}
}

#[pymethods]
impl WorkerResponse {
#[staticmethod]
Expand Down Expand Up @@ -510,7 +516,7 @@ pub struct DebuggerMessage {
impl DebuggerMessage {
#[new]
#[pyo3(signature = (*, debugger_actor_id, action))]
fn new(debugger_actor_id: PyActorId, action: DebuggerAction) -> PyResult<Self> {
pub fn new(debugger_actor_id: PyActorId, action: DebuggerAction) -> PyResult<Self> {
Ok(Self {
debugger_actor_id,
action,
Expand Down
6 changes: 6 additions & 0 deletions monarch_extension/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ mod client;
mod controller;
pub mod convert;
mod debugger;
mod mesh_controller;
mod panic;
mod simulator_client;
mod tensor_worker;
Expand Down Expand Up @@ -150,5 +151,10 @@ pub fn mod_init(module: &Bound<'_, PyModule>) -> PyResult<()> {
"monarch_extension.panic",
)?)?;

crate::mesh_controller::register_python_bindings(&get_or_add_new_module(
module,
"monarch_extension.mesh_controller",
)?)?;

Ok(())
}
Loading
Loading