We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
import logging import torch import triton import triton.language as tl @triton.autotune( configs=[ triton.Config({"BLOCK_M": m, "BLOCK_N": n}, num_stages=s, num_warps=w) for m in [32, 64, 128] for n in [1, 2, 4, 8] for s in [3, 4] for w in [4, 8] ], key=["M", "N"], ) @triton.jit def mv_kernel( A, B, C, N, M, stride_an, stride_am, stride_bm, stride_cn, BLOCK_N: tl.constexpr, BLOCK_M: tl.constexpr, ): pid = tl.program_id(0) offset_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)[:, None] offset_m = tl.arange(0, BLOCK_M)[None, :] n_mask = offset_n < N A_ptrs = A + offset_n * stride_an + offset_m * stride_am B_ptrs = B + offset_m * stride_bm acc = tl.zeros((BLOCK_N, BLOCK_M), dtype=tl.float32) for m in range(0, M, BLOCK_M): m_mask = m + offset_m < M a = tl.load(A_ptrs, mask=n_mask & m_mask, other=0.0).to(tl.float32) b = tl.load(B_ptrs, mask=m_mask, other=0.0).to(tl.float32) acc += a * b A_ptrs += BLOCK_M * stride_am B_ptrs += BLOCK_M * stride_bm acc = tl.sum(acc, axis=1) C_ptrs = C + offset_n * stride_cn tl.store(C_ptrs, acc[:, None], mask=n_mask) def mv(inp, vec): logging.debug("GEMS MV") assert inp.shape[1] == vec.shape[0], "incompatible dimensions" N, M = inp.shape out = torch.empty((N,), device=inp.device, dtype=inp.dtype) grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]),) mv_kernel[grid]( inp, vec, out, N, M, inp.stride(0), inp.stride(1), vec.stride(0), out.stride(0), ) return out
module { tt.func public @mv_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} , %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} , %arg3: i32 {tt.divisibility = 16 : i32} , %arg4: i32 {tt.divisibility = 16 : i32} , %arg5: i32 {tt.divisibility = 16 : i32} ) attributes {noinline = false} { %c32_i32 = arith.constant 32 : i32 %c0_i32 = arith.constant 0 : i32 %cst = arith.constant dense<32> : tensor<1x32xi32> %cst_0 = arith.constant dense<0.000000e+00> : tensor<1x32xf32> %0 = tt.get_program_id x : i32 %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32> %3 = arith.cmpi slt, %0, %arg3 : i32 %4 = tt.splat %3 : i1 -> tensor<1x1xi1> %5 = arith.muli %0, %arg5 : i32 loc %6 = tt.addptr %arg0, %5 : !tt.ptr<f32>, i32 %7 = tt.splat %6 : !tt.ptr<f32> -> tensor<1x32x!tt.ptr<f32>> %8 = tt.addptr %7, %2 : tensor<1x32x!tt.ptr<f32>>, tensor<1x32xi32> %9 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<1x32x!tt.ptr<f32>> %10 = tt.addptr %9, %2 : tensor<1x32x!tt.ptr<f32>>, tensor<1x32xi32> %11 = tt.splat %arg4 : i32 -> tensor<1x32xi32> %12 = tt.splat %3 : i1 -> tensor<1x32xi1> %13:3 = scf.for %arg6 = %c0_i32 to %arg4 step %c32_i32 iter_args(%arg7 = %cst_0, %arg8 = %8, %arg9 = %10) -> (tensor<1x32xf32>, tensor<1x32x!tt.ptr<f32>>, tensor<1x32x!tt.ptr<f32>>) : i32 { %18 = tt.splat %arg6 : i32 -> tensor<1x32xi32> %19 = arith.addi %18, %2 : tensor<1x32xi32> %20 = arith.cmpi slt, %19, %11 : tensor<1x32xi32> %21 = arith.andi %12, %20 : tensor<1x32xi1> %22 = tt.load %arg8, %21, %cst_0 : tensor<1x32x!tt.ptr<f32>> %23 = tt.load %arg9, %20, %cst_0 : tensor<1x32x!tt.ptr<f32>> %24 = arith.mulf %22, %23 : tensor<1x32xf32> %25 = arith.addf %arg7, %24 : tensor<1x32xf32> %26 = tt.addptr %arg8, %cst : tensor<1x32x!tt.ptr<f32>>, tensor<1x32xi32> %27 = tt.addptr %arg9, %cst : tensor<1x32x!tt.ptr<f32>>, tensor<1x32xi32> scf.yield %25, %26, %27 : tensor<1x32xf32>, tensor<1x32x!tt.ptr<f32>>, tensor<1x32x!tt.ptr<f32>> } %14 = "tt.reduce"(%13#0) <{axis = 1 : i32}> ({ ^bb0(%arg6: f32, %arg7: f32): %18 = arith.addf %arg6, %arg7 : f32 tt.reduce.return %18 : f32 }) : (tensor<1x32xf32>) -> tensor<1xf32> %15 = tt.addptr %arg2, %0 : !tt.ptr<f32>, i32 %16 = tt.splat %15 : !tt.ptr<f32> -> tensor<1x1x!tt.ptr<f32>> %17 = tt.expand_dims %14 {axis = 1 : i32} : tensor<1xf32> -> tensor<1x1xf32> tt.store %16, %17, %4 : tensor<1x1x!tt.ptr<f32>> tt.return } }
test_mv_ops.py:45:20: remark: MaskAnalysis failed a = tl.load(A_ptrs, mask=n_mask & m_mask, other=0.0).to(tl.float32) ^ test_mv_ops.py:45:20: note: see current operation: %29 = tt.load %arg7, %27, %cst_0 : tensor<1x32x!tt.ptr<f16>> test_mv_ops.py:45:20: remark: PtrAnalysis: Failed to rewrite LoadOp a = tl.load(A_ptrs, mask=n_mask & m_mask, other=0.0).to(tl.float32) ^ test_mv_ops.py:45:20: note: see current operation: %29 = tt.load %arg7, %27, %cst_0 : tensor<1x32x!tt.ptr<f16>> test_mv_ops.py:53:21: remark: PtrAnalysis: pointer is not replace with tts.make_tptr so storeOp cannot be rewritten tl.store(C_ptrs, acc[:, None], mask=n_mask) ^ test_mv_ops.py:53:21: note: see current operation: tt.store %20, %22, %5 : tensor<1x1x!tt.ptr<f16>> test_mv_ops.py:53:21: remark: PtrAnalysis: Failed to rewrite StoreOp tl.store(C_ptrs, acc[:, None], mask=n_mask) ^ test_mv_ops.py:53:21: note: see current operation: tt.store %20, %22, %5 : tensor<1x1x!tt.ptr<f16>> test_mv_ops.py:40:40: error: failed to legalize unresolved materialization from ('memref<1x32xf16, strided<[?, ?], offset: ?>>') to 'tensor<1x32x!tt.ptr<f16>>' that remained live after conversion A_ptrs = A + offset_n * stride_an + offset_m * stride_am ^ test_mv_ops.py:40:40: note: see current operation: %44 = "builtin.unrealized_conversion_cast"(%arg23) : (memref<1x32xf16, strided<[?, ?], offset: ?>>) -> tensor<1x32x!tt.ptr<f16>> test_mv_ops.py:45:20: note: see existing live user here: %33 = tt.load %25, %32, %3 : tensor<1x32x!tt.ptr<f16>> a = tl.load(A_ptrs, mask=n_mask & m_mask, other=0.0).to(tl.float32)
No response
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Triton python code
Triton IR
Crash log
Additional information
No response
The text was updated successfully, but these errors were encountered: