Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] block ptr out-of-bound access is not handled #203

Open
Nullkooland opened this issue Dec 27, 2024 · 0 comments
Open

[Bug] block ptr out-of-bound access is not handled #203

Nullkooland opened this issue Dec 27, 2024 · 0 comments

Comments

@Nullkooland
Copy link
Contributor

Nullkooland commented Dec 27, 2024

#90 Introduced triton block pointer (triton.language.make_block_ptr) support, however there's no code logic to handle out-of-bound access behavior, which is specified in the boundary_check and padding_option parameters of triton.language.load/store op:

If pointer is a block pointer defined by make_block_ptr, a tensor is loaded. In this case:
mask and other must be None, and
boundary_check and padding_option can be specified to control the behavior of out-of-bound access.

See docs of tl.load and tl.store

For instance, the triton kernel:
(BLOCK_SIZE_X=256)

@triton.jit
def abs_kernel(
    x_ptr,
    out_ptr,
    numel: int,
    BLOCK_SIZE_X: tl.constexpr,
):
    pid_x = tl.program_id(axis=0)
    offset_x = pid_x * BLOCK_SIZE_X
    x_block_ptr = tl.make_block_ptr(
        x_ptr,
        shape=[numel],
        strides=[1],
        block_shape=[BLOCK_SIZE_X],
        order=[0],
        offsets=[offset_x]
    )

    x = tl.load(x_block_ptr, boundary_check=[0])
    x_abs = tl.math.abs(x)

    out_block_ptr = tl.make_block_ptr(
        out_ptr,
        shape=[numel],
        strides=[1],
        block_shape=[BLOCK_SIZE_X],
        order=[0],
        offsets=[offset_x]
    )
    tl.store(out_block_ptr, x_abs, boundary_check=[0])

triton ir:

module {
  tt.func public @abs_kernel(
  %arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32},
  %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32},
  %arg2: i32 {tt.divisibility = 16 : i32}
) attributes {noinline = false} {
    %c1_i64 = arith.constant 1 : i64
    %c256_i32 = arith.constant 256 : i32
    %0 = tt.get_program_id x : i32
    %1 = arith.muli %0, %c256_i32 : i32
    %2 = arith.extsi %arg2 : i32 to i64
    %3 = tt.make_tensor_ptr %arg0, [%2], [%c1_i64], [%1] {order = array<i32: 0>} : <tensor<256xf16>>
    %4 = tt.load %3 {boundaryCheck = array<i32: 0>} : !tt.ptr<tensor<256xf16>>
    %5 = math.absf %4 : tensor<256xf16>
    %6 = tt.make_tensor_ptr %arg1, [%2], [%c1_i64], [%1] {order = array<i32: 0>} : <tensor<256xf16>>
    tt.store %6, %5 {boundaryCheck = array<i32: 0>} : !tt.ptr<tensor<256xf16>>
    tt.return
  }
}

triton strctured ir:

module {
  tt.func public @abs_kernel(
  %arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32},
  %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32},
  %arg2: i32 {tt.divisibility = 16 : i32}
) attributes {noinline = false} {
    %c1_i64 = arith.constant 1 : i64
    %c256_i32 = arith.constant 256 : i32
    %0 = tt.get_program_id x : i32
    %1 = arith.muli %0, %c256_i32 : i32
    %2 = arith.extsi %arg2 : i32 to i64
    %3 = arith.index_cast %c1_i64 : i64 to index
    %4 = arith.index_cast %1 : i32 to index
    %5 = arith.muli %4, %3 : index
    %6 = arith.index_cast %2 : i64 to index
    %7 = tts.make_tptr %arg0 to sizes: [256], strides: [%3], offsets: [%5], shape: [%6], order: [0] : <f16> to !tt.ptr<tensor<256xf16>>
    %8 = tt.make_tensor_ptr %arg0, [%2], [%c1_i64], [%1] {order = array<i32: 0>} : <tensor<256xf16>>
    %9 = "tts.load"(%7) <{operandSegmentSizes = array<i32: 1, 0, 0>, static_mask_dims = array<i64>}> : (!tt.ptr<tensor<256xf16>>) -> tensor<256xf16>
    %10 = math.absf %9 : tensor<256xf16>
    %11 = arith.index_cast %c1_i64 : i64 to index
    %12 = arith.index_cast %1 : i32 to index
    %13 = arith.muli %12, %11 : index
    %14 = arith.index_cast %2 : i64 to index
    %15 = tts.make_tptr %arg1 to sizes: [256], strides: [%11], offsets: [%13], shape: [%14], order: [0] : <f16> to !tt.ptr<tensor<256xf16>>
    %16 = tt.make_tensor_ptr %arg1, [%2], [%c1_i64], [%1] {order = array<i32: 0>} : <tensor<256xf16>>
    "tts.store"(%15, %10) <{static_mask_dims = array<i64>}> : (!tt.ptr<tensor<256xf16>>, tensor<256xf16>) -> ()
    tt.return
  }
}

linalg and memref ir:

#map = affine_map<(d0) -> (d0)>
module {
  func.func @abs_kernel(
  %arg0: memref<*xf16> {tt.divisibility = 16 : i32},
  %arg1: memref<*xf16> {tt.divisibility = 16 : i32},
  %arg2: i32 {tt.divisibility = 16 : i32},
  %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32
) {
    %c256_i32 = arith.constant 256 : i32
    %c1 = arith.constant 1 : index
    %0 = arith.muli %arg6, %c256_i32 : i32
    %1 = arith.index_cast %0 : i32 to index
    %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%1], sizes: [256], strides: [%c1] : memref<*xf16> to memref<256xf16, strided<[?], offset: ?>>
    %alloc = memref.alloc() : memref<256xf16>
    memref.copy %reinterpret_cast, %alloc : memref<256xf16, strided<[?], offset: ?>> to memref<256xf16>
    %2 = bufferization.to_tensor %alloc restrict writable : memref<256xf16>
    %3 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%2 : tensor<256xf16>) outs(%2 : tensor<256xf16>) {
    ^bb0(%in: f16, %out: f16):
      %4 = math.absf %in : f16
      linalg.yield %4 : f16
    } -> tensor<256xf16>
    %reinterpret_cast_0 = memref.reinterpret_cast %arg1 to offset: [%1], sizes: [256], strides: [%c1] : memref<*xf16> to memref<256xf16, strided<[?], offset: ?>>
    bufferization.materialize_in_destination %3 in writable %reinterpret_cast_0 : (tensor<256xf16>, memref<256xf16, strided<[?], offset: ?>>) -> ()
    return
  }
}

Obviously the first memref.reinterpret_cast might go out-of-bound for the last block program instance if numel % BLOCK_SIZE != 0, there should be some sort of subview with upper bound of min(numel, offset + BLOCK_SIZE) to limit access range.

@haishanzzzz Could you take a look? THX.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant