Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce fold-unstructured-ptr pass #210

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

nhat-nguyen
Copy link
Collaborator

@nhat-nguyen nhat-nguyen commented Jan 2, 2025

This PR introduces the fold-unstructured-ptr pass which is the first step towards allowing triton-shared to compile pointer sequences that cannot be analyzed by triton-to-structured (gather / scatter).

Intended lowering pipeline

  • triton-to-structured (no changes):
    • analyzes structured addptr sequences
      • introduces tts.make_tptr %ptr_arg with offsets and strides
      • introduces tts.load and tts.store
    • leaves unstructured addptr sequences and their corresponding tt.load and tt.store intact
  • fold-unstructured-ptr (this PR):
    • converts all unstructured addptr sequences into sequences that compute pointer offsets
      • introduces tts.make_unstructured_tptr %ptr_arg %offsets
    • removes all tt.addptr
  • structured-to-memref (to be updated in a different PR):
    • currently converts everything to memref including scalar addptr and kernel arguments
    • will change to just convert ops in the tts dialect to memref with the exception of tts.make_unstructured_tptr
  • unstructured-to-memref (to be introduced in a different PR):
    • converts the remaining unstructured tt.load, tt.store, and tts.make_unstructured_tptr into memref
  • triton-ptr-to-memref (to be introduced in a different PR):
    • converts kernel arguments with pointer type to memref

Pass design

This pass attempts to simplify the uses of triton pointers in the IR which
will help lowering triton to other mlir dialects easier.

A triton pointer has two pieces of info:

  • a base pointer which comes from the kernel arguments
  • an offset which could be either a tensor of offset or a single integer
    offset

Triton pointers are created and manipulated through a sequence of tt.addptr,
tt.splat, or tt.broadcast ops. If a triton pointer is created through
tt.addptr %ptr %offset, the new pointer will contain the same base pointer
as the original pointer; its offset will also be accumulated. Triton pointers
created through tt.splat and tt.broadcast retain their base pointers and
offsets. Tensors of pointers cannot have different bases by design. In other
words, the base pointer is fixed throughout a chain of pointer manipulation
ops.

Leveraging these insights, we can simplify chains of tt.addptr,
tt.splat, and tt.broadcast which produce triton pointers to just a sequence
of offset manipulation ops and a base pointer.

In essence, this pass transforms all sequences of tt.addptr into sequences of
offset accumulation ops which are then fed into a single op
tts.make_unstructured_tptr that takes:

  • a base pointer from the kernel arguments
  • a tensor of offsets (or single offset) that indicates the offsets from
    the base pointer

This simplification makes it easier for subsequent passes to lower these load
and store ops. The pass unstructured-to-memref will leverage this output to
lower the unstructured triton load / store ops into memref load / store ops
with the appropriate offsets.

See the comments in
lib/Conversion/FoldUnstructuredTritonAddPtr/FoldUnstructuredTritonAddPtrPass.cpp
for more detailed description on the approach.

@nhat-nguyen nhat-nguyen changed the title Introduce fold-unstructured-triton-ptr pass Introduce fold-unstructured-ptr pass Jan 3, 2025
@nhat-nguyen nhat-nguyen marked this pull request as ready for review January 6, 2025 19:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant