Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: failed to lower tl.make_block_ptr #198

Open
TracyMac1 opened this issue Dec 2, 2024 · 0 comments
Open

[Bug]: failed to lower tl.make_block_ptr #198

TracyMac1 opened this issue Dec 2, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@TracyMac1
Copy link

Triton python code

@triton.jit
def mul_func_kernel_rank_2(
    in0_ptr: tl.tensor, # of tl.pointer_type
    in1_ptr: tl.tensor, # of tl.pointer_type
    out0_ptr: tl.tensor, # of tl.pointer_type
    in0_stride0: int, in0_stride1: int, # strides for in0
    in0_stride_order0: tl.constexpr, in0_stride_order1: tl.constexpr, # stride order for in0
    in1_stride0: int, in1_stride1: int, # strides for in1
    in1_stride_order0: tl.constexpr, in1_stride_order1: tl.constexpr, # stride order for in1
    out0_stride0: int, out0_stride1: int, # strides for out0
    out0_stride_order0: tl.constexpr, out0_stride_order1: tl.constexpr, # stride order for out0
    s0: int, s1: int, # task_space
    num_tasks: int,
    tiles_per_cta: int,
    tile_size0: tl.constexpr, tile_size1: tl.constexpr,
    one_tile_per_cta: tl.constexpr,
):
    pid = tl.program_id(0)
    num_tiles0 = tl.cdiv(s0, tile_size0)
    num_tiles1 = tl.cdiv(s1, tile_size1)
    if one_tile_per_cta: # monolitic kernel style
        tile_id = pid
        # pid multi index recontruction: we use c ordering, right axes changes fastest
        tile_id1 = tile_id % num_tiles1
        tile_id //= num_tiles1
        tile_id0 = tile_id

        # tile offsets
        offset0 = tile_id0 * tile_size0
        offset1 = tile_id1 * tile_size1
        # loads
        in0_bptr = tl.make_block_ptr(in0_ptr, (s0, s1), (in0_stride0, in0_stride1), (offset0, offset1), (tile_size0, tile_size1), order=(in0_stride_order0, in0_stride_order1))
        in0 = tl.load(in0_bptr, boundary_check=(in0_stride_order0, in0_stride_order1)).to(in0_ptr.type.element_ty) # workaround the bug on bool, we should use the original pointer's dtype(instead of block pointer's)
        in1_bptr = tl.make_block_ptr(in1_ptr, (s0, s1), (in1_stride0, in1_stride1), (offset0, offset1), (tile_size0, tile_size1), order=(in1_stride_order0, in1_stride_order1))
        in1 = tl.load(in1_bptr, boundary_check=(in1_stride_order0, in1_stride_order1)).to(in1_ptr.type.element_ty) # workaround the bug on bool, we should use the original pointer's dtype(instead of block pointer's)

        # compute
        out0 = mul_func(in0, in1)

        # stores, note that store to block pointer does not automatically cast the value to the pointer's dtype
        out0_bptr = tl.make_block_ptr(out0_ptr, (s0, s1), (out0_stride0, out0_stride1), (offset0, offset1), (tile_size0, tile_size1), order=(out0_stride_order0, out0_stride_order1))
        tl.store(out0_bptr, out0.to(out0_bptr.type.element_ty), boundary_check=(out0_stride_order0, out0_stride_order1))
    else: # grid-stride-loop style kernel
        num_ctas = tl.num_programs(0)
        for j in range(0, tiles_per_cta):
            tile_id = pid + j * num_ctas
            # pid multi index recontruction: we use c ordering, right axes changes fastest
            tile_id1 = tile_id % num_tiles1
            tile_id //= num_tiles1
            tile_id0 = tile_id

            # tile offsets
            offset0 = tile_id0 * tile_size0
            offset1 = tile_id1 * tile_size1
            # loads
            in0_bptr = tl.make_block_ptr(in0_ptr, (s0, s1), (in0_stride0, in0_stride1), (offset0, offset1), (tile_size0, tile_size1), order=(in0_stride_order0, in0_stride_order1))
            in0 = tl.load(in0_bptr, boundary_check=(in0_stride_order0, in0_stride_order1)).to(in0_ptr.type.element_ty) # workaround the bug on bool, we should use the original pointer's dtype(instead of block pointer's)
            in1_bptr = tl.make_block_ptr(in1_ptr, (s0, s1), (in1_stride0, in1_stride1), (offset0, offset1), (tile_size0, tile_size1), order=(in1_stride_order0, in1_stride_order1))
            in1 = tl.load(in1_bptr, boundary_check=(in1_stride_order0, in1_stride_order1)).to(in1_ptr.type.element_ty) # workaround the bug on bool, we should use the original pointer's dtype(instead of block pointer's)

            # compute
            out0 = mul_func(in0, in1)

            # stores, note that store to block pointer does not automatically cast the value to the pointer's dtype
            out0_bptr = tl.make_block_ptr(out0_ptr, (s0, s1), (out0_stride0, out0_stride1), (offset0, offset1), (tile_size0, tile_size1), order=(out0_stride_order0, out0_stride_order1))
            tl.store(out0_bptr, out0.to(out0_bptr.type.element_ty), boundary_check=(out0_stride_order0, out0_stride_order1))

Triton IR

module {
  tt.func public @mul_func_kernel_rank_2(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} , %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} , %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} , %arg3: i32 {tt.divisibility = 16 : i32} , %arg4: i32 {tt.divisibility = 16 : i32} , %arg5: i32 , %arg6: i32 , %arg7: i32 , %arg8: i32 {tt.divisibility = 16 : i32} ) attributes {noinline = false} {
    %c3_i32 = arith.constant 3 : i32
    %c1_i64 = arith.constant 1 : i64
    %c4_i32 = arith.constant 4 : i32
    %0 = tt.get_program_id x : i32
    %1 = arith.addi %arg7, %c3_i32 : i32
    %2 = arith.divsi %1, %c4_i32 : i32
    %3 = arith.remsi %0, %2 : i32
    %4 = arith.divsi %0, %2 : i32
    %5 = arith.muli %4, %c4_i32 : i32
    %6 = arith.muli %3, %c4_i32 : i32
    %7 = arith.extsi %arg6 : i32 to i64
    %8 = arith.extsi %arg7 : i32 to i64
    %9 = arith.extsi %arg3 : i32 to i64
    %10 = tt.make_tensor_ptr %arg0, [%7, %8], [%c1_i64, %9], [%5, %6] {order = array<i32: 1, 0>} : <tensor<4x4xf32>>
    %11 = tt.load %10 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<4x4xf32>>
    %12 = arith.extsi %arg4 : i32 to i64
    %13 = tt.make_tensor_ptr %arg1, [%7, %8], [%12, %c1_i64], [%5, %6] {order = array<i32: 0, 1>} : <tensor<4x4xf32>>
    %14 = tt.load %13 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<4x4xf32>>
    %15 = arith.mulf %11, %14 : tensor<4x4xf32>
    %16 = arith.extsi %arg5 : i32 to i64
    %17 = tt.make_tensor_ptr %arg2, [%7, %8], [%16, %c1_i64], [%5, %6] {order = array<i32: 1, 0>} : <tensor<4x4xf32>>
    tt.store %17, %15 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<4x4xf32>>
    tt.return
  }
}

Crash log

test_mul_ops.py:94:104: error: non-decreasing dimension order on tensor pointers are not yet supported
        in1_bptr = tl.make_block_ptr(in1_ptr, (s0, s1), (in1_stride0, in1_stride1), (offset0, offset1), (tile_size0, tile_size1), order=(in1_stride_order0, in1_stride_order1))
                                                                                                       ^
test_mul_ops.py:94:104: error: failed to legalize operation 'tts.make_tptr' that was explicitly marked illegal
        in1_bptr = tl.make_block_ptr(in1_ptr, (s0, s1), (in1_stride0, in1_stride1), (offset0, offset1), (tile_size0, tile_size1), order=(in1_stride_order0, in1_stride_order1))
                                                                                                       ^

Additional information

This is a mul op with broadcast.Same as torch.mul(torch.randn(4,1), torch.randn(1,4)).

@TracyMac1 TracyMac1 added the bug Something isn't working label Dec 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant