Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update structured-to-memref pass to support the new pass pipeline #217

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

nhat-nguyen
Copy link
Collaborator

@nhat-nguyen nhat-nguyen commented Jan 7, 2025

This PR simplifies the structured-to-memref pass responsible for converting structured triton load / store ops to memref load / store ops. This is part of the work to allow triton-shared to lower gather / scatter pointer sequences. Previously, this pass is also responsible for converting scalar pointer load and store into memref; that transformation has now been moved to unstructured-to-memref.

In addition, the PR also updates the triton-to-linalg-experimental pass to fully utilize all the new passes. Once merged, triton-shared now fully supports gather / scatter. An example test (test_gather_scatter.py) is also added to demonstrate this new capability.


Intended lowering pipeline:

  • triton-to-structured (no changes):
    • analyzes structured addptr sequences
      • introduces tts.make_tptr %ptr_arg with offsets and strides
      • introduces tts.load and tts.store
    • leaves unstructured addptr sequences and their corresponding tt.load and tt.store intact
  • fold-unstructured-ptr (Introduce fold-unstructured-ptr pass #210):
    • converts all unstructured addptr sequences into sequences that compute pointer offsets
      • introduces tts.make_unstructured_tptr %ptr_arg %offsets
    • removes all tt.addptr
  • structured-to-memref (this PR):
    • currently converts everything to memref including scalar addptr and kernel arguments
    • will change to just convert ops in the tts dialect to memref with the exception of tts.make_unstructured_tptr
  • unstructured-to-memref (Introduce unstructured-to-memref pass #216 ):
    • converts the remaining unstructured tt.load, tt.store, and tts.make_unstructured_tptr into memref
  • triton-ptr-to-memref (Introduce triton-ptr-to-memref pass #211):
    • converts kernel arguments with pointer type to memref

@nhat-nguyen nhat-nguyen changed the title Update StructuredToMemref pass to support the new pass pipeline Update structured-to-memref pass to support the new pass pipeline Jan 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant