Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: encountered operand produced by an unsupported operation "arith.extsi" #125

Open
colawithsauce opened this issue Apr 5, 2024 · 4 comments
Assignees
Labels
bug Something isn't working

Comments

@colawithsauce
Copy link

colawithsauce commented Apr 5, 2024

Triton python code

#!/usr/bin/env python3

import torch

import triton
import triton.language as tl


@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
                      num_warps=2),
        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
                      num_warps=2),
    ],
    key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel_with_block_pointers(
        # Pointers to matrices
        a_ptr, b_ptr, c_ptr,
        # Matrix dimensions
        M, N, K,
        # The stride variables represent how much to increase the ptr by when moving by 1
        # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
        # by to get the element one row down (A has M rows).
        stride_am, stride_ak,  #
        stride_bk, stride_bn,  #
        stride_cm, stride_cn,
        # Meta-parameters
        BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
    """Kernel for computing the matmul C = A x B.
    A has shape (M, K), B has shape (K, N) and C has shape (M, N)
    """
    # -----------------------------------------------------------
    # Map program ids `pid` to the block of C it should compute.
    # This is done in a grouped ordering to promote L2 data reuse.
    # See the matrix multiplication tutorial for details.
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + (pid % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    # ----------------------------------------------------------
    # Create block pointers for the first blocks of A and B.
    # We will advance this pointer as we move in the K direction and accumulate.
    # See above `Make a Block Pointer` section for details.
    a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
                                    offsets=(pid_m * BLOCK_SIZE_M, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K),
                                    order=(1, 0))
    b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
                                    offsets=(0, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N),
                                    order=(1, 0))

    # -----------------------------------------------------------
    # Iterate to compute a block of the C matrix.
    # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block.
    # of fp32 values for higher accuracy.
    # `accumulator` will be converted back to fp16 after the loop.
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    for k in range(0, K, BLOCK_SIZE_K):
        # Load with boundary checks, no need to calculate the mask manually.
        # For better performance, you may remove some axis from the boundary
        # check, if you can guarantee that the access is always in-bound in
        # that axis.
        # See above `Load/Store a Block Pointer` section for details.
        a = tl.load(a_block_ptr, boundary_check=(0, 1))
        b = tl.load(b_block_ptr, boundary_check=(0, 1))
        # We accumulate along the K dimension.
        accumulator += tl.dot(a, b)
        # Advance the block pointer to the next K block.
        # See above `Advance a Block Pointer` section for details.
        a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_SIZE_K))
        b_block_ptr = tl.advance(b_block_ptr, (BLOCK_SIZE_K, 0))
    c = accumulator.to(tl.float16)

    # ----------------------------------------------------------------
    # Write back the block of the output matrix C with boundary checks.
    # See above `Load/Store a Block Pointer` section for details.
    c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
                                    offsets=(pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N),
                                    block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0))
    tl.store(c_block_ptr, c, boundary_check=(0, 1))


# We can now create a convenience wrapper function that only takes two input tensors,
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel.
def matmul(a, b):
    # Check constraints.
    assert a.shape[1] == b.shape[0], "Incompatible dimensions"
    assert a.is_contiguous(), "Matrix A must be contiguous"
    assert b.is_contiguous(), "Matrix B must be contiguous"
    M, K = a.shape
    K, N = b.shape
    # Allocates output.
    c = torch.empty((M, N), device=a.device, dtype=a.dtype)
    # 1D launch kernel where each block gets its own program.
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
    compiled_kernel = matmul_kernel_with_block_pointers[grid](
        a, b, c,  #
        M, N, K,  #
        a.stride(0), a.stride(1),  #
        b.stride(0), b.stride(1),  #
        c.stride(0), c.stride(1))

    print(compiled_kernel.asm['ttir'])
    return c

if __name__ == '__main__':
    torch.manual_seed(0)
    a = torch.randn((2048, 2048), device='cuda', dtype=torch.float16)
    b = torch.randn((2048, 2048), device='cuda', dtype=torch.float16)
    matmul(a, b)

Triton IR

module {
  tt.func public @matmul_kernel_with_block_pointers_0d1d2d3d4d5d6d7c8d9c10d11c(%arg0: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
    %c32_i64 = arith.constant 32 : i64
    %cst = arith.constant dense<0> : tensor<1x128xi64>
    %cst_0 = arith.constant dense<0> : tensor<32x1xi64>
    %cst_1 = arith.constant dense<0> : tensor<1x32xi64>
    %cst_2 = arith.constant dense<0> : tensor<64x1xi64>
    %c0_i64 = arith.constant 0 : i64
    %c127_i32 = arith.constant 127 : i32
    %c63_i32 = arith.constant 63 : i32
    %cst_3 = arith.constant dense<0.000000e+00> : tensor<64x128xf32>
    %c32_i32 = arith.constant 32 : i32
    %c128_i32 = arith.constant 128 : i32
    %c0_i32 = arith.constant 0 : i32
    %c64_i32 = arith.constant 64 : i32
    %c8_i32 = arith.constant 8 : i32
    %0 = tt.get_program_id x : i32
    %1 = arith.addi %arg3, %c63_i32 : i32
    %2 = arith.divsi %1, %c64_i32 : i32
    %3 = arith.addi %arg4, %c127_i32 : i32
    %4 = arith.divsi %3, %c128_i32 : i32
    %5 = arith.muli %4, %c8_i32 : i32
    %6 = arith.divsi %0, %5 : i32
    %7 = arith.muli %6, %c8_i32 : i32
    %8 = arith.subi %2, %7 : i32
    %9 = arith.minsi %8, %c8_i32 : i32
    %10 = arith.remsi %0, %9 : i32
    %11 = arith.addi %7, %10 : i32
    %12 = arith.remsi %0, %5 : i32
    %13 = arith.divsi %12, %9 : i32
    %14 = arith.muli %11, %c64_i32 : i32
    %15 = arith.extsi %arg3 : i32 to i64
    %16 = arith.extsi %arg5 : i32 to i64
    %17 = arith.extsi %arg6 : i32 to i64
    %18 = arith.extsi %14 : i32 to i64
    %19 = arith.muli %13, %c128_i32 : i32
    %20 = arith.extsi %arg4 : i32 to i64
    %21 = arith.extsi %arg7 : i32 to i64
    %22 = arith.extsi %19 : i32 to i64
    %23 = tt.splat %arg0 : !tt.ptr<f16, 1> -> tensor<64x32x!tt.ptr<f16, 1>>
    %24 = tt.splat %18 : i64 -> tensor<64xi64>
    %25 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
    %26 = arith.extsi %25 : tensor<64xi32> to tensor<64xi64>
    %27 = arith.addi %24, %26 : tensor<64xi64>
    %28 = tt.expand_dims %27 {axis = 1 : i32} : tensor<64xi64> -> tensor<64x1xi64>
    %29 = tt.splat %17 : i64 -> tensor<64x1xi64>
    %30 = arith.muli %28, %29 : tensor<64x1xi64>
    %31 = tt.broadcast %30 : tensor<64x1xi64> -> tensor<64x32xi64>
    %32 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
    %33 = arith.extsi %32 : tensor<32xi32> to tensor<32xi64>
    %34 = arith.cmpi sge, %28, %cst_2 : tensor<64x1xi64>
    %35 = tt.splat %15 : i64 -> tensor<64x1xi64>
    %36 = arith.cmpi slt, %28, %35 : tensor<64x1xi64>
    %37 = arith.andi %34, %36 : tensor<64x1xi1>
    %38 = tt.broadcast %37 : tensor<64x1xi1> -> tensor<64x32xi1>
    %39 = tt.splat %16 : i64 -> tensor<1x32xi64>
    %40 = tt.splat %arg1 : !tt.ptr<f16, 1> -> tensor<32x128x!tt.ptr<f16, 1>>
    %41 = tt.splat %21 : i64 -> tensor<32x1xi64>
    %42 = tt.splat %22 : i64 -> tensor<128xi64>
    %43 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
    %44 = arith.extsi %43 : tensor<128xi32> to tensor<128xi64>
    %45 = arith.addi %42, %44 : tensor<128xi64>
    %46 = tt.expand_dims %45 {axis = 0 : i32} : tensor<128xi64> -> tensor<1x128xi64>
    %47 = tt.broadcast %46 : tensor<1x128xi64> -> tensor<32x128xi64>
    %48 = tt.splat %16 : i64 -> tensor<32x1xi64>
    %49 = arith.cmpi sge, %46, %cst : tensor<1x128xi64>
    %50 = tt.splat %20 : i64 -> tensor<1x128xi64>
    %51 = arith.cmpi slt, %46, %50 : tensor<1x128xi64>
    %52 = arith.andi %49, %51 : tensor<1x128xi1>
    %53 = tt.broadcast %52 : tensor<1x128xi1> -> tensor<32x128xi1>
    %54:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst_3, %arg11 = %c0_i64, %arg12 = %c0_i64) -> (tensor<64x128xf32>, i64, i64)  : i32 {
      %85 = tt.splat %arg11 : i64 -> tensor<32xi64>
      %86 = arith.addi %85, %33 : tensor<32xi64>
      %87 = tt.expand_dims %86 {axis = 0 : i32} : tensor<32xi64> -> tensor<1x32xi64>
      %88 = tt.broadcast %87 : tensor<1x32xi64> -> tensor<64x32xi64>
      %89 = arith.addi %31, %88 : tensor<64x32xi64>
      %90 = tt.addptr %23, %89 : tensor<64x32x!tt.ptr<f16, 1>>, tensor<64x32xi64>
      %91 = arith.cmpi sge, %87, %cst_1 : tensor<1x32xi64>
      %92 = arith.cmpi slt, %87, %39 : tensor<1x32xi64>
      %93 = arith.andi %91, %92 : tensor<1x32xi1>
      %94 = tt.broadcast %93 : tensor<1x32xi1> -> tensor<64x32xi1>
      %95 = arith.andi %38, %94 : tensor<64x32xi1>
      %96 = tt.load %90, %95 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32x!tt.ptr<f16, 1>>
      %97 = tt.splat %arg12 : i64 -> tensor<32xi64>
      %98 = arith.addi %97, %33 : tensor<32xi64>
      %99 = tt.expand_dims %98 {axis = 1 : i32} : tensor<32xi64> -> tensor<32x1xi64>
      %100 = arith.muli %99, %41 : tensor<32x1xi64>
      %101 = tt.broadcast %100 : tensor<32x1xi64> -> tensor<32x128xi64>
      %102 = arith.addi %101, %47 : tensor<32x128xi64>
      %103 = tt.addptr %40, %102 : tensor<32x128x!tt.ptr<f16, 1>>, tensor<32x128xi64>
      %104 = arith.cmpi sge, %99, %cst_0 : tensor<32x1xi64>
      %105 = arith.cmpi slt, %99, %48 : tensor<32x1xi64>
      %106 = arith.andi %104, %105 : tensor<32x1xi1>
      %107 = tt.broadcast %106 : tensor<32x1xi1> -> tensor<32x128xi1>
      %108 = arith.andi %107, %53 : tensor<32x128xi1>
      %109 = tt.load %103, %108 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128x!tt.ptr<f16, 1>>
      %110 = tt.dot %96, %109, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<64x32xf16> * tensor<32x128xf16> -> tensor<64x128xf32>
      %111 = arith.addi %arg11, %c32_i64 : i64
      %112 = arith.addi %arg12, %c32_i64 : i64
      scf.yield %110, %111, %112 : tensor<64x128xf32>, i64, i64
    }
    %55 = arith.truncf %54#0 : tensor<64x128xf32> to tensor<64x128xf16>
    %56 = arith.extsi %arg8 : i32 to i64
    %57 = tt.splat %arg2 : !tt.ptr<f16, 1> -> tensor<64x128x!tt.ptr<f16, 1>>
    %58 = tt.splat %18 : i64 -> tensor<64xi64>
    %59 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
    %60 = arith.extsi %59 : tensor<64xi32> to tensor<64xi64>
    %61 = arith.addi %58, %60 : tensor<64xi64>
    %62 = tt.expand_dims %61 {axis = 1 : i32} : tensor<64xi64> -> tensor<64x1xi64>
    %63 = tt.splat %56 : i64 -> tensor<64x1xi64>
    %64 = arith.muli %62, %63 : tensor<64x1xi64>
    %65 = tt.broadcast %64 : tensor<64x1xi64> -> tensor<64x128xi64>
    %66 = tt.splat %22 : i64 -> tensor<128xi64>
    %67 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
    %68 = arith.extsi %67 : tensor<128xi32> to tensor<128xi64>
    %69 = arith.addi %66, %68 : tensor<128xi64>
    %70 = tt.expand_dims %69 {axis = 0 : i32} : tensor<128xi64> -> tensor<1x128xi64>
    %71 = tt.broadcast %70 : tensor<1x128xi64> -> tensor<64x128xi64>
    %72 = arith.addi %65, %71 : tensor<64x128xi64>
    %73 = tt.addptr %57, %72 : tensor<64x128x!tt.ptr<f16, 1>>, tensor<64x128xi64>
    %74 = arith.cmpi sge, %62, %cst_2 : tensor<64x1xi64>
    %75 = tt.splat %15 : i64 -> tensor<64x1xi64>
    %76 = arith.cmpi slt, %62, %75 : tensor<64x1xi64>
    %77 = arith.andi %74, %76 : tensor<64x1xi1>
    %78 = tt.broadcast %77 : tensor<64x1xi1> -> tensor<64x128xi1>
    %79 = arith.cmpi sge, %70, %cst : tensor<1x128xi64>
    %80 = tt.splat %20 : i64 -> tensor<1x128xi64>
    %81 = arith.cmpi slt, %70, %80 : tensor<1x128xi64>
    %82 = arith.andi %79, %81 : tensor<1x128xi1>
    %83 = tt.broadcast %82 : tensor<1x128xi1> -> tensor<64x128xi1>
    %84 = arith.andi %78, %83 : tensor<64x128xi1>
    tt.store %73, %55, %84 {cache = 1 : i32, evict = 1 : i32} : tensor<64x128x!tt.ptr<f16, 1>>
    tt.return
  }
}

Crash log

[16:20:18] colawithsauce@SPQR /home/colawithsauce/Projects/Triton
> triton-shared-opt --triton-to-linalg ~/playground/test.mlir
%43 = "arith.extsi"(%42) {MetaUse} : (tensor<64xi32>) -> tensor<64xi64>
encountered addptr operand produced by an unsupported operation
UNREACHABLE executed at /home/colawithsauce/Projects/Triton/triton_shared/lib/Analysis/PtrAnalysis.cpp:701!
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.      Program arguments: triton-shared-opt --triton-to-linalg /home/colawithsauce/playground/test.mlir
 #0 0x00005f9befc0b087 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x242a087)
 #1 0x00005f9befc08bae llvm::sys::RunSignalHandlers() (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x2427bae)
 #2 0x00005f9befc0b73f SignalHandler(int) Signals.cpp:0:0
 #3 0x000077438ec54eb0 __restore_rt (/nix/store/1rm6sr6ixxzipv5358x0cmaw8rs84g2j-glibc-2.38-44/lib/libc.so.6+0x3deb0)
 #4 0x000077438eca407c __pthread_kill_implementation (/nix/store/1rm6sr6ixxzipv5358x0cmaw8rs84g2j-glibc-2.38-44/lib/libc.so.6+0x8d07c)
 #5 0x000077438ec54e06 gsignal (/nix/store/1rm6sr6ixxzipv5358x0cmaw8rs84g2j-glibc-2.38-44/lib/libc.so.6+0x3de06)
 #6 0x000077438ec3d8f5 abort (/nix/store/1rm6sr6ixxzipv5358x0cmaw8rs84g2j-glibc-2.38-44/lib/libc.so.6+0x268f5)
 #7 0x00005f9befbccfc1 (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x23ebfc1)
 #8 0x00005f9bedc26dc0 mlir::triton::PtrAnalysis::visitOperandMul(mlir::arith::MulIOp, mlir::triton::PtrState&, mlir::Location, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState>> const&) /home/colawithsauce/Projects/Triton/triton_shared/lib/Analysis/PtrAnalysis.cpp:369:0
 #9 0x00005f9bedc26780 llvm::SmallVectorBase<unsigned int>::size() const /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/llvm/ADT/SmallVector.h:91:32
#10 0x00005f9bedc26780 mlir::triton::PtrState::getRank() const /home/colawithsauce/Projects/Triton/triton_shared/lib/Analysis/PtrAnalysis.cpp:61:3
#11 0x00005f9bedc26780 mlir::triton::PtrAnalysis::visitOperandAdd(mlir::arith::AddIOp, mlir::triton::PtrState&, mlir::Location, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState>> const&) /home/colawithsauce/Projects/Triton/triton_shared/lib/Analysis/PtrAnalysis.cpp:358:17
#12 0x00005f9bedc26af6 mlir::triton::PtrAnalysis::visitOperand(mlir::Value, mlir::triton::PtrState&, mlir::Location, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState>> const&) /home/colawithsauce/Projects/Triton/triton_shared/lib/Analysis/PtrAnalysis.cpp:0:5
#13 0x00005f9bedc27928 mlir::triton::PtrAnalysis::visitOperandExpandDims(mlir::triton::ExpandDimsOp, mlir::triton::PtrState&, mlir::Location, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState>> const&) /home/colawithsauce/Projects/Triton/triton_shared/lib/Analysis/PtrAnalysis.cpp:461:20
#14 0x00005f9bedc26ca5 mlir::triton::PtrAnalysis::visitOperand(mlir::Value, mlir::triton::PtrState&, mlir::Location, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState>> const&) /home/colawithsauce/Projects/Triton/triton_shared/lib/Analysis/PtrAnalysis.cpp:0:5
#15 0x00005f9bedc26e74 llvm::SmallVectorTemplateCommon<mlir::OpFoldResult, void>::getFirstEl() const /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/llvm/ADT/SmallVector.h:143:46
#16 0x00005f9bedc26e74 llvm::SmallVectorTemplateCommon<mlir::OpFoldResult, void>::SmallVectorTemplateCommon(unsigned long) /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/llvm/ADT/SmallVector.h:148:49
#17 0x00005f9bedc26e74 llvm::SmallVectorTemplateBase<mlir::OpFoldResult, true>::SmallVectorTemplateBase(unsigned long) /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/llvm/ADT/SmallVector.h:500:42
#18 0x00005f9bedc26e74 llvm::SmallVectorImpl<mlir::OpFoldResult>::SmallVectorImpl(unsigned int) /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/llvm/ADT/SmallVector.h:601:9
#19 0x00005f9bedc26e74 llvm::SmallVector<mlir::OpFoldResult, 6u>::SmallVector() /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/llvm/ADT/SmallVector.h:1211:19
#20 0x00005f9bedc26e74 mlir::triton::PtrState::PtrState() /home/colawithsauce/Projects/Triton/triton_shared/include/triton-shared/Analysis/PtrAnalysis.h:46:7
#21 0x00005f9bedc26e74 mlir::triton::PtrAnalysis::visitOperandMul(mlir::arith::MulIOp, mlir::triton::PtrState&, mlir::Location, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState>> const&) /home/colawithsauce/Projects/Triton/triton_shared/lib/Analysis/PtrAnalysis.cpp:373:12
#22 0x00005f9bedc26b77 mlir::triton::PtrAnalysis::visitOperand(mlir::Value, mlir::triton::PtrState&, mlir::Location, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState>> const&) /home/colawithsauce/Projects/Triton/triton_shared/lib/Analysis/PtrAnalysis.cpp:0:5
#23 0x00005f9bedc27d0b mlir::triton::PtrAnalysis::visitOperandBroadcast(mlir::triton::BroadcastOp, mlir::triton::PtrState&, mlir::Location, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState>> const&) /home/colawithsauce/Projects/Triton/triton_shared/lib/Analysis/PtrAnalysis.cpp:497:24
#24 0x00005f9bedc26c27 mlir::triton::PtrAnalysis::visitOperand(mlir::Value, mlir::triton::PtrState&, mlir::Location, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState>> const&) /home/colawithsauce/Projects/Triton/triton_shared/lib/Analysis/PtrAnalysis.cpp:0:5
#25 0x00005f9bedc266f4 llvm::SmallVectorTemplateCommon<mlir::OpFoldResult, void>::getFirstEl() const /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/llvm/ADT/SmallVector.h:143:46
#26 0x00005f9bedc266f4 llvm::SmallVectorTemplateCommon<mlir::OpFoldResult, void>::SmallVectorTemplateCommon(unsigned long) /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/llvm/ADT/SmallVector.h:148:49
#27 0x00005f9bedc266f4 llvm::SmallVectorTemplateBase<mlir::OpFoldResult, true>::SmallVectorTemplateBase(unsigned long) /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/llvm/ADT/SmallVector.h:500:42
#28 0x00005f9bedc266f4 llvm::SmallVectorImpl<mlir::OpFoldResult>::SmallVectorImpl(unsigned int) /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/llvm/ADT/SmallVector.h:601:9
#29 0x00005f9bedc266f4 llvm::SmallVector<mlir::OpFoldResult, 6u>::SmallVector() /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/llvm/ADT/SmallVector.h:1211:19
#30 0x00005f9bedc266f4 mlir::triton::PtrState::PtrState() /home/colawithsauce/Projects/Triton/triton_shared/include/triton-shared/Analysis/PtrAnalysis.h:46:7
#31 0x00005f9bedc266f4 mlir::triton::PtrAnalysis::visitOperandAdd(mlir::arith::AddIOp, mlir::triton::PtrState&, mlir::Location, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState>> const&) /home/colawithsauce/Projects/Triton/triton_shared/lib/Analysis/PtrAnalysis.cpp:355:12
#32 0x00005f9bedc26af6 mlir::triton::PtrAnalysis::visitOperand(mlir::Value, mlir::triton::PtrState&, mlir::Location, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState>> const&) /home/colawithsauce/Projects/Triton/triton_shared/lib/Analysis/PtrAnalysis.cpp:0:5
#33 0x00005f9bedc28e6e mlir::Value::operator bool() const /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/mlir/IR/Value.h:120:43
#34 0x00005f9bedc28e6e mlir::triton::PtrAnalysis::visitOperandAddptr(mlir::triton::AddPtrOp, mlir::triton::PtrState&, mlir::Location, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState>> const&) /home/colawithsauce/Projects/Triton/triton_shared/lib/Analysis/PtrAnalysis.cpp:588:3
#35 0x00005f9bedc299d8 mlir::triton::PtrAnalysis::rewriteAddptrOp(mlir::triton::AddPtrOp, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState>>&) /home/colawithsauce/Projects/Triton/triton_shared/lib/Analysis/PtrAnalysis.cpp:749:26
#36 0x00005f9bedcd6767 llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState>>::~SmallDenseMap() /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/llvm/ADT/DenseMap.h:960:11
#37 0x00005f9bedcd6767 (anonymous namespace)::LegacyAddPtrConverter::matchAndRewrite(mlir::triton::AddPtrOp, mlir::triton::AddPtrOpAdaptor, mlir::ConversionPatternRewriter&) const /home/colawithsauce/Projects/Triton/triton_shared/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp:242:3
#38 0x00005f9bedc645af mlir::OpConversionPattern<mlir::triton::AddPtrOp>::matchAndRewrite(mlir::Operation*, llvm::ArrayRef<mlir::Value>, mlir::ConversionPatternRewriter&) const /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/mlir/Transforms/DialectConversion.h:538:3
#39 0x00005f9bef7e4d50 mlir::ConversionPattern::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&) const (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x2003d50)
#40 0x00005f9bef81e194 mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<mlir::LogicalResult (mlir::Pattern const&)>)::$_2::operator()() const PatternApplicator.cpp:0:0
#41 0x00005f9bef81acaf mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<mlir::LogicalResult (mlir::Pattern const&)>) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x2039caf)
#42 0x00005f9bef7f2315 (anonymous namespace)::OperationLegalizer::legalize(mlir::Operation*, mlir::ConversionPatternRewriter&) DialectConversion.cpp:0:0
#43 0x00005f9bef7e7f44 (anonymous namespace)::OperationConverter::convertOperations(llvm::ArrayRef<mlir::Operation*>, llvm::function_ref<void (mlir::Diagnostic&)>) DialectConversion.cpp:0:0
#44 0x00005f9bef7eb1e0 mlir::applyPartialConversion(mlir::Operation*, mlir::ConversionTarget const&, mlir::FrozenRewritePatternSet const&, llvm::DenseSet<mlir::Operation*, llvm::DenseMapInfo<mlir::Operation*, void>>*) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x200a1e0)
#45 0x00005f9bedcec085 (anonymous namespace)::TritonToLinalgPass::runOnOperation() /home/colawithsauce/Projects/Triton/triton_shared/lib/Conversion/TritonToLinalg/TritonToLinalgPass.cpp:184:16
#46 0x00005f9bef0cc6a6 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x18eb6a6)
#47 0x00005f9bef0cce41 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x18ebe41)
#48 0x00005f9bef0cf27b mlir::PassManager::run(mlir::Operation*) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x18ee27b)
#49 0x00005f9bef0c959f performActions(llvm::raw_ostream&, std::shared_ptr<llvm::SourceMgr> const&, mlir::MLIRContext*, mlir::MlirOptMainConfig const&) MlirOptMain.cpp:0:0
#50 0x00005f9bef0c87bd mlir::LogicalResult llvm::function_ref<mlir::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>::callback_fn<mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&)::$_2>(long, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&) MlirOptMain.cpp:0:0
#51 0x00005f9befb99f89 mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::function_ref<mlir::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>, llvm::raw_ostream&, bool, bool) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x23b8f89)
#52 0x00005f9bef0c3bfa mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x18e2bfa)
#53 0x00005f9bef0c3d96 mlir::MlirOptMain(int, char**, llvm::StringRef, llvm::StringRef, mlir::DialectRegistry&) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x18e2d96)
#54 0x00005f9bef0c4146 mlir::MlirOptMain(int, char**, llvm::StringRef, mlir::DialectRegistry&) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x18e3146)
#55 0x00005f9bedd79bcb main /home/colawithsauce/Projects/Triton/triton_shared/tools/triton-shared-opt/triton-shared-opt.cpp:16:33
#56 0x000077438ec3f0ce __libc_start_call_main (/nix/store/1rm6sr6ixxzipv5358x0cmaw8rs84g2j-glibc-2.38-44/lib/libc.so.6+0x280ce)
#57 0x000077438ec3f189 __libc_start_main@GLIBC_2.2.5 (/nix/store/1rm6sr6ixxzipv5358x0cmaw8rs84g2j-glibc-2.38-44/lib/libc.so.6+0x28189)
#58 0x00005f9bedaadc75 _start (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x2ccc75)
fish: Job 1, 'triton-shared-opt --triton-to-l…' terminated by signal SIGABRT (Abort)

Additional information

@colawithsauce colawithsauce added the bug Something isn't working label Apr 5, 2024
@colawithsauce
Copy link
Author

I used triton-shared-opt --triton-to-linalg-experimental instead, here is the new error information:

[16:48:00] colawithsauce@SPQR /home/colawithsauce/Projects/Triton/triton_shared
> triton-opt ~/playground/test.mlir| triton-shared-opt --triton-to-linalg-experimental
PtrAnalysis: encountered addptr operand produced by an unsupported operation
%27 = arith.extsi %26 : tensor<64xi32> to tensor<64xi64>
<stdin>:77:13: remark: PtrAnalysis: Failed to rewrite AddPtrOp
      %90 = tt.addptr %23, %89 : tensor<64x32x!tt.ptr<f16, 1>>, tensor<64x32xi64>
            ^
<stdin>:77:13: note: see current operation: %92 = tt.addptr %24, %91 : tensor<64x32x!tt.ptr<f16, 1>>, tensor<64x32xi64>
<stdin>:83:13: remark: PtrAnalysis: pointer is not replace with tts.make_tptr so loadOp cannot be rewritten
      %96 = tt.load %90, %95 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16>
            ^
<stdin>:83:13: note: see current operation: %98 = tt.load %92, %97 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16>
<stdin>:83:13: remark: PtrAnalysis: Failed to rewrite LoadOp
      %96 = tt.load %90, %95 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16>
            ^
<stdin>:83:13: note: see current operation: %98 = tt.load %92, %97 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x32xf16>
PtrAnalysis: encountered addptr operand produced by an unsupported operation
%34 = arith.extsi %33 : tensor<32xi32> to tensor<32xi64>
<stdin>:90:14: remark: PtrAnalysis: Failed to rewrite AddPtrOp
      %103 = tt.addptr %40, %102 : tensor<32x128x!tt.ptr<f16, 1>>, tensor<32x128xi64>
             ^
<stdin>:90:14: note: see current operation: %106 = tt.addptr %41, %104 : tensor<32x128x!tt.ptr<f16, 1>>, tensor<32x128xi64>
<stdin>:96:14: remark: PtrAnalysis: pointer is not replace with tts.make_tptr so loadOp cannot be rewritten
      %109 = tt.load %103, %108 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16>
             ^
<stdin>:96:14: note: see current operation: %112 = tt.load %106, %111 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16>
<stdin>:96:14: remark: PtrAnalysis: Failed to rewrite LoadOp
      %109 = tt.load %103, %108 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16>
             ^
<stdin>:96:14: note: see current operation: %112 = tt.load %106, %111 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16>
PtrAnalysis: encountered addptr operand produced by an unsupported operation
%62 = arith.extsi %61 : tensor<64xi32> to tensor<64xi64>
<stdin>:120:11: remark: PtrAnalysis: Failed to rewrite AddPtrOp
    %73 = tt.addptr %57, %72 : tensor<64x128x!tt.ptr<f16, 1>>, tensor<64x128xi64>
          ^
<stdin>:120:11: note: see current operation: %75 = tt.addptr %59, %74 : tensor<64x128x!tt.ptr<f16, 1>>, tensor<64x128xi64>
<stdin>:132:5: remark: PtrAnalysis: pointer is not replace with tts.make_tptr so storeOp cannot be rewritten
    tt.store %73, %55, %84 {cache = 1 : i32, evict = 1 : i32} : tensor<64x128xf16>
    ^
<stdin>:132:5: note: see current operation: tt.store %75, %57, %86 {cache = 1 : i32, evict = 1 : i32} : tensor<64x128xf16>
<stdin>:132:5: remark: PtrAnalysis: Failed to rewrite StoreOp
    tt.store %73, %55, %84 {cache = 1 : i32, evict = 1 : i32} : tensor<64x128xf16>
    ^
<stdin>:132:5: note: see current operation: tt.store %75, %57, %86 {cache = 1 : i32, evict = 1 : i32} : tensor<64x128xf16>
LLVM ERROR: Failed to infer result type(s).
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.      Program arguments: triton-shared-opt --triton-to-linalg-experimental
 #0 0x00006282cf875087 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x242a087)
 #1 0x00006282cf872bae llvm::sys::RunSignalHandlers() (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x2427bae)
 #2 0x00006282cf87573f SignalHandler(int) Signals.cpp:0:0
 #3 0x000077b037e54eb0 __restore_rt (/nix/store/1rm6sr6ixxzipv5358x0cmaw8rs84g2j-glibc-2.38-44/lib/libc.so.6+0x3deb0)
 #4 0x000077b037ea407c __pthread_kill_implementation (/nix/store/1rm6sr6ixxzipv5358x0cmaw8rs84g2j-glibc-2.38-44/lib/libc.so.6+0x8d07c)
 #5 0x000077b037e54e06 gsignal (/nix/store/1rm6sr6ixxzipv5358x0cmaw8rs84g2j-glibc-2.38-44/lib/libc.so.6+0x3de06)
 #6 0x000077b037e3d8f5 abort (/nix/store/1rm6sr6ixxzipv5358x0cmaw8rs84g2j-glibc-2.38-44/lib/libc.so.6+0x268f5)
 #7 0x00006282cf836d00 llvm::report_fatal_error(llvm::Twine const&, bool) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x23ebd00)
 #8 0x00006282cf836b18 (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x23ebb18)
 #9 0x00006282ce3f05b1 (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0xfa55b1)
#10 0x00006282cd97a18f mlir::memref::ExtractStridedMetadataOp mlir::OpBuilder::create<mlir::memref::ExtractStridedMetadataOp, mlir::Value&>(mlir::Location, mlir::Value&) /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/mlir/IR/Builders.h:509:16
#11 0x00006282cd978648 (anonymous namespace)::ScalarAddptrConverter::matchAndRewrite(mlir::triton::AddPtrOp, mlir::triton::AddPtrOpAdaptor, mlir::ConversionPatternRewriter&) const /home/colawithsauce/Projects/Triton/triton_shared/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp:766:18
#12 0x00006282cd8ce5af mlir::OpConversionPattern<mlir::triton::AddPtrOp>::matchAndRewrite(mlir::Operation*, llvm::ArrayRef<mlir::Value>, mlir::ConversionPatternRewriter&) const /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/mlir/Transforms/DialectConversion.h:538:3
#13 0x00006282cf44ed50 mlir::ConversionPattern::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&) const (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x2003d50)
#14 0x00006282cf488194 mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<mlir::LogicalResult (mlir::Pattern const&)>)::$_2::operator()() const PatternApplicator.cpp:0:0
#15 0x00006282cf484caf mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<mlir::LogicalResult (mlir::Pattern const&)>) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x2039caf)
#16 0x00006282cf45c315 (anonymous namespace)::OperationLegalizer::legalize(mlir::Operation*, mlir::ConversionPatternRewriter&) DialectConversion.cpp:0:0
#17 0x00006282cf451f44 (anonymous namespace)::OperationConverter::convertOperations(llvm::ArrayRef<mlir::Operation*>, llvm::function_ref<void (mlir::Diagnostic&)>) DialectConversion.cpp:0:0
#18 0x00006282cf4551e0 mlir::applyPartialConversion(mlir::Operation*, mlir::ConversionTarget const&, mlir::FrozenRewritePatternSet const&, llvm::DenseSet<mlir::Operation*, llvm::DenseMapInfo<mlir::Operation*, void>>*) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x200a1e0)
#19 0x00006282cd97b2d3 (anonymous namespace)::StructuredToMemrefPass::runOnOperation() /home/colawithsauce/Projects/Triton/triton_shared/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp:141:16
#20 0x00006282ced366a6 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x18eb6a6)
#21 0x00006282ced36e41 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x18ebe41)
#22 0x00006282ced3b4e8 mlir::LogicalResult llvm::function_ref<mlir::LogicalResult (mlir::OpPassManager&, mlir::Operation*)>::callback_fn<mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::$_6>(long, mlir::OpPassManager&, mlir::Operation*) Pass.cpp:0:0
#23 0x00006282cd961389 mlir::LogicalResult::failed() const /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/mlir/Support/LogicalResult.h:44:33
#24 0x00006282cd961389 mlir::failed(mlir::LogicalResult) /home/colawithsauce/.triton/llvm/llvm-4017f04e-ubuntu-x64/include/mlir/Support/LogicalResult.h:72:58
#25 0x00006282cd961389 (anonymous namespace)::TritonToLinalgExperimentalPass::runOnOperation() /home/colawithsauce/Projects/Triton/triton_shared/lib/Conversion/TritonToLinalgExperimental/TritonToLinalgExperimentalPass.cpp:55:9
#26 0x00006282ced366a6 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x18eb6a6)
#27 0x00006282ced36e41 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x18ebe41)
#28 0x00006282ced3927b mlir::PassManager::run(mlir::Operation*) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x18ee27b)
#29 0x00006282ced3359f performActions(llvm::raw_ostream&, std::shared_ptr<llvm::SourceMgr> const&, mlir::MLIRContext*, mlir::MlirOptMainConfig const&) MlirOptMain.cpp:0:0
#30 0x00006282ced327bd mlir::LogicalResult llvm::function_ref<mlir::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>::callback_fn<mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&)::$_2>(long, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&) MlirOptMain.cpp:0:0
#31 0x00006282cf803f89 mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::function_ref<mlir::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>, llvm::raw_ostream&, bool, bool) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x23b8f89)
#32 0x00006282ced2dbfa mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x18e2bfa)
#33 0x00006282ced2dd96 mlir::MlirOptMain(int, char**, llvm::StringRef, llvm::StringRef, mlir::DialectRegistry&) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x18e2d96)
#34 0x00006282ced2e146 mlir::MlirOptMain(int, char**, llvm::StringRef, mlir::DialectRegistry&) (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x18e3146)
#35 0x00006282cd9e3bcb main /home/colawithsauce/Projects/Triton/triton_shared/tools/triton-shared-opt/triton-shared-opt.cpp:16:33
#36 0x000077b037e3f0ce __libc_start_call_main (/nix/store/1rm6sr6ixxzipv5358x0cmaw8rs84g2j-glibc-2.38-44/lib/libc.so.6+0x280ce)
#37 0x000077b037e3f189 __libc_start_main@GLIBC_2.2.5 (/nix/store/1rm6sr6ixxzipv5358x0cmaw8rs84g2j-glibc-2.38-44/lib/libc.so.6+0x28189)
#38 0x00006282cd717c75 _start (/home/colawithsauce/Projects/Triton/triton/python/build/cmake.linux-x86_64-cpython-3.11/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x2ccc75)
fish: Process 424253, 'triton-shared-opt' from job 1, 'triton-opt ~/playground/test.ml…' terminated by signal SIGABRT (Abort)

@nhat-nguyen
Copy link
Collaborator

Thank you for the report for both the legacy and experimental passes! This looks like a matmul kernel -- perhaps from the official triton tutorial? I will take a closer look. It has been a while since we last updated the tutorial kernel tests. But if this is the official one, it's definitely important that we fix it.

@nhat-nguyen nhat-nguyen self-assigned this Apr 5, 2024
@colawithsauce
Copy link
Author

@nhat-nguyen yes, it is indeed from triton tutorial, I only did some changes on the kernel's caller, and left the kernel un-modified.

@colawithsauce
Copy link
Author

@nhat-nguyen, this code is from triton's tutorial 08-experimental-block-pointer. However, today when I going to search this tutorial page on triton's online document, I got 404 not found. I didn't konw when it had been deleted. here is the wayback machine link to it

I search block pointer on triton's issue page, I found this comment from its contributor: triton-lang/triton#1946 (comment)

I guess block_ptr is currently still work in progress, and usage of it are not recommended.

Thanks for your attention!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants