We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
@triton.jit def mul_func_kernel_rank_2( in0_ptr: tl.tensor, # of tl.pointer_type in1_ptr: tl.tensor, # of tl.pointer_type out0_ptr: tl.tensor, # of tl.pointer_type in0_stride0: int, in0_stride1: int, # strides for in0 in0_stride_order0: tl.constexpr, in0_stride_order1: tl.constexpr, # stride order for in0 in1_stride0: int, in1_stride1: int, # strides for in1 in1_stride_order0: tl.constexpr, in1_stride_order1: tl.constexpr, # stride order for in1 out0_stride0: int, out0_stride1: int, # strides for out0 out0_stride_order0: tl.constexpr, out0_stride_order1: tl.constexpr, # stride order for out0 s0: int, s1: int, # task_space num_tasks: int, tiles_per_cta: int, tile_size0: tl.constexpr, tile_size1: tl.constexpr, one_tile_per_cta: tl.constexpr, ): pid = tl.program_id(0) num_tiles0 = tl.cdiv(s0, tile_size0) num_tiles1 = tl.cdiv(s1, tile_size1) if one_tile_per_cta: # monolitic kernel style tile_id = pid # pid multi index recontruction: we use c ordering, right axes changes fastest tile_id1 = tile_id % num_tiles1 tile_id //= num_tiles1 tile_id0 = tile_id # tile offsets offset0 = tile_id0 * tile_size0 offset1 = tile_id1 * tile_size1 # loads in0_bptr = tl.make_block_ptr(in0_ptr, (s0, s1), (in0_stride0, in0_stride1), (offset0, offset1), (tile_size0, tile_size1), order=(in0_stride_order0, in0_stride_order1)) in0 = tl.load(in0_bptr, boundary_check=(in0_stride_order0, in0_stride_order1)).to(in0_ptr.type.element_ty) # workaround the bug on bool, we should use the original pointer's dtype(instead of block pointer's) in1_bptr = tl.make_block_ptr(in1_ptr, (s0, s1), (in1_stride0, in1_stride1), (offset0, offset1), (tile_size0, tile_size1), order=(in1_stride_order0, in1_stride_order1)) in1 = tl.load(in1_bptr, boundary_check=(in1_stride_order0, in1_stride_order1)).to(in1_ptr.type.element_ty) # workaround the bug on bool, we should use the original pointer's dtype(instead of block pointer's) # compute out0 = mul_func(in0, in1) # stores, note that store to block pointer does not automatically cast the value to the pointer's dtype out0_bptr = tl.make_block_ptr(out0_ptr, (s0, s1), (out0_stride0, out0_stride1), (offset0, offset1), (tile_size0, tile_size1), order=(out0_stride_order0, out0_stride_order1)) tl.store(out0_bptr, out0.to(out0_bptr.type.element_ty), boundary_check=(out0_stride_order0, out0_stride_order1)) else: # grid-stride-loop style kernel num_ctas = tl.num_programs(0) for j in range(0, tiles_per_cta): tile_id = pid + j * num_ctas # pid multi index recontruction: we use c ordering, right axes changes fastest tile_id1 = tile_id % num_tiles1 tile_id //= num_tiles1 tile_id0 = tile_id # tile offsets offset0 = tile_id0 * tile_size0 offset1 = tile_id1 * tile_size1 # loads in0_bptr = tl.make_block_ptr(in0_ptr, (s0, s1), (in0_stride0, in0_stride1), (offset0, offset1), (tile_size0, tile_size1), order=(in0_stride_order0, in0_stride_order1)) in0 = tl.load(in0_bptr, boundary_check=(in0_stride_order0, in0_stride_order1)).to(in0_ptr.type.element_ty) # workaround the bug on bool, we should use the original pointer's dtype(instead of block pointer's) in1_bptr = tl.make_block_ptr(in1_ptr, (s0, s1), (in1_stride0, in1_stride1), (offset0, offset1), (tile_size0, tile_size1), order=(in1_stride_order0, in1_stride_order1)) in1 = tl.load(in1_bptr, boundary_check=(in1_stride_order0, in1_stride_order1)).to(in1_ptr.type.element_ty) # workaround the bug on bool, we should use the original pointer's dtype(instead of block pointer's) # compute out0 = mul_func(in0, in1) # stores, note that store to block pointer does not automatically cast the value to the pointer's dtype out0_bptr = tl.make_block_ptr(out0_ptr, (s0, s1), (out0_stride0, out0_stride1), (offset0, offset1), (tile_size0, tile_size1), order=(out0_stride_order0, out0_stride_order1)) tl.store(out0_bptr, out0.to(out0_bptr.type.element_ty), boundary_check=(out0_stride_order0, out0_stride_order1))
module { tt.func public @mul_func_kernel_rank_2(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} , %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} , %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} , %arg3: i32 {tt.divisibility = 16 : i32} , %arg4: i32 {tt.divisibility = 16 : i32} , %arg5: i32 , %arg6: i32 , %arg7: i32 , %arg8: i32 {tt.divisibility = 16 : i32} ) attributes {noinline = false} { %c3_i32 = arith.constant 3 : i32 %c1_i64 = arith.constant 1 : i64 %c4_i32 = arith.constant 4 : i32 %0 = tt.get_program_id x : i32 %1 = arith.addi %arg7, %c3_i32 : i32 %2 = arith.divsi %1, %c4_i32 : i32 %3 = arith.remsi %0, %2 : i32 %4 = arith.divsi %0, %2 : i32 %5 = arith.muli %4, %c4_i32 : i32 %6 = arith.muli %3, %c4_i32 : i32 %7 = arith.extsi %arg6 : i32 to i64 %8 = arith.extsi %arg7 : i32 to i64 %9 = arith.extsi %arg3 : i32 to i64 %10 = tt.make_tensor_ptr %arg0, [%7, %8], [%c1_i64, %9], [%5, %6] {order = array<i32: 1, 0>} : <tensor<4x4xf32>> %11 = tt.load %10 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<4x4xf32>> %12 = arith.extsi %arg4 : i32 to i64 %13 = tt.make_tensor_ptr %arg1, [%7, %8], [%12, %c1_i64], [%5, %6] {order = array<i32: 0, 1>} : <tensor<4x4xf32>> %14 = tt.load %13 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<4x4xf32>> %15 = arith.mulf %11, %14 : tensor<4x4xf32> %16 = arith.extsi %arg5 : i32 to i64 %17 = tt.make_tensor_ptr %arg2, [%7, %8], [%16, %c1_i64], [%5, %6] {order = array<i32: 1, 0>} : <tensor<4x4xf32>> tt.store %17, %15 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<4x4xf32>> tt.return } }
test_mul_ops.py:94:104: error: non-decreasing dimension order on tensor pointers are not yet supported in1_bptr = tl.make_block_ptr(in1_ptr, (s0, s1), (in1_stride0, in1_stride1), (offset0, offset1), (tile_size0, tile_size1), order=(in1_stride_order0, in1_stride_order1)) ^ test_mul_ops.py:94:104: error: failed to legalize operation 'tts.make_tptr' that was explicitly marked illegal in1_bptr = tl.make_block_ptr(in1_ptr, (s0, s1), (in1_stride0, in1_stride1), (offset0, offset1), (tile_size0, tile_size1), order=(in1_stride_order0, in1_stride_order1)) ^
This is a mul op with broadcast.Same as torch.mul(torch.randn(4,1), torch.randn(1,4)).
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Triton python code
Triton IR
Crash log
test_mul_ops.py:94:104: error: non-decreasing dimension order on tensor pointers are not yet supported in1_bptr = tl.make_block_ptr(in1_ptr, (s0, s1), (in1_stride0, in1_stride1), (offset0, offset1), (tile_size0, tile_size1), order=(in1_stride_order0, in1_stride_order1)) ^ test_mul_ops.py:94:104: error: failed to legalize operation 'tts.make_tptr' that was explicitly marked illegal in1_bptr = tl.make_block_ptr(in1_ptr, (s0, s1), (in1_stride0, in1_stride1), (offset0, offset1), (tile_size0, tile_size1), order=(in1_stride_order0, in1_stride_order1)) ^
Additional information
This is a mul op with broadcast.Same as torch.mul(torch.randn(4,1), torch.randn(1,4)).
The text was updated successfully, but these errors were encountered: