diff --git a/autoparallel/apply_sharding.py b/autoparallel/apply_sharding.py index 1f50abc..51de0f1 100644 --- a/autoparallel/apply_sharding.py +++ b/autoparallel/apply_sharding.py @@ -250,6 +250,12 @@ def shard_node_given_placements(node, sharding_placement, *, meta: bool): mesh = tgt_spec.mesh # all tensors start as replicated curr_placement = (Replicate(),) * mesh.ndim + if "val" not in node.meta: + # for non-tensor inputs, they are considered as being + # baked in the graph, so we don't need to do anything + # and just return a dummy value + assert len(node.users) == 0 + return "arbitrary value" tensor = node.meta["val"] ctx: Any @@ -303,7 +309,7 @@ def _get_inductor_decomp_table(): def apply_sharding_to_model(gm, sharding_placement, params_spec, buffers_spec): args = shard_nodes_given_placements(gm, sharding_placement) - local_args = [arg.to_local() for arg in args] + local_args = tree_map_only(DTensor, lambda x: x.to_local(), args) decomp_table = _get_inductor_decomp_table() # run with DTensor to apply the collectives given the graph diff --git a/autoparallel/cast_parametrization.py b/autoparallel/cast_parametrization.py index 36574ed..24c1fdf 100644 --- a/autoparallel/cast_parametrization.py +++ b/autoparallel/cast_parametrization.py @@ -187,6 +187,8 @@ def apply_dtype_cast(model, mp_policy: MixedPrecisionPolicy): class DTypeCastModule(torch.nn.Module): def forward(self, *args, **kwargs): def cast_fn(x): + if not isinstance(x, torch.Tensor): + return x if not torch.is_floating_point(x): return x return x.to(self._mp_policy.param_dtype) @@ -196,6 +198,8 @@ def cast_fn(x): output = super().forward(*args, **kwargs) def cast_out_fn(x): + if not isinstance(x, torch.Tensor): + return x return x.to(self._mp_policy.output_dtype) output = tree_map(cast_out_fn, output) diff --git a/autoparallel/graph_utils.py b/autoparallel/graph_utils.py index d5b1eb7..30064b4 100644 --- a/autoparallel/graph_utils.py +++ b/autoparallel/graph_utils.py @@ -52,7 +52,10 @@ def update_joint_with_descriptors( """ # TODO: should we upstream a util like this? placeholders = [n for n in updated_gm.graph.nodes if n.op == "placeholder"] - new_local_args = [n.meta["val"] for n in placeholders] + # assume if "val" is not present in meta, then it's a non-tensor input + # and there is no sharding associated with it and we can just forward + # the original input + new_local_args = [n.meta.get("val", None) for n in placeholders] joint_with_descriptors.graph_module = updated_gm joint_with_descriptors._aot_graph_capture.graph_module = updated_gm @@ -60,8 +63,10 @@ def update_joint_with_descriptors( for orig, new in zip(joint_with_descriptors._aot_state.flat_args, new_local_args): if isinstance(orig, torch.nn.Parameter): new_flat_args.append(torch.nn.Parameter(new)) - else: + elif new is not None: new_flat_args.append(new) + else: + new_flat_args.append(orig) tangent_idx = len(joint_with_descriptors._aot_state.flat_args) new_local_tangents = new_local_args[tangent_idx:] diff --git a/autoparallel/optimize_sharding.py b/autoparallel/optimize_sharding.py index 5e1dca5..c0e2092 100644 --- a/autoparallel/optimize_sharding.py +++ b/autoparallel/optimize_sharding.py @@ -157,6 +157,18 @@ def build_sharding_metadata(self): strats = {} for node in self.graph.nodes: if node.op == "placeholder": + if node.meta.get("val", None) is None: + # For non-tensor inputs, they are considered as being + # replicated across all ranks. Given that those inputs + # seems to have been baked into the graph, we don't + # actually will use this OpStrategy + strats[node] = _create_all_options(self.mesh, ()) + # for now, seems like non-tensor inputs are baked in the graph + # so let's assert that this is indeed the case + assert ( + len(node.users) == 0 + ), f"{node} nas {len(node.users)}, expected 0" + continue strats[node] = _create_all_options( self.mesh, node.meta["val"].shape, tensor=node.meta["val"] ) @@ -828,9 +840,14 @@ def add_sharded_input_constraint( if input_placements is not None: mut_ips = {i: p for i, p in enumerate(input_placements)} - for desc, (node, grad_node) in get_plain_input_and_grad_nodes( - self.graph - ).items(): + inputs_and_grads = get_plain_input_and_grad_nodes(self.graph) + if mut_ips is not None and len(mut_ips) != len(inputs_and_grads): + raise ValueError( + f"Expected to have {len(inputs_and_grads)} " + f"input placements, got {len(mut_ips)}" + ) + + for desc, (node, grad_node) in inputs_and_grads.items(): if input_placements is None: placement = None else: @@ -838,6 +855,10 @@ def add_sharded_input_constraint( assert mut_ips is not None placement = mut_ips.pop(desc.idx) + if placement is None and "val" not in node.meta: + # this is a non-tensor input, we don't do anything about it + continue + self.add_node_constraint( node, placement, constraint_name="input_constraint" ) diff --git a/autoparallel/propagation_rules.py b/autoparallel/propagation_rules.py index 3815569..e8e8375 100644 --- a/autoparallel/propagation_rules.py +++ b/autoparallel/propagation_rules.py @@ -660,31 +660,6 @@ def constant_pad_nd_rule(mesh, op_schema): return OpStrategy(filtered_strats) -@register_opschema_rule(torch.ops.aten.split.Tensor) -def split_rule(mesh, op_schema): - strat = op_schema.args_schema - op = torch.ops.aten.split.Tensor - from torch.distributed.tensor._ops._tensor_ops import split_rule - - res = [] - oo = [] - for i, ss in enumerate(strat[0].strategies): - ispec = ss.input_specs[0] - assert ss.output_spec == ispec - o = split_rule(OpSchema(op, (ispec, strat[1], strat[2]), {})) - # res.append(o) - oo.append(o) - if o.output_spec is not None: - s = OpSpec(o.output_spec, input_specs=(ispec,)) - s.redistribute_cost = [[math.inf] * len(ss.redistribute_cost[0])] - # s.redistribute_cost = [[0.0] * len(ss.redistribute_cost[0])] - s.redistribute_cost[0][i] = 0.0 - res.append(s) - - out_strat = OpStrategy(res) - return out_strat - - @register_opschema_rule(torch.ops.aten._unsafe_index.Tensor) def _unsafe_index_rule(mesh, op_schema): raise NotImplementedError() diff --git a/tests/test_api.py b/tests/test_api.py index 29e010f..29e9c54 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -117,6 +117,51 @@ def input_fn(): ) +def test_non_tensor_input(device_mesh_1d): + dim = 128 + + class Model(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + self.linear = nn.Linear(dim, dim) + + def forward(self, x, input_dim: int): + return self.linear(x).chunk(2, dim=input_dim) + + def init_weights(self): + dim = self.dim + self.linear.weight = torch.nn.Parameter(torch.ones(dim, dim) * 9.0) + with torch.no_grad(): + self.linear.bias.fill_(98.6) + + def input_fn(): + b = 512 + inputs = torch.rand(b, dim, device="cuda") + input_dim = 1 + return (inputs, input_dim) + + with torch.device("meta"): + model = Model(dim) + with AutoParallel( + model, + input_fn, + device_mesh_1d, + ) as autop: + x_sharding = (Shard(0),) + autop.add_input_constraints([x_sharding, None]) + sharding_placement = autop.optimize_placement() + + parallel_mod = autop.apply_placement(sharding_placement) + parallel_mod.to_empty(device="cuda") + parallel_mod.init_weights() + placeholders = autop.gm.graph.find_nodes(op="placeholder") + non_tensor_input = placeholders[3] + assert sharding_placement[non_tensor_input].output_specs.placements == ( + Replicate(), + ) + + def test_fx_graph_annotate(device_mesh_1d): dim = 128