Introduce fold-unstructured-ptr
pass
#210
Draft
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR introduces the
fold-unstructured-ptr
pass which is the first step towards allowing triton-shared to compile pointer sequences that cannot be analyzed bytriton-to-structured
(gather / scatter).Intended lowering pipeline
tts.make_tptr %ptr_arg with offsets and strides
tts.load
andtts.store
tt.load
andtt.store
intacttts.make_unstructured_tptr %ptr_arg %offsets
tt.addptr
tts
dialect tomemref
with the exception oftts.make_unstructured_tptr
tt.load
,tt.store
, andtts.make_unstructured_tptr
into memrefPass design
This pass attempts to simplify the uses of triton pointers in the IR which
will help lowering triton to other mlir dialects easier.
A triton pointer has two pieces of info:
offset
Triton pointers are created and manipulated through a sequence of tt.addptr,
tt.splat, or tt.broadcast ops. If a triton pointer is created through
tt.addptr %ptr %offset
, the new pointer will contain the same base pointeras the original pointer; its offset will also be accumulated. Triton pointers
created through tt.splat and tt.broadcast retain their base pointers and
offsets. Tensors of pointers cannot have different bases by design. In other
words, the base pointer is fixed throughout a chain of pointer manipulation
ops.
Leveraging these insights, we can simplify chains of tt.addptr,
tt.splat, and tt.broadcast which produce triton pointers to just a sequence
of offset manipulation ops and a base pointer.
In essence, this pass transforms all sequences of tt.addptr into sequences of
offset accumulation ops which are then fed into a single op
tts.make_unstructured_tptr that takes:
the base pointer
This simplification makes it easier for subsequent passes to lower these load
and store ops. The pass
unstructured-to-memref
will leverage this output tolower the unstructured triton load / store ops into memref load / store ops
with the appropriate offsets.
See the comments in
lib/Conversion/FoldUnstructuredTritonAddPtr/FoldUnstructuredTritonAddPtrPass.cpp
for more detailed description on the approach.