-
Notifications
You must be signed in to change notification settings - Fork 2.1k
[Frontend] Add support for mutable arguments and make aggregates mutable #7286
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bit concerned that the semantics are too far from python here.
|
||
def _wrap_mutable_argument(arg, node): | ||
if isinstance(arg, base_value) and getattr(arg.type, "__triton_mutable__", False): | ||
setattr(arg, "__ast__", node) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this cause issues with stacked function calls, e.g. f(a)
calls g(a)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, this works correctly. Each subsequent call attaches the expression for its argument, and these are recursively propagated up to the last caller.
@@ -506,6 +520,15 @@ def visit_List(self, node): | |||
elts = language.tuple([self.visit(elt) for elt in node.elts]) | |||
return elts | |||
|
|||
def emit_return(self, handles): | |||
for i, (name, ty) in enumerate(self.prototype.args_mut): | |||
value = self.dereference_name(name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems wrong. We should track the python object, not the value the name is currently bound to. e.g.
def f(a):
a = NewAggregate(...)
should not mutate the caller's a
value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might be able to fix this.
for i, arg in enumerate(args): | ||
if isinstance(arg, (language.dtype, float, int, bool, JITFunction)): | ||
args[i] = language.core.constexpr(arg) | ||
args_mut = [[name, None] | ||
for (name, arg) in zip(fn.arg_names, args) | ||
if arg is not None and getattr(arg.type, "__triton_mutable__", False)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if __triton_mutable__
isn't on the top level type, e.g. tuple(aggregate)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tuple will take a copy of the aggregate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about banning putting mutable values inside non-mutable ones, but in fact putting any aggregate inside another will take a copy because they aren't actually reference semantic.
if not isinstance(target, ValueDest): | ||
target = copy.deepcopy(target) | ||
target.ctx = ast.Store() | ||
self.assignTarget(target, value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if there are multiple references to the same value, e.g.
a = Aggregate(...)
b = a
f(b) # a and b should be updated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will only update b
. I might be able to fix this one.
Yes, they do differ a bit and that's something I was hoping to discuss. Let me see if there is some simple way of addressing some of your comments, but in general, implementing Python object semantics in Triton will be a significant amount of effort. Consider: @triton.jit
def foo():
x = Obj()
return x, Box(x)
@triton.jit
def entry_point():
a, b = foo()
modify(b)
use(a) If we ignore refcounting and just let all objects be immortal, the only reasonable way to codegen this is tt.func @foo() -> !tt.ptr<Obj>, !tt.ptr<Obj> {
%x = thread_local_alloc // objects are effectively "scalars" -- they live in shared memory but each thread gets its own
%b = thread_local_alloc
store %x, gep %b[0]
return %x, %b
}
tt.func @entry_point() {
%a, %b = tt.call @foo()
tt.call @modify(%b)
tt.call @use(%a)
} If the function calls aren't inlined, we have to code generate real reference semantics. Here's another example @triton.jit
def two_mutables(a, b):
a.value = 3
b.value *= 2
@triton.jit
def caller():
x = Obj()
two_mutables(x, x) Without passing in a proper reference to the object, the mutation on the non-exclusive references |
It might be possible to fix some of these semantics issues at the function level, because the backing |
This PR adds support for mutable function arguments. Mutability is a type property, and mutable arguments are implemented in the callee by appending implicit returns of the last value of mutable arguments, and in the caller by emitting writeback into the argument expression using
assignTarget
. This requires distinguishing between LValues and RValues (i.e.self.a.b
vsmake_value(123)
). Mutation of RValues is discarded since the frontend does not model RValue to LValue conversion. The writeback destination is modeled as the AST expression of the argument. This is stored on the value itself, so that mutability is mostly transparent to the rest of the frontend, and so that mutable arguments can interop with builtin functions.Using mutable arguments, this PR makes all
@aggregate
types mutable. This allows cleaning up the attention implementation and removing the pretty big footgun of forgetting to return and re-assign variables on mutation. This also enables__init__
to be a@jit
function, since it mutatesself
, further cleaning up the Gluon attention code.The PR also makes
@aggregate
make all member functions@jit
by default, and exposes it asttgl.aggregate
andtl.aggregate
respectively, since it is now feature-complete.