Skip to content

[Frontend] Add support for mutable arguments and make aggregates mutable #7286

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 39 commits into
base: main
Choose a base branch
from

Conversation

Mogball
Copy link
Collaborator

@Mogball Mogball commented Jun 24, 2025

This PR adds support for mutable function arguments. Mutability is a type property, and mutable arguments are implemented in the callee by appending implicit returns of the last value of mutable arguments, and in the caller by emitting writeback into the argument expression using assignTarget. This requires distinguishing between LValues and RValues (i.e. self.a.b vs make_value(123)). Mutation of RValues is discarded since the frontend does not model RValue to LValue conversion. The writeback destination is modeled as the AST expression of the argument. This is stored on the value itself, so that mutability is mostly transparent to the rest of the frontend, and so that mutable arguments can interop with builtin functions.

Using mutable arguments, this PR makes all @aggregate types mutable. This allows cleaning up the attention implementation and removing the pretty big footgun of forgetting to return and re-assign variables on mutation. This also enables __init__ to be a @jit function, since it mutates self, further cleaning up the Gluon attention code.

The PR also makes @aggregate make all member functions @jit by default, and exposes it as ttgl.aggregate and tl.aggregate respectively, since it is now feature-complete.

@Mogball Mogball changed the title [Frontend] Add support for mutable arguments [WIP][Frontend] Add support for mutable arguments Jun 24, 2025
@Mogball Mogball changed the title [WIP][Frontend] Add support for mutable arguments [Frontend] Add support for mutable arguments and make aggregates mutable Jun 24, 2025
@Mogball Mogball requested a review from peterbell10 June 24, 2025 22:18
@Mogball Mogball marked this pull request as ready for review June 24, 2025 22:18
@Mogball Mogball requested a review from ptillet as a code owner June 24, 2025 22:18
@Mogball Mogball requested a review from ThomasRaoux June 24, 2025 22:19
Copy link
Contributor

@peterbell10 peterbell10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bit concerned that the semantics are too far from python here.


def _wrap_mutable_argument(arg, node):
if isinstance(arg, base_value) and getattr(arg.type, "__triton_mutable__", False):
setattr(arg, "__ast__", node)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this cause issues with stacked function calls, e.g. f(a) calls g(a).

Copy link
Collaborator Author

@Mogball Mogball Jun 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this works correctly. Each subsequent call attaches the expression for its argument, and these are recursively propagated up to the last caller.

@@ -506,6 +520,15 @@ def visit_List(self, node):
elts = language.tuple([self.visit(elt) for elt in node.elts])
return elts

def emit_return(self, handles):
for i, (name, ty) in enumerate(self.prototype.args_mut):
value = self.dereference_name(name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems wrong. We should track the python object, not the value the name is currently bound to. e.g.

def f(a):
    a = NewAggregate(...)

should not mutate the caller's a value.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be able to fix this.

for i, arg in enumerate(args):
if isinstance(arg, (language.dtype, float, int, bool, JITFunction)):
args[i] = language.core.constexpr(arg)
args_mut = [[name, None]
for (name, arg) in zip(fn.arg_names, args)
if arg is not None and getattr(arg.type, "__triton_mutable__", False)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if __triton_mutable__ isn't on the top level type, e.g. tuple(aggregate)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tuple will take a copy of the aggregate.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about banning putting mutable values inside non-mutable ones, but in fact putting any aggregate inside another will take a copy because they aren't actually reference semantic.

if not isinstance(target, ValueDest):
target = copy.deepcopy(target)
target.ctx = ast.Store()
self.assignTarget(target, value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if there are multiple references to the same value, e.g.

a = Aggregate(...)
b = a
f(b) # a and b should be updated

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will only update b. I might be able to fix this one.

@Mogball
Copy link
Collaborator Author

Mogball commented Jun 27, 2025

Bit concerned that the semantics are too far from python here.

Yes, they do differ a bit and that's something I was hoping to discuss. Let me see if there is some simple way of addressing some of your comments, but in general, implementing Python object semantics in Triton will be a significant amount of effort.

Consider:

@triton.jit
def foo():
    x = Obj()
    return x, Box(x)

@triton.jit
def entry_point():
    a, b = foo()
    modify(b)
    use(a)

If we ignore refcounting and just let all objects be immortal, the only reasonable way to codegen this is

tt.func @foo() -> !tt.ptr<Obj>, !tt.ptr<Obj> {
  %x = thread_local_alloc // objects are effectively "scalars" -- they live in shared memory but each thread gets its own
  %b = thread_local_alloc
  store %x, gep %b[0]
  return %x, %b
}

tt.func @entry_point() {
  %a, %b = tt.call @foo()
  tt.call @modify(%b)
  tt.call @use(%a)
}

If the function calls aren't inlined, we have to code generate real reference semantics. Here's another example

@triton.jit
def two_mutables(a, b):
    a.value = 3
    b.value *= 2

@triton.jit
def caller():
    x = Obj()
    two_mutables(x, x)

Without passing in a proper reference to the object, the mutation on the non-exclusive references a and b cannot be properly modeled.

@Mogball
Copy link
Collaborator Author

Mogball commented Jun 27, 2025

It might be possible to fix some of these semantics issues at the function level, because the backing base_value Python objects are still there and can act as a proxy for object references. But when values are passed across function boundaries, they are flattened and unflattened on the other side, making a copy of the backing Python object.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants