Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Failed lowering mv kernel #197

Open
TracyMac1 opened this issue Dec 2, 2024 · 0 comments
Open

[Bug]: Failed lowering mv kernel #197

TracyMac1 opened this issue Dec 2, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@TracyMac1
Copy link

Triton python code

import logging

import torch
import triton
import triton.language as tl

@triton.autotune(
    configs=[
        triton.Config({"BLOCK_M": m, "BLOCK_N": n}, num_stages=s, num_warps=w)
        for m in [32, 64, 128]
        for n in [1, 2, 4, 8]
        for s in [3, 4]
        for w in [4, 8]
    ],
    key=["M", "N"],
)
@triton.jit
def mv_kernel(
    A,
    B,
    C,
    N,
    M,
    stride_an,
    stride_am,
    stride_bm,
    stride_cn,
    BLOCK_N: tl.constexpr,
    BLOCK_M: tl.constexpr,
):
    pid = tl.program_id(0)
    offset_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)[:, None]
    offset_m = tl.arange(0, BLOCK_M)[None, :]
    n_mask = offset_n < N
    A_ptrs = A + offset_n * stride_an + offset_m * stride_am
    B_ptrs = B + offset_m * stride_bm
    acc = tl.zeros((BLOCK_N, BLOCK_M), dtype=tl.float32)
    for m in range(0, M, BLOCK_M):
        m_mask = m + offset_m < M
        a = tl.load(A_ptrs, mask=n_mask & m_mask, other=0.0).to(tl.float32)
        b = tl.load(B_ptrs, mask=m_mask, other=0.0).to(tl.float32)
        acc += a * b
        A_ptrs += BLOCK_M * stride_am
        B_ptrs += BLOCK_M * stride_bm

    acc = tl.sum(acc, axis=1)
    C_ptrs = C + offset_n * stride_cn
    tl.store(C_ptrs, acc[:, None], mask=n_mask)


def mv(inp, vec):
    logging.debug("GEMS MV")
    assert inp.shape[1] == vec.shape[0], "incompatible dimensions"
    N, M = inp.shape
    out = torch.empty((N,), device=inp.device, dtype=inp.dtype)
    grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]),)
    mv_kernel[grid](
        inp,
        vec,
        out,
        N,
        M,
        inp.stride(0),
        inp.stride(1),
        vec.stride(0),
        out.stride(0),
    )
    return out

Triton IR

module {
  tt.func public @mv_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} , %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} , %arg3: i32 {tt.divisibility = 16 : i32} , %arg4: i32 {tt.divisibility = 16 : i32} , %arg5: i32 {tt.divisibility = 16 : i32} ) attributes {noinline = false} {
    %c32_i32 = arith.constant 32 : i32
    %c0_i32 = arith.constant 0 : i32
    %cst = arith.constant dense<32> : tensor<1x32xi32>
    %cst_0 = arith.constant dense<0.000000e+00> : tensor<1x32xf32>
    %0 = tt.get_program_id x : i32
    %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
    %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32>
    %3 = arith.cmpi slt, %0, %arg3 : i32
    %4 = tt.splat %3 : i1 -> tensor<1x1xi1>
    %5 = arith.muli %0, %arg5 : i32 loc
    %6 = tt.addptr %arg0, %5 : !tt.ptr<f32>, i32
    %7 = tt.splat %6 : !tt.ptr<f32> -> tensor<1x32x!tt.ptr<f32>>
    %8 = tt.addptr %7, %2 : tensor<1x32x!tt.ptr<f32>>, tensor<1x32xi32>
    %9 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<1x32x!tt.ptr<f32>>
    %10 = tt.addptr %9, %2 : tensor<1x32x!tt.ptr<f32>>, tensor<1x32xi32>
    %11 = tt.splat %arg4 : i32 -> tensor<1x32xi32>
    %12 = tt.splat %3 : i1 -> tensor<1x32xi1>
    %13:3 = scf.for %arg6 = %c0_i32 to %arg4 step %c32_i32 iter_args(%arg7 = %cst_0, %arg8 = %8, %arg9 = %10) -> (tensor<1x32xf32>, tensor<1x32x!tt.ptr<f32>>, tensor<1x32x!tt.ptr<f32>>)  : i32 {
      %18 = tt.splat %arg6 : i32 -> tensor<1x32xi32>
      %19 = arith.addi %18, %2 : tensor<1x32xi32>
      %20 = arith.cmpi slt, %19, %11 : tensor<1x32xi32>
      %21 = arith.andi %12, %20 : tensor<1x32xi1>
      %22 = tt.load %arg8, %21, %cst_0 : tensor<1x32x!tt.ptr<f32>>
      %23 = tt.load %arg9, %20, %cst_0 : tensor<1x32x!tt.ptr<f32>>
      %24 = arith.mulf %22, %23 : tensor<1x32xf32>
      %25 = arith.addf %arg7, %24 : tensor<1x32xf32>
      %26 = tt.addptr %arg8, %cst : tensor<1x32x!tt.ptr<f32>>, tensor<1x32xi32>
      %27 = tt.addptr %arg9, %cst : tensor<1x32x!tt.ptr<f32>>, tensor<1x32xi32>
      scf.yield %25, %26, %27 : tensor<1x32xf32>, tensor<1x32x!tt.ptr<f32>>, tensor<1x32x!tt.ptr<f32>>
    }
    %14 = "tt.reduce"(%13#0) <{axis = 1 : i32}> ({
    ^bb0(%arg6: f32, %arg7: f32):
      %18 = arith.addf %arg6, %arg7 : f32
      tt.reduce.return %18 : f32
    }) : (tensor<1x32xf32>) -> tensor<1xf32>
    %15 = tt.addptr %arg2, %0 : !tt.ptr<f32>, i32
    %16 = tt.splat %15 : !tt.ptr<f32> -> tensor<1x1x!tt.ptr<f32>>
    %17 = tt.expand_dims %14 {axis = 1 : i32} : tensor<1xf32> -> tensor<1x1xf32>
    tt.store %16, %17, %4 : tensor<1x1x!tt.ptr<f32>>
    tt.return
  }
}

Crash log

test_mv_ops.py:45:20: remark: MaskAnalysis failed
        a = tl.load(A_ptrs, mask=n_mask & m_mask, other=0.0).to(tl.float32)
                   ^
test_mv_ops.py:45:20: note: see current operation: %29 = tt.load %arg7, %27, %cst_0 : tensor<1x32x!tt.ptr<f16>>
test_mv_ops.py:45:20: remark: PtrAnalysis: Failed to rewrite LoadOp
        a = tl.load(A_ptrs, mask=n_mask & m_mask, other=0.0).to(tl.float32)
                   ^
test_mv_ops.py:45:20: note: see current operation: %29 = tt.load %arg7, %27, %cst_0 : tensor<1x32x!tt.ptr<f16>>
test_mv_ops.py:53:21: remark: PtrAnalysis: pointer is not replace with tts.make_tptr so storeOp cannot be rewritten
    tl.store(C_ptrs, acc[:, None], mask=n_mask)
                    ^
test_mv_ops.py:53:21: note: see current operation: tt.store %20, %22, %5 : tensor<1x1x!tt.ptr<f16>>
test_mv_ops.py:53:21: remark: PtrAnalysis: Failed to rewrite StoreOp
    tl.store(C_ptrs, acc[:, None], mask=n_mask)
                    ^
test_mv_ops.py:53:21: note: see current operation: tt.store %20, %22, %5 : tensor<1x1x!tt.ptr<f16>>
test_mv_ops.py:40:40: error: failed to legalize unresolved materialization from ('memref<1x32xf16, strided<[?, ?], offset: ?>>') to 'tensor<1x32x!tt.ptr<f16>>' that remained live after conversion
    A_ptrs = A + offset_n * stride_an + offset_m * stride_am
                                       ^
test_mv_ops.py:40:40: note: see current operation: %44 = "builtin.unrealized_conversion_cast"(%arg23) : (memref<1x32xf16, strided<[?, ?], offset: ?>>) -> tensor<1x32x!tt.ptr<f16>>
test_mv_ops.py:45:20: note: see existing live user here: %33 = tt.load %25, %32, %3 : tensor<1x32x!tt.ptr<f16>>
        a = tl.load(A_ptrs, mask=n_mask & m_mask, other=0.0).to(tl.float32)

Additional information

No response

@TracyMac1 TracyMac1 added the bug Something isn't working label Dec 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant