-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] add a canonicalizer before triton-to-linalg #62
Conversation
Attach an example ttir & ast we are facing to:
|
We may find that the indexing patterns can be represented like:
With f(0, 256, 7, 72) we got:
While f(1, 256, 7, 72) we got:
whose pattern isn't obvious to determine. In fact if we make programId 0, function will be simplified to :
I think we should make a compromise here and create a tt.assert. If you have better solution please don't hesitate to share. |
Thank you for the contribution. I have a small patch to support the modulo pattern that torch inductor generates together with some other small fixes. I will take a look at your proposal for canonicalizing the division operator. Thanks! |
I really like the idea of canonicalizing before TritonToLinalg so that PtrAnalysis only needs to take care of one pattern. I think for now though, to support the case that you're interested in, the code itself is quite short so we can still have it in PtrAnalysis to keep the complexity low. If this keeps growing, let's revisit the idea. I have a working branch over at if (state.getRank() == 1) {
// Apply the modulo before expanding shape, the common pattern is
// offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
// a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] *
// stride_ak)
assert(!state.modulos.back().has_value() &&
"No support for multiple modulos within an expression");
state.modulos.back() = ModuloState{rhsState.scalar};
} else if (state.getRank() == 2) {
// torch inductor expands the tensor shape before applying the modulo
// operator.
//
// We only support either:
// - (tl.arange(0, end)[:, None] % mod), or
// - (tl.arange(0, end)[None, :] % mod)
//
// In both cases, we apply the modulo to the non-singleton dimension.
auto shape = cast<TensorType>(remOp.getResult().getType()).getShape();
if (shape[0] == 1) {
state.modulos[1] = ModuloState{rhsState.scalar};
} else if (shape[1] == 1) {
state.modulos[0] = ModuloState{rhsState.scalar};
} else {
assert(false && "Do not support taking modulo on a 2D tensor with no "
"singleton dimension");
}
} else {
assert(false && "Unsupported modulo pattern");
} Also I would very much appreciate if you could add me as reviewer in future PRs so I can take a look at them in a timely manner. I don't get notifications otherwise. Thanks again! |
This PR is WIP and aims to add a canonicalizer for triton-to-linalg. It decouples the mutation of ttir from triton-to-linalg.
RemsiCanonicalizer will postpone expand_dims{axis=1} and provide valid input for PtrAnalysis, as the latter checks rank == 1:
After RemsiCanonicalizer, %11 = arith.remsi %5, %cst_11 : tensor<8x1xi32> will be :
Thus it is no longer a rank 2 tensor for PtrAnalysis.