Skip to content

Allow customization of the subscript operator for triton values #7239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

Anstow
Copy link
Contributor

@Anstow Anstow commented Jun 19, 2025

This allows users to provide custom setitem and getitem functions in order to override the subscript operator for their triton classes.

New contributor declaration

  • I am not making a trivial change, such as fixing a typo in a comment.

  • I have written a PR description following these
    rules.

  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • I have added tests.
      • /test for lit tests
      • /unittest for C++ tests
      • /python/test for end-to-end tests
    • This PR does not need a test because FILL THIS IN.
  • Select one of the following.

    • I have not added any lit tests.
    • The lit tests I have added follow these best practices,
      including the "tests should be minimal" section. (Usually running Python code
      and using the instructions it generates is not minimal.)

@Anstow Anstow requested a review from ptillet as a code owner June 19, 2025 14:15
@Anstow
Copy link
Contributor Author

Anstow commented Jun 23, 2025

@Mogball are you the right person to look at this?

Copy link
Collaborator

@Mogball Mogball left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not obvious to me why this is necessary. Doesn't lhs.__getattr__ already invoke the right method?

@Anstow
Copy link
Contributor Author

Anstow commented Jun 24, 2025

It's not obvious to me why this is necessary. Doesn't lhs.__getattr__ already invoke the right method?

Currently, it invokes the right method but it doesn't let you customize it with the _semantic and _generator arguments or make it a bound jit function.

@Mogball
Copy link
Collaborator

Mogball commented Jun 24, 2025

It's not obvious to me why this is necessary. Doesn't lhs.__getattr__ already invoke the right method?

Currently, it invokes the right method but it doesn't let you customize it with the _semantic and _generator arguments or make it a bound jit function.

Makes sense. Does this mean that other binop methods will need to go through the same call_Function path?

@Mogball
Copy link
Collaborator

Mogball commented Jun 24, 2025

Thanks for the frontend contributions by the way. They're much appreciated

@Anstow
Copy link
Contributor Author

Anstow commented Jun 25, 2025

It's not obvious to me why this is necessary. Doesn't lhs.__getattr__ already invoke the right method?

Currently, it invokes the right method but it doesn't let you customize it with the _semantic and _generator arguments or make it a bound jit function.

Makes sense. Does this mean that other binop methods will need to go through the same call_Function path?

That's what I was thinking, yes.

return lhs[slices]

def visit_Subscript_Store(self, node, value):
assert isinstance(node.ctx, ast.Store)
lhs = self.visit(node.value)
slices = self.visit(node.slice)
assert isinstance(lhs, language.tuple)
lhs.__setitem__(slices, value)
fn = BoundJITMethod(lhs, lhs.__setitem__) if isinstance(lhs.__setitem__, JITFunction) else lhs.__setitem__
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little frustrated by this. I would have thought that __setitem__ would be a bound method containing a JITFunction however this doesn't appear to be the case (it's just a JITFunction) so I need to convert it into a BoundJITMethod in order to call it. I'm not sure if this is an issue cause by the _argegate wrapper on this is just an issue with python. Do you have any ideas?

Copy link
Collaborator

@Mogball Mogball Jun 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method unfortunately can't be bound to the object until Triton parses the code, because it needs to be bound to the MLIR handle. I had changed visit_Attribute to emit obj.method where method is a JITFunction to bind obj into it, which is different than Python which binds upon object construction.

The methods have to be bound during visit_Attribute because Triton values are not reference semantic. This does mean that something like

f = obj.method
obj.x = 1234
f()

Would not work as one would expect, if it was even valid code to begin with.

You can probably wrap this in a call_Method helper since it could end up being quite common.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense thanks for explaining. I wonder if we can get closer to python by making @jit return a function rather than an object. I'll play around and make a follow up draft PR if that idea goes anywhere. For now I'll make a call_Method helper.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that idea does work. I'll make a follow up draft PR once this lands.

Comment on lines +1257 to 1262
extra_kwargs = dict()
sig = inspect.signature(fn)
if '_semantic' in sig.parameters:
extra_kwargs["_semantic"] = self.semantic
if '_generator' in sig.parameters:
extra_kwargs['_generator'] = self
Copy link
Contributor Author

@Anstow Anstow Jun 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that I've made both _semantic and _generator optional. Do we need the if statement at all? I guess the only differences are:

  • how the results are wrapped (although this probably should be made consistent)
  • whether the error is wrapped in CompilationError (again this should probably be made consistent)
  • Theoretically the user might have used _semantic and _generator as arguments without intending them to be overwritten by us which could lead to weird errors. Although if we were really concerned about this I think the fix would be to mangle the names even more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants