Skip to content

Encapsulate sharding and hashing logic for IR in XlaNode constructor #4555

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 16 commits into
base: master
Choose a base branch
from

Conversation

steventk-g
Copy link
Collaborator

@steventk-g steventk-g commented Feb 2, 2023

Building on #4286 and #4287, this PR moves all hash creation logic into XlaNode constructors, and exposes the functionality via Clone APIs defined on the DeviceData and Generic node classes. This is required to appropriately set the IR node hashes for sharded data, as we are exposed to some bugs if we modify IR nodes (and particularly hashes) in-place.

For testing, I've added some SPMD tests, including linear model training, to CI and verified that everything's passing #4645

@steventk-g steventk-g changed the title Add move hashing to constructor Move hashing to constructor Feb 2, 2023
@yeounoh yeounoh self-requested a review February 2, 2023 23:21
void XLATensor::ClearShardingSpec() {
data()->sharding = nullptr;
torch::lazy::Value ir_value = CurrentIrValue();
if (ir_value) {
// This should be a no-op if there is no sharding.
dynamic_cast<XlaNode*>(ir_value.node.get())->ClearSharding();
CreateUnshardedIrValue();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's little verbose, but I actually preferred dynamic_cast<XlaNode*>(ir_value.node.get())->XYZ to show that the tensor method ClearShardingSpec calls the underlying node's method.

XLA_CHECK(old_value.node->op() == xla_device_data)
<< "Can only clear sharding for device data";
torch::lazy::Value new_ir =
dynamic_cast<DeviceData*>(old_value.node.get())->Clone();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we need to implement a method in each IR Node to return a unsharded IR hmm,..

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only need to add it for the node types which will be cloned at runtime, but I'm not sure if that would be everything. We can always leave it with the base implementation (throwing a not implemented error) otherwise.

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly lGTM, thanks!

@steventk-g steventk-g force-pushed the hash-with-device-data branch 4 times, most recently from e75745c to 1f36449 Compare February 8, 2023 02:18
@steventk-g steventk-g force-pushed the hash-with-device-data branch 9 times, most recently from 7166064 to 6223063 Compare February 16, 2023 21:56
@steventk-g steventk-g changed the title Move hashing to constructor Encapsulate sharding and hashing logic for IR in XlaNode constructor Feb 16, 2023
@steventk-g steventk-g force-pushed the hash-with-device-data branch from 6223063 to 531ca82 Compare February 17, 2023 05:10
@steventk-g steventk-g requested a review from JackCaoG February 19, 2023 21:42
@steventk-g steventk-g force-pushed the hash-with-device-data branch from 531ca82 to e66f323 Compare February 19, 2023 22:18
@yeounoh
Copy link
Contributor

yeounoh commented Feb 22, 2023

Let's mark the PR (DO NOT MERGE YET) and/or add a comment about the ongoing investigation.

@@ -48,15 +48,30 @@ XlaNode::XlaNode(torch::lazy::OpKind op, torch::lazy::OpList operands,
: torch::lazy::Node(op, operands, std::move(shapes), num_outputs),
xla_shape_(std::move(xla_shape)),
node_hash_(torch::lazy::HashCombine(op.hash(), hash_seed)),
dag_hash_(GetOperandHashes(operands, node_hash_)) {}
dag_hash_(GetOperandHashes(operands, node_hash_)),
oplist_(std::optional<torch::lazy::OpList>(operands)) {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to construct OpList from operands_ in lazy::Node? We are passing operands as input to the parent constructor and also to oplist_. Hopefully, that will free us from adding this additional constructor argument and possibly be safer, as we don't have to maintain two copies of operands_ in two different formats. operands_ is protected, so should be accessible.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I followed your suggestion here to move the logic into XlaNode. Should I go back to using lazy::Node like I was originally? https://github.com/pytorch/pytorch/pull/92579/files#r1105075942

Copy link
Contributor

@yeounoh yeounoh Feb 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No circular dependency here. Instead of doing it in lazy::Node, can we just take operands_ from the parent (lazy::Node) and construct OpList in XlaNode?

Also, re-iterating @alanwaketan 's suggestion, chaning the parent's operands_ type to std::vector<Value> operands_ could make it simpler as well. But I am not comfortable dropping the NodePtr here... would keeping a copy of Value holds the ownership of the original node? I see a move constructor there, so guess that would be OK... ?cc @alanwaketan @JackCaoG

Copy link
Contributor

@yeounoh yeounoh Feb 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Synced offline with @alanwaketan , node(node) from torch::lazy::Value ctor would hold the reference to Node, thanks.

@@ -101,11 +121,14 @@ XlaNode::XlaNode(torch::lazy::OpKind op, torch::lazy::Shape shape,
: torch::lazy::Node(op, shape, num_outputs),
xla_shape_(std::move(xla_shape)),
node_hash_(GetOpHash(op, xla_shape_, hash_seed)),
dag_hash_(node_hash_) {}
dag_hash_(node_hash_),
oplist_(std::nullopt) {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Along the same line with my previous comment, we are introducing a gap here between operands_, operands_as_outputs_ and oplist_ (empty). And we probably want to avoid this additional null initialization if possible.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was planning to close the upstream PR and localize these changes in XlaNode. Do you think we should move it back to lazy::Node like what I had before?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as "null initialization", it's a nullopt for optionals so it should be safe.


// Experimental sharding annotation attached to the IR node.
// TODO(yeounoh): make sure that view update doesn't reset this.
std::shared_ptr<xla::OpSharding> output_sharding_ = nullptr;

std::optional<torch::lazy::OpList> oplist_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess a major issue we need to work around would be keeping the cross references between ops up-to-date, given that the graph is a tree. It might be doable with additional bookkeeping and traversals, but not sure how expensive that would be.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the best way to do it will be storing ops as smart pointers, either shared or weak, so that we can access them within an oplist by index while also being able to replace the node with a new node.

DeviceData::DeviceData(std::shared_ptr<torch::lazy::BackendData> data,
torch::lazy::OpList ops, xla::Shape xla_shape,
xla::OpSharding sharding)
: XlaNode(xla_device_data, ops, {data->shape()}, xla_shape, sharding,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need xla_shape as input? Could we do this like in the other constructor UnwrapXlaData(data)->shape()?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can experiment with it. Any reason to exclude xla_shape?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally, if it's not needed or redundant, then I would not add more arguments to the constructor.

std::string DeviceData::ToString() const {
std::stringstream ss;
ss << XlaNode::ToString() << ", device=" << data_->device();
return ss.str();
}

torch::lazy::NodePtr DeviceData::Clone() const { return Clone({}); }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this calling?

// TODO(https://github.com/pytorch/xla/issues/4567) Remove this clone method
  virtual torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const;

Let's add some comments why we are passing an empty oplist argument here.

Will this be removed/updated when we address #4567?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that you have a similar comment below,

// TODO(steventk) Ideally, we are either passing the actual oplist, or we
  // don't pass ops at all and use another XlaNode constructor

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add the comment to where we make the call too

torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override;

torch::lazy::NodePtr CloneWithSharding(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens to the other node types? We can attach sharding to any.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless they have this method implemented it will crash at runtime. So far, only generic and device data are effected. If we need to support sharding for any node, we will need to add these functions to each node class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.. I think we have data & linear model sharding test cases. Could you try with conv layer as well, to see if that would also require touching other nodes? I hope we can cover most of common cases generic & data.

xla::OpSharding sharding) const {
// TODO(steventk) Right now, we drop the operands on clone, because the oplist
// memory becomes unsafe when we clone the other nodes and they go out of
// scope. Instead of initializing the oplist below to nullopt, we want to use
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thiking out loud, constructing the oplist from the parent's operands_, if possible, could simplify things. As long as the operands are alive, they shoulnd't go out of scope. For the recreated node, we should propagate that change (old->new), though.

Copy link
Collaborator Author

@steventk-g steventk-g Feb 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

constructing the oplist from the parent's operands_

That makes sense to me, but I'm not sure how to get a reference to the parent

@@ -284,6 +284,26 @@ void XLATensor::SetXlaData(torch::lazy::BackendDataPtr handle, bool sync) {
}
}

void XLATensor::CreateShardedIrValue(const ShardingSpecPtr sharding_spec) {
torch::lazy::Value old_value = GetIrValue();
XLA_CHECK(old_value && old_value.node != nullptr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can do

if (old_value)

instead.

@@ -284,6 +284,26 @@ void XLATensor::SetXlaData(torch::lazy::BackendDataPtr handle, bool sync) {
}
}

void XLATensor::CreateShardedIrValue(const ShardingSpecPtr sharding_spec) {
torch::lazy::Value old_value = GetIrValue();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use CurrentIrValue() ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was getting crashes with Current, it looks like ir_value is not always defined when we call this method.

"exist";
torch::lazy::Value new_ir = dynamic_cast<XlaNode*>(old_value.node.get())
->CloneWithSharding(sharding_spec->sharding);
data()->ir_value = std::move(new_ir);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some node's operands could be pointing to the old_ir.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think this is the root cause of the problem I was seeing before. If I can keep the operands up to date, the clone should work smoothly.

@@ -512,8 +512,7 @@ XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors(
// sharding takes the precedence as the source of the truth.
XLATensor::ShardingSpecPtr sharding = tensors[i]->sharding_spec();
if (sharding) {
dynamic_cast<XlaNode*>(ir_value.node.get())
->SetSharding(sharding->sharding);
tensors[i]->CreateShardedIrValue(sharding);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alanwaketan
Copy link
Collaborator

@steventk-g Can you remind me what are the problems we are trying to solve? And if there is a test case to reproduce the problem?

@steventk-g
Copy link
Collaborator Author

@steventk-g Can you remind me what are the problems we are trying to solve? And if there is a test case to reproduce the problem?

@alanwaketan The original runtime bug is solved and tested in #4287, this PR is intended to clean up the IR hash logic so that the hash for a specific node isn't mutable.

@alanwaketan
Copy link
Collaborator

@steventk-g Can you remind me what are the problems we are trying to solve? And if there is a test case to reproduce the problem?

@alanwaketan The original runtime bug is solved and tested in #4287, this PR is intended to clean up the IR hash logic so that the hash for a specific node isn't mutable.

You mentioned some bugs are exposed if we are not taking the new approach. That's what I'm curious about. Thanks for the briefing.

@ysiraichi ysiraichi added DO_NOT_MERGE Not for merging. and removed DO_NOT_MERGE_YET labels Mar 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
DO_NOT_MERGE Not for merging.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants