Composable triton-to-linalg #81
Replies: 1 comment
-
i'm 100% onboard with breaking up the triton-to-linalg pass into more discrete pieces and this looks like a good approach. For some history, the original design for triton-shared had the pointer/mask analysis separate from the transformation, but we ran into some technical issues that prevented us from implementing it that way. i believe the issues we ran into can be addressed with this proposed design primarily by utilizing the proposed new structured triton dialect. We may need to iterate on the details of the structured triton pointers, but since you're close to having a PR, let's handle that when you have it ready. That way we will have all the code and examples in front of us. Regarding the proposed plan for delivering this work: i want to make sure we don't end up with two versions of this pass. I think delivering this work in stages makes sense, as its going to be quite extensive, so let's commit to making sure we complete this (full functional replacement) and remove the old version before we call it done. My team can help with reviews and landing some of the code along the way. |
Beta Was this translation helpful? Give feedback.
-
TL;DR
triton-to-linalg
is an extremely powerful pass that does many things (very well if I may add 🙂). However, it also has a few hard edges especially in dealing with unsupported IRs.We propose to redesign the pass to create a few composable and robust passes, each of which performs specific sets of transformations. The combined effect of these passes will be a better version of the
triton-to-linalg
today. More importantly, these composable passes will hopefully make it easier to take advantage of the codebase, and make future maintenance and improvement easier.Introduction
triton-to-linalg
is an extremely powerful pass to convert Triton MLIR into standard dialects, making a significant portion of Triton kernels target-able in a performant manner by non-GPU backends. However, it also has a few limitations:tt.dot
’s operand C when its a constant splattt.func
intofunc.func
The goal of this proposal is to address these limitations. In the rest of this post, we present a high-level technical proposal, some lower-level details, and a tentative execution plan. Feedback is much appreciated!
Proposal
At a high-level, we propose to tease apart the current
triton-to-linalg
into the the following three passes (please bear in mind that the exact details may change as we collaboratively make more progress):1.
triton-to-structured
tts.make_tptr
,tts.load
,tts.store
).tt.load
) and store (tt.store
) operations into TritonStructured load (tts.load
) and store (tts.store
).triton-to-structured
to convert block pointers intotts.make_tptr
.2.
triton-arith-to-linalg
tt.func
intofunc.func
, and to convert program IDs into function parameters.3.
structured-to-memref
tt.ptr
into unranked memrefs.tts.make_tptr
intomemref.reinterpret_cast
,tts.load
andtts.store
intomemref.copy
with the corresponding bufferization operations.Triton Structured Dialect
The goal of this dialect is to cleanly represent information of pointer and mask analysis without introducing new concepts such as the ones from memref dialect. We propose to introduce three operations in this dialect:
MakeTensorPtrOp (
tts.make_tptr
)tts.make_tptr
represents structured pointer patterns. It creates a statically-sized multi-dimensional tensor of triton pointers (tensor
oftt.ptr
) from a base pointer of the same type, and supports both dynamic and static strides, offsets, and parent_sizes of index types.“parent_sizes” is used to represent the scenarios when pointers wrap around (represented by modulo operations in Triton kernels). A static 0 in parent_sizes indicates no wrap-around behavior along that dimension.
This field is also one of the main differentiators from MakeTensorPtrOp from Triton dialect (
tt.make_tensor_ptr
), which does not support such wrap-around behavior. As a side note,tt.make_tensor_ptr
with row-major order can be treated similarly totts.make_tptr
without wrap-around.E.g., the following Triton IR:
Can be represented as:
Note that wraparound takes effect after the offset. E.g., for a 1D tensor of pointers, the addresses it generates are:
(offsets[0] + i * strides[0]) % parent_sizes[0]
. This is the same is howPtrState
represents pointers in the current codebase.The conditions below should also be true:
Note that the conditions above are typically true. However, in cases when they are not, the new triton-to-structured pass will keep the original triton operations intact without crashing, so later passes can choose to lower them differently.
LoadOp (
tts.load
)tts.load
represents loading from a structured pointer, with optional sub-dimensions and constant fill/other value. It always takes a pointer produced directly by tts.make_tptr. It can take optional arguments to indicate the sub-dimension memory loads. It can also take another optional scalar value to indicate what data to use to fill the rest of the tensor.With sub-dimensions present, we always start loading data from index 0 to the size specified by the operand if present, or the full tensor if not.
StoreOp (
tts.store
)tts.store
represents storing to a structured pointer, with optional sub-dimensions. It is a mirror oftts.load
.TritonToStructured Pass Example
We are happy to discuss more details of the pass itself later. For brevity of this post, we provide an example below of 1D masked load
Triton IR
Result of
triton-to-structured
Tentative Execution Plan
The effort required for this work is non-trivial. Below we discuss our plan of tackling things. Contributions are greatly appreciated!!
For stability of the codebase and other ongoing work, we will duplicate some code during this effort.
1. TritonStructured dialect and
triton-to-structured
The initial implementation is code complete. The new pass can now process all LIT tests in the repo and fail gracefully. Please expect a PR soon.
Notably, we have not started working on lowering Triton block pointer operations (
tt.make_tensor_ptr
andtt.advance
). Help is very much appreciated here.2.
triton-arith-to-linalg
We plan to take it on after the above work (except for Triton block pointer related items) completes. We expect to finish implementing this pass in Jan or Feb.
3.
structured-to-memref
We have not started this work and are actively looking for collaboration on this. Note that this work is independent of
triton-arith-to-linalg
and can proceed in parallel.Beta Was this translation helpful? Give feedback.
All reactions