-
Notifications
You must be signed in to change notification settings - Fork 561
Encapsulate sharding and hashing logic for IR in XlaNode constructor #4555
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
void XLATensor::ClearShardingSpec() { | ||
data()->sharding = nullptr; | ||
torch::lazy::Value ir_value = CurrentIrValue(); | ||
if (ir_value) { | ||
// This should be a no-op if there is no sharding. | ||
dynamic_cast<XlaNode*>(ir_value.node.get())->ClearSharding(); | ||
CreateUnshardedIrValue(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's little verbose, but I actually preferred dynamic_cast<XlaNode*>(ir_value.node.get())->XYZ
to show that the tensor method ClearShardingSpec
calls the underlying node's method.
torch_xla/csrc/tensor.cpp
Outdated
XLA_CHECK(old_value.node->op() == xla_device_data) | ||
<< "Can only clear sharding for device data"; | ||
torch::lazy::Value new_ir = | ||
dynamic_cast<DeviceData*>(old_value.node.get())->Clone(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like we need to implement a method in each IR Node to return a unsharded IR hmm,..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We only need to add it for the node types which will be cloned at runtime, but I'm not sure if that would be everything. We can always leave it with the base implementation (throwing a not implemented error) otherwise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly lGTM, thanks!
e75745c
to
1f36449
Compare
7166064
to
6223063
Compare
6223063
to
531ca82
Compare
531ca82
to
e66f323
Compare
Let's mark the PR (DO NOT MERGE YET) and/or add a comment about the ongoing investigation. |
@@ -48,15 +48,30 @@ XlaNode::XlaNode(torch::lazy::OpKind op, torch::lazy::OpList operands, | |||
: torch::lazy::Node(op, operands, std::move(shapes), num_outputs), | |||
xla_shape_(std::move(xla_shape)), | |||
node_hash_(torch::lazy::HashCombine(op.hash(), hash_seed)), | |||
dag_hash_(GetOperandHashes(operands, node_hash_)) {} | |||
dag_hash_(GetOperandHashes(operands, node_hash_)), | |||
oplist_(std::optional<torch::lazy::OpList>(operands)) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to construct OpList from operands_
in lazy::Node
? We are passing operands
as input to the parent constructor and also to oplist_
. Hopefully, that will free us from adding this additional constructor argument and possibly be safer, as we don't have to maintain two copies of operands_
in two different formats. operands_
is protected, so should be accessible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I followed your suggestion here to move the logic into XlaNode. Should I go back to using lazy::Node like I was originally? https://github.com/pytorch/pytorch/pull/92579/files#r1105075942
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No circular dependency here. Instead of doing it in lazy::Node
, can we just take operands_
from the parent (lazy::Node
) and construct OpList in XlaNode
?
Also, re-iterating @alanwaketan 's suggestion, chaning the parent's operands_
type to std::vector<Value> operands_
could make it simpler as well. But I am not comfortable dropping the NodePtr
here... would keeping a copy of Value
holds the ownership of the original node? I see a move constructor there, so guess that would be OK... ?cc @alanwaketan @JackCaoG
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Synced offline with @alanwaketan , node(node)
from torch::lazy::Value
ctor would hold the reference to Node
, thanks.
@@ -101,11 +121,14 @@ XlaNode::XlaNode(torch::lazy::OpKind op, torch::lazy::Shape shape, | |||
: torch::lazy::Node(op, shape, num_outputs), | |||
xla_shape_(std::move(xla_shape)), | |||
node_hash_(GetOpHash(op, xla_shape_, hash_seed)), | |||
dag_hash_(node_hash_) {} | |||
dag_hash_(node_hash_), | |||
oplist_(std::nullopt) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Along the same line with my previous comment, we are introducing a gap here between operands_
, operands_as_outputs_
and oplist_
(empty). And we probably want to avoid this additional null initialization if possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was planning to close the upstream PR and localize these changes in XlaNode. Do you think we should move it back to lazy::Node like what I had before?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as "null initialization", it's a nullopt for optionals so it should be safe.
|
||
// Experimental sharding annotation attached to the IR node. | ||
// TODO(yeounoh): make sure that view update doesn't reset this. | ||
std::shared_ptr<xla::OpSharding> output_sharding_ = nullptr; | ||
|
||
std::optional<torch::lazy::OpList> oplist_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess a major issue we need to work around would be keeping the cross references between ops up-to-date, given that the graph is a tree. It might be doable with additional bookkeeping and traversals, but not sure how expensive that would be.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the best way to do it will be storing ops as smart pointers, either shared or weak, so that we can access them within an oplist by index while also being able to replace the node with a new node.
DeviceData::DeviceData(std::shared_ptr<torch::lazy::BackendData> data, | ||
torch::lazy::OpList ops, xla::Shape xla_shape, | ||
xla::OpSharding sharding) | ||
: XlaNode(xla_device_data, ops, {data->shape()}, xla_shape, sharding, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need xla_shape
as input? Could we do this like in the other constructor UnwrapXlaData(data)->shape()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can experiment with it. Any reason to exclude xla_shape?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally, if it's not needed or redundant, then I would not add more arguments to the constructor.
std::string DeviceData::ToString() const { | ||
std::stringstream ss; | ||
ss << XlaNode::ToString() << ", device=" << data_->device(); | ||
return ss.str(); | ||
} | ||
|
||
torch::lazy::NodePtr DeviceData::Clone() const { return Clone({}); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this calling?
// TODO(https://github.com/pytorch/xla/issues/4567) Remove this clone method
virtual torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const;
Let's add some comments why we are passing an empty oplist
argument here.
Will this be removed/updated when we address #4567?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that you have a similar comment below,
// TODO(steventk) Ideally, we are either passing the actual oplist, or we
// don't pass ops at all and use another XlaNode constructor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll add the comment to where we make the call too
torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; | ||
|
||
torch::lazy::NodePtr CloneWithSharding( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens to the other node types? We can attach sharding to any.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unless they have this method implemented it will crash at runtime. So far, only generic and device data are effected. If we need to support sharding for any node, we will need to add these functions to each node class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see.. I think we have data & linear model sharding test cases. Could you try with conv layer as well, to see if that would also require touching other nodes? I hope we can cover most of common cases generic & data.
xla::OpSharding sharding) const { | ||
// TODO(steventk) Right now, we drop the operands on clone, because the oplist | ||
// memory becomes unsafe when we clone the other nodes and they go out of | ||
// scope. Instead of initializing the oplist below to nullopt, we want to use |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thiking out loud, constructing the oplist
from the parent's operands_
, if possible, could simplify things. As long as the operands are alive, they shoulnd't go out of scope. For the recreated node, we should propagate that change (old->new), though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
constructing the oplist from the parent's operands_
That makes sense to me, but I'm not sure how to get a reference to the parent
@@ -284,6 +284,26 @@ void XLATensor::SetXlaData(torch::lazy::BackendDataPtr handle, bool sync) { | |||
} | |||
} | |||
|
|||
void XLATensor::CreateShardedIrValue(const ShardingSpecPtr sharding_spec) { | |||
torch::lazy::Value old_value = GetIrValue(); | |||
XLA_CHECK(old_value && old_value.node != nullptr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can do
if (old_value)
instead.
@@ -284,6 +284,26 @@ void XLATensor::SetXlaData(torch::lazy::BackendDataPtr handle, bool sync) { | |||
} | |||
} | |||
|
|||
void XLATensor::CreateShardedIrValue(const ShardingSpecPtr sharding_spec) { | |||
torch::lazy::Value old_value = GetIrValue(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use CurrentIrValue()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was getting crashes with Current, it looks like ir_value is not always defined when we call this method.
"exist"; | ||
torch::lazy::Value new_ir = dynamic_cast<XlaNode*>(old_value.node.get()) | ||
->CloneWithSharding(sharding_spec->sharding); | ||
data()->ir_value = std::move(new_ir); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some node's operands could be pointing to the old_ir.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I think this is the root cause of the problem I was seeing before. If I can keep the operands up to date, the clone should work smoothly.
@@ -512,8 +512,7 @@ XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors( | |||
// sharding takes the precedence as the source of the truth. | |||
XLATensor::ShardingSpecPtr sharding = tensors[i]->sharding_spec(); | |||
if (sharding) { | |||
dynamic_cast<XlaNode*>(ir_value.node.get()) | |||
->SetSharding(sharding->sharding); | |||
tensors[i]->CreateShardedIrValue(sharding); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@steventk-g Can you remind me what are the problems we are trying to solve? And if there is a test case to reproduce the problem? |
@alanwaketan The original runtime bug is solved and tested in #4287, this PR is intended to clean up the IR hash logic so that the hash for a specific node isn't mutable. |
You mentioned some bugs are exposed if we are not taking the new approach. That's what I'm curious about. Thanks for the briefing. |
Building on #4286 and #4287, this PR moves all hash creation logic into XlaNode constructors, and exposes the functionality via Clone APIs defined on the DeviceData and Generic node classes. This is required to appropriately set the IR node hashes for sharded data, as we are exposed to some bugs if we modify IR nodes (and particularly hashes) in-place.
For testing, I've added some SPMD tests, including linear model training, to CI and verified that everything's passing #4645