Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] add a canonicalizer before triton-to-linalg #62

Closed
wants to merge 0 commits into from

Conversation

yuanfz98
Copy link
Contributor

@yuanfz98 yuanfz98 commented Nov 22, 2023

This PR is WIP and aims to add a canonicalizer for triton-to-linalg. It decouples the mutation of ttir from triton-to-linalg.
RemsiCanonicalizer will postpone expand_dims{axis=1} and provide valid input for PtrAnalysis, as the latter checks rank == 1:

void PtrAnalysis::visitOperandRem(
    arith::RemSIOp remOp, PtrState &state, const Location loc,
    ConversionPatternRewriter &rewriter,
    const llvm::SmallDenseMap<Value, PtrState> &knownPtrs) {
  assert(state.isEmpty());
  visitOperand(remOp.getLhs(), state, loc, rewriter, knownPtrs);
  assert(state.getRank() == 1 && !state.modulos.back().has_value() &&
         "No support for multiple modulos within an expression");

After RemsiCanonicalizer, %11 = arith.remsi %5, %cst_11 : tensor<8x1xi32> will be :

%11 = arith.remsi %5_clone, %cst_clone : tensor<8xi32>
%12 = tt.expand_dims %11 {axis = 1 : i32} : (tensor<8xi32>) -> tensor<8x1xi32>

Thus it is no longer a rank 2 tensor for PtrAnalysis.

@yuanfz98
Copy link
Contributor Author

yuanfz98 commented Nov 22, 2023

Attach an example ttir & ast we are facing to:

def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 24576
    rnumel = 256
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex % 384
    x1 = (xindex // 384)
    _tmp5 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
    x3 = xindex
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r2 = rindex
        tmp0 = tl.load(in_ptr0 + (x0 + (384*r2) + (98304*x1)), rmask, other=0).to(tl.float32)
        tmp2 = tl.load(in_ptr1 + (x0 + (384*r2) + (98304*x1)), rmask, other=0)
        tmp1 = tmp0.to(tl.float32)
        tmp3 = tmp1 * tmp2
        tmp4 = tl.broadcast_to(tmp3, [XBLOCK, RBLOCK])
        tmp6 = _tmp5 + tmp4
        _tmp5 = tl.where(rmask, tmp6, _tmp5)
    tmp5 = tl.sum(_tmp5, 1)[:, None]
    tl.store(out_ptr0 + (x3), tmp5, None)

module {
  tt.func public @triton__0d1d2d3de4de(%arg0: !tt.ptr<bf16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg4: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} {
    %cst = arith.constant dense<0.000000e+00> : tensor<128x8xbf16>
    %c8_i32 = arith.constant 8 : i32
    %c256_i32 = arith.constant 256 : i32
    %c0_i32 = arith.constant 0 : i32
    %cst_0 = arith.constant dense<98304> : tensor<128x1xi32>
    %cst_1 = arith.constant dense<384> : tensor<1x8xi32>
    %cst_2 = arith.constant dense<256> : tensor<1x8xi32>
    %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x8xf32>
    %cst_4 = arith.constant dense<384> : tensor<128x1xi32>
    %c128_i32 = arith.constant 128 : i32
    %0 = tt.get_program_id x : i32
    %1 = arith.muli %0, %c128_i32 : i32
    %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
    %3 = tt.expand_dims %2 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32>
    %4 = tt.splat %1 : (i32) -> tensor<128x1xi32>
    %5 = arith.addi %4, %3 : tensor<128x1xi32>
    %6 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32>
    %7 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<8xi32>) -> tensor<1x8xi32>
    %8 = arith.remsi %5, %cst_4 : tensor<128x1xi32>
    %9 = arith.divsi %5, %cst_4 : tensor<128x1xi32>
    %10 = tt.broadcast %8 : (tensor<128x1xi32>) -> tensor<128x8xi32>
    %11 = arith.muli %9, %cst_0 : tensor<128x1xi32>
    %12 = tt.broadcast %11 : (tensor<128x1xi32>) -> tensor<128x8xi32>
    %13 = tt.splat %arg0 : (!tt.ptr<bf16, 1>) -> tensor<128x8x!tt.ptr<bf16, 1>>
    %14 = tt.splat %arg1 : (!tt.ptr<f32, 1>) -> tensor<128x8x!tt.ptr<f32, 1>>
    %15 = scf.for %arg5 = %c0_i32 to %c256_i32 step %c8_i32 iter_args(%arg6 = %cst_3) -> (tensor<128x8xf32>)  : i32 {
      %20 = tt.splat %arg5 : (i32) -> tensor<1x8xi32>
      %21 = arith.addi %20, %7 : tensor<1x8xi32>
      %22 = arith.cmpi slt, %21, %cst_2 : tensor<1x8xi32>
      %23 = arith.muli %21, %cst_1 : tensor<1x8xi32>
      %24 = tt.broadcast %23 : (tensor<1x8xi32>) -> tensor<128x8xi32>
      %25 = arith.addi %10, %24 : tensor<128x8xi32>
      %26 = arith.addi %25, %12 : tensor<128x8xi32>
      %27 = tt.addptr %13, %26 : tensor<128x8x!tt.ptr<bf16, 1>>, tensor<128x8xi32>
      %28 = tt.broadcast %22 : (tensor<1x8xi1>) -> tensor<128x8xi1>
      %29 = tt.load %27, %28, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x8xbf16>
      %30 = arith.extf %29 : tensor<128x8xbf16> to tensor<128x8xf32>
      %31 = tt.addptr %14, %26 : tensor<128x8x!tt.ptr<f32, 1>>, tensor<128x8xi32>
      %32 = tt.load %31, %28, %cst_3 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x8xf32>
      %33 = arith.mulf %30, %32 : tensor<128x8xf32>
      %34 = arith.addf %arg6, %33 : tensor<128x8xf32>
      %35 = arith.select %28, %34, %arg6 : tensor<128x8xi1>, tensor<128x8xf32>
      scf.yield %35 : tensor<128x8xf32>
    }
    %16 = "tt.reduce"(%15) <{axis = 1 : i32}> ({
    ^bb0(%arg5: f32, %arg6: f32):
      %20 = arith.addf %arg5, %arg6 : f32
      tt.reduce.return %20 : f32
    }) : (tensor<128x8xf32>) -> tensor<128xf32>
    %17 = tt.expand_dims %16 {axis = 1 : i32} : (tensor<128xf32>) -> tensor<128x1xf32>
    %18 = tt.splat %arg2 : (!tt.ptr<f32, 1>) -> tensor<128x1x!tt.ptr<f32, 1>>
    %19 = tt.addptr %18, %5 : tensor<128x1x!tt.ptr<f32, 1>>, tensor<128x1xi32>
    tt.store %19, %17 {cache = 1 : i32, evict = 1 : i32} : tensor<128x1xf32>
    tt.return
  }
}


@yuanfz98
Copy link
Contributor Author

We may find that the indexing patterns can be represented like:

def f(programId, x, factor0, factor1):
    a = list(range(0, x))
    result = []
    for e in a:
        offset = programId * x
        o = offset + e
        result.append((o // factor0) * factor1 + o % factor0)
    print(result)

With f(0, 256, 7, 72) we got:

[0, 1, 2, 3, 4, 5, 6, 72, 73, 74, 75, 76...]

While f(1, 256, 7, 72) we got:

[2596, 2597, 2598, 2664, 2665, 2666, 2667, 2668, 2669, 2670, 2736, 2737, 2738...]

whose pattern isn't obvious to determine.

In fact if we make programId 0, function will be simplified to :

def f(programId, x, factor0, factor1):
    a = list(range(0, x))
    result = []
    for e in a:
        result.append((e // factor0) * factor1 + e % factor0)
    print(result)

I think we should make a compromise here and create a tt.assert. If you have better solution please don't hesitate to share.

@nhat-nguyen
Copy link
Collaborator

Thank you for the contribution. I have a small patch to support the modulo pattern that torch inductor generates together with some other small fixes. I will take a look at your proposal for canonicalizing the division operator. Thanks!

@nhat-nguyen
Copy link
Collaborator

nhat-nguyen commented Nov 24, 2023

I really like the idea of canonicalizing before TritonToLinalg so that PtrAnalysis only needs to take care of one pattern. I think for now though, to support the case that you're interested in, the code itself is quite short so we can still have it in PtrAnalysis to keep the complexity low. If this keeps growing, let's revisit the idea. I have a working branch over at nhat/modulo if you're interested in checking out early. I just need to do a bit of cleanup before publishing the PR. Here's the code to support your case:

if (state.getRank() == 1) {
    // Apply the modulo before expanding shape, the common pattern is
    // offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    // a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] *
    // stride_ak)

    assert(!state.modulos.back().has_value() &&
           "No support for multiple modulos within an expression");

    state.modulos.back() = ModuloState{rhsState.scalar};

  } else if (state.getRank() == 2) {
    // torch inductor expands the tensor shape before applying the modulo
    // operator.
    //
    // We only support either:
    // - (tl.arange(0, end)[:, None] % mod), or
    // - (tl.arange(0, end)[None, :] % mod)
    //
    // In both cases, we apply the modulo to the non-singleton dimension.
    auto shape = cast<TensorType>(remOp.getResult().getType()).getShape();
    if (shape[0] == 1) {
      state.modulos[1] = ModuloState{rhsState.scalar};
    } else if (shape[1] == 1) {
      state.modulos[0] = ModuloState{rhsState.scalar};
    } else {
      assert(false && "Do not support taking modulo on a 2D tensor with no "
                      "singleton dimension");
    }
  } else {
    assert(false && "Unsupported modulo pattern");
  }

Also I would very much appreciate if you could add me as reviewer in future PRs so I can take a look at them in a timely manner. I don't get notifications otherwise. Thanks again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants