Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Errors:No support for multiple modulos within an expression #48

Closed
paretoth opened this issue Nov 10, 2023 · 6 comments
Closed

Errors:No support for multiple modulos within an expression #48

paretoth opened this issue Nov 10, 2023 · 6 comments
Assignees

Comments

@paretoth
Copy link

paretoth commented Nov 10, 2023

module {
  tt.func public @triton__0d1d2d3d4d5d6d7d(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
    %cst = arith.constant 0.000000e+00 : f32
    %cst_0 = arith.constant dense<0.000000e+00> : tensor<1x512xf32>
    %cst_1 = arith.constant dense<384> : tensor<1x1xi64>
    %cst_2 = arith.constant dense<0.000000e+00> : tensor<1x512xbf16>
    %cst_3 = arith.constant dense<65> : tensor<1x1xi64>
    %cst_4 = arith.constant dense<0> : tensor<1x1xi64>
    %cst_5 = arith.constant dense<9.99999974E-6> : tensor<1x1xf32>
    %cst_6 = arith.constant dense<3.840000e+02> : tensor<1x1xf32>
    %cst_7 = arith.constant dense<384> : tensor<1x1xi32>
    %cst_8 = arith.constant dense<384> : tensor<1x512xi32>
    %cst_9 = arith.constant dense<256> : tensor<1x1xi32>
    %0 = tt.get_program_id x : i32
    %1 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32>
    %2 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<1xi32>) -> tensor<1x1xi32>
    %3 = tt.splat %0 : (i32) -> tensor<1x1xi32>
    %4 = arith.addi %3, %2 : tensor<1x1xi32>
    %5 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32>
    %6 = tt.expand_dims %5 {axis = 0 : i32} : (tensor<512xi32>) -> tensor<1x512xi32>
    %7 = tt.splat %arg0 : (!tt.ptr<i64>) -> tensor<1x1x!tt.ptr<i64>>
    %8 = tt.addptr %7, %4 : tensor<1x1x!tt.ptr<i64>>, tensor<1x1xi32>
    %9 = tt.load %8 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<1x1xi64>
    %10 = arith.remsi %4, %cst_9 : tensor<1x1xi32>
    %11 = arith.cmpi slt, %6, %cst_8 : tensor<1x512xi32>
    %12 = arith.muli %10, %cst_7 : tensor<1x1xi32>
    %13 = tt.broadcast %12 : (tensor<1x1xi32>) -> tensor<1x512xi32>
    %14 = arith.addi %6, %13 : tensor<1x512xi32>
    %15 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<1x512x!tt.ptr<f32>>
    %16 = tt.addptr %15, %14 : tensor<1x512x!tt.ptr<f32>>, tensor<1x512xi32>
    %17 = tt.load %16, %11, %cst_0 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<1x512xf32>
    %18 = arith.muli %4, %cst_7 : tensor<1x1xi32>
    %19 = tt.broadcast %18 : (tensor<1x1xi32>) -> tensor<1x512xi32>
    %20 = arith.addi %6, %19 : tensor<1x512xi32>
    %21 = tt.splat %arg3 : (!tt.ptr<bf16>) -> tensor<1x512x!tt.ptr<bf16>>
    %22 = tt.addptr %21, %20 : tensor<1x512x!tt.ptr<bf16>>, tensor<1x512xi32>
    %23 = tt.load %22, %11, %cst_2 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<1x512xbf16>
    %24 = arith.extf %23 : tensor<1x512xbf16> to tensor<1x512xf32>
    %25 = arith.cmpi sge, %9, %cst_4 : tensor<1x1xi64>
    %26 = arith.cmpi slt, %9, %cst_3 : tensor<1x1xi64>
    %27 = arith.andi %25, %26 : tensor<1x1xi1>
    tt.assert %27, "index out of bounds: 0 <= tmp0 < 65", "<frozen importlib._bootstrap_external>", "_call_with_frames_removed", 883 : tensor<1x1xi1>
    %28 = arith.muli %9, %cst_1 : tensor<1x1xi64>
    %29 = tt.broadcast %28 : (tensor<1x1xi64>) -> tensor<1x512xi64>
    %30 = arith.extsi %6 : tensor<1x512xi32> to tensor<1x512xi64>
    %31 = arith.addi %30, %29 : tensor<1x512xi64>
    %32 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<1x512x!tt.ptr<f32>>
    %33 = tt.addptr %32, %31 : tensor<1x512x!tt.ptr<f32>>, tensor<1x512xi64>
    %34 = tt.load %33, %11, %cst_0 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<1x512xf32>
    %35 = arith.addf %34, %17 : tensor<1x512xf32>
    %36 = arith.addf %35, %24 : tensor<1x512xf32>
    %37 = arith.uitofp %11 : tensor<1x512xi1> to tensor<1x512xf32>
    %38 = arith.addf %37, %cst_0 : tensor<1x512xf32>
    %39 = arith.divf %36, %38 : tensor<1x512xf32>
    %40 = arith.addf %39, %cst_0 : tensor<1x512xf32>
    %41 = arith.subf %36, %40 : tensor<1x512xf32>
    %42 = arith.mulf %36, %41 : tensor<1x512xf32>
    %43 = arith.addf %42, %cst_0 : tensor<1x512xf32>
    %44 = arith.select %11, %40, %cst_0 : tensor<1x512xi1>, tensor<1x512xf32>
    %45 = arith.select %11, %43, %cst_0 : tensor<1x512xi1>, tensor<1x512xf32>
    %46 = arith.select %11, %38, %cst_0 : tensor<1x512xi1>, tensor<1x512xf32>
    %47 = arith.addf %36, %cst_0 : tensor<1x512xf32>
    %48 = arith.select %11, %47, %cst_0 : tensor<1x512xi1>, tensor<1x512xf32>
    %49:3 = "tt.reduce"(%44, %45, %46) <{axis = 1 : i32}> ({
    ^bb0(%arg8: f32, %arg9: f32, %arg10: f32, %arg11: f32, %arg12: f32, %arg13: f32):
      %74 = arith.subf %arg11, %arg8 : f32
      %75 = arith.addf %arg10, %arg13 : f32
      %76 = arith.cmpf oeq, %75, %cst : f32
      %77 = arith.divf %arg13, %75 : f32
      %78 = arith.select %76, %cst, %77 : f32
      %79 = arith.mulf %74, %78 : f32
      %80 = arith.addf %arg8, %79 : f32
      %81 = arith.addf %arg9, %arg12 : f32
      %82 = arith.mulf %74, %74 : f32
      %83 = arith.mulf %82, %arg10 : f32
      %84 = arith.mulf %83, %78 : f32
      %85 = arith.addf %81, %84 : f32
      tt.reduce.return %80, %85, %75 : f32, f32, f32
    }) : (tensor<1x512xf32>, tensor<1x512xf32>, tensor<1x512xf32>) -> (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>)
    %50 = tt.expand_dims %49#1 {axis = 1 : i32} : (tensor<1xf32>) -> tensor<1x1xf32>
    %51 = "tt.reduce"(%48) <{axis = 1 : i32}> ({
    ^bb0(%arg8: f32, %arg9: f32):
      %74 = arith.addf %arg8, %arg9 : f32
      tt.reduce.return %74 : f32
    }) : (tensor<1x512xf32>) -> tensor<1xf32>
    %52 = tt.expand_dims %51 {axis = 1 : i32} : (tensor<1xf32>) -> tensor<1x1xf32>
    %53 = tt.load %16, %11, %cst_0 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<1x512xf32>
    %54 = tt.load %22, %11, %cst_2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1x512xbf16>
    %55 = arith.extf %54 : tensor<1x512xbf16> to tensor<1x512xf32>
    %56 = tt.splat %arg4 : (!tt.ptr<f32>) -> tensor<1x512x!tt.ptr<f32>>
    %57 = tt.addptr %56, %6 : tensor<1x512x!tt.ptr<f32>>, tensor<1x512xi32>
    %58 = tt.load %57, %11, %cst_0 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<1x512xf32>
    %59 = tt.load %33, %11, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1x512xf32>
    %60 = arith.addf %59, %53 : tensor<1x512xf32>
    %61 = arith.addf %60, %55 : tensor<1x512xf32>
    %62 = arith.divf %52, %cst_6 : tensor<1x1xf32>
    %63 = tt.broadcast %62 : (tensor<1x1xf32>) -> tensor<1x512xf32>
    %64 = arith.subf %61, %63 : tensor<1x512xf32>
    %65 = arith.divf %50, %cst_6 : tensor<1x1xf32>
    %66 = arith.addf %65, %cst_5 : tensor<1x1xf32>
    %67 = tt.extern_elementwise %66 {pure = true, libname = "libdevice", libpath = "/wkspc/hongzhu/anaconda3/envs/exp_triton_org1/lib/python3.10/site-packages/triton/language/../third_party/cuda/lib/libdevice.10.bc", symbol = "__nv_rsqrtf"} : (tensor<1x1xf32>) -> tensor<1x1xf32>
    %68 = tt.broadcast %67 : (tensor<1x1xf32>) -> tensor<1x512xf32>
    %69 = arith.mulf %64, %68 : tensor<1x512xf32>
    %70 = arith.mulf %69, %58 : tensor<1x512xf32>
    %71 = tt.splat %arg5 : (!tt.ptr<bf16>) -> tensor<1x512x!tt.ptr<bf16>>
    %72 = tt.addptr %71, %20 : tensor<1x512x!tt.ptr<bf16>>, tensor<1x512xi32>
    %73 = arith.truncf %70 : tensor<1x512xf32> to tensor<1x512xbf16>
    tt.store %72, %73, %11 {cache = 1 : i32, evict = 1 : i32} : tensor<1x512xbf16>
    tt.return
  }
}

Running triton-shared-opt --triton-to-linalg on the IR throws this error:

triton-shared-opt: /home/lanwo/code/triton-dlc/triton/third_party/triton_shared/lib/Analysis/PtrAnalysis.cpp:426: static void mlir::triton::PtrAnalysis::visitOperandRem(mlir::arith::RemSIOp, mlir::triton::PtrState&, mlir::Location, mlir::ConversionPatternRewriter&, const llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState>&): Assertion `state.getRank() == 1 && !state.modulos.back().has_value() && "No support for multiple modulos within an expression"' failed.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.	Program arguments: triton-shared-opt --triton-to-linalg models_nano/nanoGPT/9c632b4228052e6bbc27c9772f3a7e3b/triton_.ttir
 #0 0x0000564125d484a7 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (/home/lanwo/code/triton-dlc/triton/python/build/cmake.linux-x86_64-cpython-3.10/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x42fc4a7)
 #1 0x0000564125d45fce llvm::sys::RunSignalHandlers() (/home/lanwo/code/triton-dlc/triton/python/build/cmake.linux-x86_64-cpython-3.10/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x42f9fce)
 #2 0x0000564125d48b5f SignalHandler(int) Signals.cpp:0:0
 #3 0x00007f3f0ec86520 (/usr/lib/x86_64-linux-gnu/libc.so.6+0x42520)
 #4 0x00007f3f0ecdaa7c __pthread_kill_implementation ./nptl/./nptl/pthread_kill.c:44:76
 #5 0x00007f3f0ecdaa7c __pthread_kill_internal ./nptl/./nptl/pthread_kill.c:78:10
 #6 0x00007f3f0ecdaa7c pthread_kill ./nptl/./nptl/pthread_kill.c:89:10
 #7 0x00007f3f0ec86476 gsignal ./signal/../sysdeps/posix/raise.c:27:6
 #8 0x00007f3f0ec6c7f3 abort ./stdlib/./stdlib/abort.c:81:7
 #9 0x00007f3f0ec6c71b _nl_load_domain ./intl/./intl/loadmsgcat.c:1177:9
#10 0x00007f3f0ec7de96 (/usr/lib/x86_64-linux-gnu/libc.so.6+0x39e96)
#11 0x0000564123a26e79 mlir::triton::PtrAnalysis::visitOperandRem(mlir::arith::RemSIOp, mlir::triton::PtrState&, mlir::Location, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState> > const&) /home/lanwo/code/triton-dlc/triton/third_party/triton_shared/lib/Analysis/PtrAnalysis.cpp:431:3
#12 0x0000564123a26217 mlir::triton::PtrAnalysis::visitOperand(mlir::Value, mlir::triton::PtrState&, mlir::Location, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState> > const&) /home/lanwo/code/triton-dlc/triton/third_party/triton_shared/lib/Analysis/PtrAnalysis.cpp:695:20
@paretoth paretoth changed the title No support for multiple modulos within an expression Errors:No support for multiple modulos within an expression Nov 10, 2023
@nhat-nguyen
Copy link
Collaborator

@paretoth Thanks for reporting this, would you mind also sharing the triton python code that produces this IR?

@manbearian
Copy link
Collaborator

@nhat-nguyen is there anyway we can change this??

PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.

@paretoth
Copy link
Author

paretoth commented Nov 13, 2023

@paretoth Thanks for reporting this, would you mind also sharing the triton python code that produces this IR?

@nhat-nguyen Thanks for your reply. The triton python code and its corresponding IR as follows:
The triton python code:

import triton
import triton.language as tl
from torch._inductor import triton_helpers

@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 16384
    rnumel = 384
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (x3), None, eviction_policy='evict_last')
    x0 = xindex % 256
    tmp9_mean = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    tmp9_m2 = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    tmp9_weight = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r2 = rindex
        tmp3 = tl.load(in_ptr2 + (r2 + (384*x0)), rmask, eviction_policy='evict_last', other=0)
        tmp5 = tl.load(in_ptr3 + (r2 + (384*x3)), rmask, eviction_policy='evict_last', other=0).to(tl.float32)
        tmp1 = tl.where(tmp0 < 0, tmp0 + 65, tmp0)
        tl.device_assert((0 <= tmp1) & (tmp1 < 65), "index out of bounds: 0 <= tmp1 < 65")
        tmp2 = tl.load(in_ptr1 + (r2 + (384*tmp1)), rmask, eviction_policy='evict_last', other=0)
        tmp4 = tmp2 + tmp3
        tmp6 = tmp5.to(tl.float32)
        tmp7 = tmp4 + tmp6
        tmp8 = tl.broadcast_to(tmp7, [XBLOCK, RBLOCK])
        tmp9_mean_next, tmp9_m2_next, tmp9_weight_next = triton_helpers.welford_reduce(
            tmp8, tmp9_mean, tmp9_m2, tmp9_weight,
        )
        tmp9_mean = tl.where(rmask, tmp9_mean_next, tmp9_mean)
        tmp9_m2 = tl.where(rmask, tmp9_m2_next, tmp9_m2)
        tmp9_weight = tl.where(rmask, tmp9_weight_next, tmp9_weight)
    tmp9_tmp, tmp10_tmp, tmp11_tmp = triton_helpers.welford(
        tmp9_mean, tmp9_m2, tmp9_weight, 1
    )
    tmp9 = tmp9_tmp[:, None]
    tmp10 = tmp10_tmp[:, None]
    tmp11 = tmp11_tmp[:, None]
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r2 = rindex
        tmp14 = tl.load(in_ptr2 + (r2 + (384*x0)), rmask, eviction_policy='evict_last', other=0)
        tmp16 = tl.load(in_ptr3 + (r2 + (384*x3)), rmask, other=0).to(tl.float32)
        tmp26 = tl.load(in_ptr4 + (r2), rmask, eviction_policy='evict_last', other=0)
        tmp12 = tl.where(tmp0 < 0, tmp0 + 65, tmp0)
        tl.device_assert((0 <= tmp12) & (tmp12 < 65), "index out of bounds: 0 <= tmp12 < 65")
        tmp13 = tl.load(in_ptr1 + (r2 + (384*tmp12)), rmask, other=0)
        tmp15 = tmp13 + tmp14
        tmp17 = tmp16.to(tl.float32)
        tmp18 = tmp15 + tmp17
        tmp19 = tmp18 - tmp9
        tmp20 = 384.0
        tmp21 = tmp10 / tmp20
        tmp22 = 1e-05
        tmp23 = tmp21 + tmp22
        tmp24 = tl.math.rsqrt(tmp23)
        tmp25 = tmp19 * tmp24
        tmp27 = tmp25 * tmp26
        tmp28 = tmp27.to(tl.float32)
        tl.store(out_ptr2 + (r2 + (384*x3)), tmp28, rmask)

#'signature': {0: '*i64', 1: '*fp32', 2: '*fp32', 3: '*bf16', 4: '*fp32', 5: '*bf16', 6: 'i32', 7: 'i32'}
ret = triton.compile(triton_, signature="*i64,*fp32,*fp32,*bf16,*fp32,*bf16,i32,i32", constants={"XBLOCK": 64, "RBLOCK": 512})
output_filename = __file__ + ".ttir"
with open(output_filename, "w") as f:
    f.write(ret.asm["ttir"])

print(f"triton_ saved to {output_filename}")

Running triton-shared-opt --triton-to-linalg, IR:

module {
  tt.func public @triton__01234567(%arg0: !tt.ptr<i64, 1>, %arg1: !tt.ptr<f32, 1>, %arg2: !tt.ptr<f32, 1>, %arg3: !tt.ptr<bf16, 1>, %arg4: !tt.ptr<f32, 1>, %arg5: !tt.ptr<bf16, 1>, %arg6: i32, %arg7: i32) attributes {noinline = false} {
    %cst = arith.constant dense<1.000000e+00> : tensor<1x512xf32>
    %cst_0 = arith.constant dense<0.000000e+00> : tensor<1x512xf32>
    %cst_1 = arith.constant 0.000000e+00 : f32
    %cst_2 = arith.constant dense<0.000000e+00> : tensor<64x512xf32>
    %cst_3 = arith.constant dense<384> : tensor<64x1xi64>
    %cst_4 = arith.constant dense<65> : tensor<64x1xi64>
    %cst_5 = arith.constant dense<0> : tensor<64x1xi64>
    %cst_6 = arith.constant dense<0.000000e+00> : tensor<64x512xbf16>
    %cst_7 = arith.constant dense<9.99999974E-6> : tensor<64x1xf32>
    %cst_8 = arith.constant dense<3.840000e+02> : tensor<64x1xf32>
    %cst_9 = arith.constant dense<384> : tensor<64x1xi32>
    %cst_10 = arith.constant dense<384> : tensor<1x512xi32>
    %cst_11 = arith.constant dense<256> : tensor<64x1xi32>
    %c64_i32 = arith.constant 64 : i32
    %0 = tt.get_program_id x : i32
    %1 = arith.muli %0, %c64_i32 : i32
    %2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
    %3 = tt.expand_dims %2 {axis = 1 : i32} : (tensor<64xi32>) -> tensor<64x1xi32>
    %4 = tt.splat %1 : (i32) -> tensor<64x1xi32>
    %5 = arith.addi %4, %3 : tensor<64x1xi32>
    %6 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32>
    %7 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<512xi32>) -> tensor<1x512xi32>
    %8 = tt.splat %arg0 : (!tt.ptr<i64, 1>) -> tensor<64x1x!tt.ptr<i64, 1>>
    %9 = tt.addptr %8, %5 : tensor<64x1x!tt.ptr<i64, 1>>, tensor<64x1xi32>
    %10 = tt.load %9 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<64x1xi64>
    %11 = arith.remsi %5, %cst_11 : tensor<64x1xi32>
    %12 = arith.cmpi slt, %7, %cst_10 : tensor<1x512xi32>
    %13 = arith.muli %11, %cst_9 : tensor<64x1xi32>
    %14 = tt.broadcast %7 : (tensor<1x512xi32>) -> tensor<64x512xi32>
    %15 = tt.broadcast %13 : (tensor<64x1xi32>) -> tensor<64x512xi32>
    %16 = arith.addi %14, %15 : tensor<64x512xi32>
    %17 = tt.splat %arg2 : (!tt.ptr<f32, 1>) -> tensor<64x512x!tt.ptr<f32, 1>>
    %18 = tt.addptr %17, %16 : tensor<64x512x!tt.ptr<f32, 1>>, tensor<64x512xi32>
    %19 = tt.broadcast %12 : (tensor<1x512xi1>) -> tensor<64x512xi1>
    %20 = tt.load %18, %19, %cst_2 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<64x512xf32>
    %21 = arith.muli %5, %cst_9 : tensor<64x1xi32>
    %22 = tt.broadcast %21 : (tensor<64x1xi32>) -> tensor<64x512xi32>
    %23 = arith.addi %14, %22 : tensor<64x512xi32>
    %24 = tt.splat %arg3 : (!tt.ptr<bf16, 1>) -> tensor<64x512x!tt.ptr<bf16, 1>>
    %25 = tt.addptr %24, %23 : tensor<64x512x!tt.ptr<bf16, 1>>, tensor<64x512xi32>
    %26 = tt.load %25, %19, %cst_6 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<64x512xbf16>
    %27 = arith.extf %26 : tensor<64x512xbf16> to tensor<64x512xf32>
    %28 = arith.cmpi slt, %10, %cst_5 : tensor<64x1xi64>
    %29 = arith.addi %10, %cst_4 : tensor<64x1xi64>
    %30 = arith.select %28, %29, %10 : tensor<64x1xi1>, tensor<64x1xi64>
    %31 = arith.muli %30, %cst_3 : tensor<64x1xi64>
    %32 = tt.broadcast %31 : (tensor<64x1xi64>) -> tensor<64x512xi64>
    %33 = arith.extsi %7 : tensor<1x512xi32> to tensor<1x512xi64>
    %34 = tt.broadcast %33 : (tensor<1x512xi64>) -> tensor<64x512xi64>
    %35 = arith.addi %34, %32 : tensor<64x512xi64>
    %36 = tt.splat %arg1 : (!tt.ptr<f32, 1>) -> tensor<64x512x!tt.ptr<f32, 1>>
    %37 = tt.addptr %36, %35 : tensor<64x512x!tt.ptr<f32, 1>>, tensor<64x512xi64>
    %38 = tt.load %37, %19, %cst_2 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<64x512xf32>
    %39 = arith.addf %38, %20 : tensor<64x512xf32>
    %40 = arith.addf %39, %27 : tensor<64x512xf32>
    %41 = arith.addf %40, %cst_2 : tensor<64x512xf32>
    %42 = arith.subf %40, %41 : tensor<64x512xf32>
    %43 = arith.mulf %40, %42 : tensor<64x512xf32>
    %44 = arith.addf %43, %cst_2 : tensor<64x512xf32>
    %45 = arith.select %19, %41, %cst_2 : tensor<64x512xi1>, tensor<64x512xf32>
    %46 = arith.select %19, %44, %cst_2 : tensor<64x512xi1>, tensor<64x512xf32>
    %47 = arith.select %12, %cst, %cst_0 : tensor<1x512xi1>, tensor<1x512xf32>
    %48 = tt.broadcast %47 : (tensor<1x512xf32>) -> tensor<64x512xf32>
    %49:3 = "tt.reduce"(%45, %46, %48) <{axis = 1 : i32}> ({
    ^bb0(%arg8: f32, %arg9: f32, %arg10: f32, %arg11: f32, %arg12: f32, %arg13: f32):
      %72 = arith.subf %arg11, %arg8 : f32
      %73 = arith.addf %arg10, %arg13 : f32
      %74 = arith.cmpf oeq, %73, %cst_1 : f32
      %75 = arith.divf %arg13, %73 : f32
      %76 = arith.select %74, %cst_1, %75 : f32
      %77 = arith.mulf %72, %76 : f32
      %78 = arith.addf %arg8, %77 : f32
      %79 = arith.addf %arg9, %arg12 : f32
      %80 = arith.mulf %72, %72 : f32
      %81 = arith.mulf %80, %arg10 : f32
      %82 = arith.mulf %81, %76 : f32
      %83 = arith.addf %79, %82 : f32
      tt.reduce.return %78, %83, %73 : f32, f32, f32
    }) : (tensor<64x512xf32>, tensor<64x512xf32>, tensor<64x512xf32>) -> (tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
    %50 = tt.expand_dims %49#0 {axis = 1 : i32} : (tensor<64xf32>) -> tensor<64x1xf32>
    %51 = tt.expand_dims %49#1 {axis = 1 : i32} : (tensor<64xf32>) -> tensor<64x1xf32>
    %52 = tt.load %25, %19, %cst_6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x512xbf16>
    %53 = arith.extf %52 : tensor<64x512xbf16> to tensor<64x512xf32>
    %54 = tt.splat %arg4 : (!tt.ptr<f32, 1>) -> tensor<1x512x!tt.ptr<f32, 1>>
    %55 = tt.addptr %54, %7 : tensor<1x512x!tt.ptr<f32, 1>>, tensor<1x512xi32>
    %56 = tt.load %55, %12, %cst_0 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<1x512xf32>
    %57 = tt.load %37, %19, %cst_2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x512xf32>
    %58 = arith.addf %57, %20 : tensor<64x512xf32>
    %59 = arith.addf %58, %53 : tensor<64x512xf32>
    %60 = tt.broadcast %50 : (tensor<64x1xf32>) -> tensor<64x512xf32>
    %61 = arith.subf %59, %60 : tensor<64x512xf32>
    %62 = arith.divf %51, %cst_8 : tensor<64x1xf32>
    %63 = arith.addf %62, %cst_7 : tensor<64x1xf32>
    %64 = tt.extern_elementwise %63 {libname = "libdevice", libpath = "/home/lanwo/code/triton-dlc/triton/python/triton/language/../third_party/cuda/lib/libdevice.10.bc", pure = true, symbol = "__nv_rsqrtf"} : (tensor<64x1xf32>) -> tensor<64x1xf32>
    %65 = tt.broadcast %64 : (tensor<64x1xf32>) -> tensor<64x512xf32>
    %66 = arith.mulf %61, %65 : tensor<64x512xf32>
    %67 = tt.broadcast %56 : (tensor<1x512xf32>) -> tensor<64x512xf32>
    %68 = arith.mulf %66, %67 : tensor<64x512xf32>
    %69 = tt.splat %arg5 : (!tt.ptr<bf16, 1>) -> tensor<64x512x!tt.ptr<bf16, 1>>
    %70 = tt.addptr %69, %23 : tensor<64x512x!tt.ptr<bf16, 1>>, tensor<64x512xi32>
    %71 = arith.truncf %68 : tensor<64x512xf32> to tensor<64x512xbf16>
    tt.store %70, %71, %19 {cache = 1 : i32, evict = 1 : i32} : tensor<64x512xbf16>
    tt.return
  }
}

Running triton-shared-opt --triton-to-linalg on the IR throws this error:

triton-shared-opt: /home/lanwo/code/triton-dlc/triton/third_party/triton_shared/lib/Analysis/PtrAnalysis.cpp:429: static void mlir::triton::PtrAnalysis::visitOperandRem(mlir::arith::RemSIOp, mlir::triton::PtrState&, mlir::Location, mlir::ConversionPatternRewriter&, const llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState>&): Assertion `state.getRank() == 1 && !state.modulos.back().has_value() && "No support for multiple modulos within an expression"' failed.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.	Program arguments: triton-shared-opt --triton-to-linalg test_ttir.py.ttir
 #0 0x00005596e8a24467 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (/home/lanwo/code/triton-dlc/triton/python/build/cmake.linux-x86_64-cpython-3.10/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x42fd467)
 #1 0x00005596e8a21f8e llvm::sys::RunSignalHandlers() (/home/lanwo/code/triton-dlc/triton/python/build/cmake.linux-x86_64-cpython-3.10/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x42faf8e)
 #2 0x00005596e8a24b1f SignalHandler(int) Signals.cpp:0:0
 #3 0x00007ff1dc020520 (/lib/x86_64-linux-gnu/libc.so.6+0x42520)
 #4 0x00007ff1dc0749fc __pthread_kill_implementation ./nptl/./nptl/pthread_kill.c:44:76
 #5 0x00007ff1dc0749fc __pthread_kill_internal ./nptl/./nptl/pthread_kill.c:78:10
 #6 0x00007ff1dc0749fc pthread_kill ./nptl/./nptl/pthread_kill.c:89:10
 #7 0x00007ff1dc020476 gsignal ./signal/../sysdeps/posix/raise.c:27:6
 #8 0x00007ff1dc0067f3 abort ./stdlib/./stdlib/abort.c:81:7
 #9 0x00007ff1dc00671b _nl_load_domain ./intl/./intl/loadmsgcat.c:1177:9
#10 0x00007ff1dc017e96 (/lib/x86_64-linux-gnu/libc.so.6+0x39e96)
#11 0x00005596e6702e79 mlir::triton::PtrAnalysis::visitOperandRem(mlir::arith::RemSIOp, mlir::triton::PtrState&, mlir::Location, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState> > const&) /home/lanwo/code/triton-dlc/triton/third_party/triton_shared/lib/Analysis/PtrAnalysis.cpp:434:3
#12 0x00005596e6702216 mlir::triton::PtrAnalysis::visitOperand(mlir::Value, mlir::triton::PtrState&, mlir::Location, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState> > const&) /home/lanwo/code/triton-dlc/triton/third_party/triton_shared/lib/Analysis/PtrAnalysis.cpp:698:20

@nhat-nguyen
Copy link
Collaborator

Thank you @paretoth for the original triton python code.

This looks like it was generated by torch-inductor which may be using certain coding patterns that we have not covered in our current implementation. Our implementation is still in early-stage and it would take quite some time to reach a point where we can feed in arbitrary triton programs, especially ones that are generated by another compiler.

I will do an initial investigation and open up issues that might need to be addressed in order to lower this program. Keep in mind that you will need to do additional lowering regardless (e.g: tt.extern_elementwise will need to be lowered to ops that suit your needs).

@paretoth
Copy link
Author

Thanks for your reply. It was indeed generated by torch-inductor.

@nhat-nguyen nhat-nguyen self-assigned this Nov 27, 2023
nhat-nguyen added a commit that referenced this issue Dec 1, 2023
This PR addresses several issues:

1. Support modulo pattern generated by torch inductor
We now support `tl.arange(0, size)[:, None] % mod` -- expanding the shape before applying the modulo). Fixes #14 #48.

2. Add more modulo tests running on the CPU backend,
I found out that our current usages of `memref.reinterpret_cast` to support modulo ops in a loop are incorrect. Previously we insert two "no-op" `memref.reinterpret_cast` for the two blocks of the mod pointers so that `LoadConveter` can determine the sizes of the blocks to copy to the local buffers. However, when lowering all the way to llvm, doing this meant that we are resetting the offset of the blocks being yielded in each loop interation. To solve this, I have replaced the casts with the proper `memref.dim_op` to get the correct sizes.

3. Fix individual modulo block's type can sometimes mismatch in a loop
Previously, the types for each individual modulo block can have static strides. During a loop, their corresponding loop's yield values have dynamic strides, causing type mismatch. I have instead make the strides always dynamic to begin with.

4. Support lowering to CPU for more cases
Lowering to memref can produces more affine ops which we would have already run in the current pass ordering. I have added two additional passes in the pass list to fix this issue.

5. Add softmax tutorial test for CPU backend
@nhat-nguyen
Copy link
Collaborator

Fixed by #68

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants