-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Errors:No support for multiple modulos within an expression #48
Comments
@paretoth Thanks for reporting this, would you mind also sharing the triton python code that produces this IR? |
@nhat-nguyen is there anyway we can change this??
|
@nhat-nguyen Thanks for your reply. The triton python code and its corresponding IR as follows: import triton
import triton.language as tl
from torch._inductor import triton_helpers
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 16384
rnumel = 384
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x3 = xindex
tmp0 = tl.load(in_ptr0 + (x3), None, eviction_policy='evict_last')
x0 = xindex % 256
tmp9_mean = tl.zeros([XBLOCK, RBLOCK], tl.float32)
tmp9_m2 = tl.zeros([XBLOCK, RBLOCK], tl.float32)
tmp9_weight = tl.zeros([XBLOCK, RBLOCK], tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r2 = rindex
tmp3 = tl.load(in_ptr2 + (r2 + (384*x0)), rmask, eviction_policy='evict_last', other=0)
tmp5 = tl.load(in_ptr3 + (r2 + (384*x3)), rmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp1 = tl.where(tmp0 < 0, tmp0 + 65, tmp0)
tl.device_assert((0 <= tmp1) & (tmp1 < 65), "index out of bounds: 0 <= tmp1 < 65")
tmp2 = tl.load(in_ptr1 + (r2 + (384*tmp1)), rmask, eviction_policy='evict_last', other=0)
tmp4 = tmp2 + tmp3
tmp6 = tmp5.to(tl.float32)
tmp7 = tmp4 + tmp6
tmp8 = tl.broadcast_to(tmp7, [XBLOCK, RBLOCK])
tmp9_mean_next, tmp9_m2_next, tmp9_weight_next = triton_helpers.welford_reduce(
tmp8, tmp9_mean, tmp9_m2, tmp9_weight,
)
tmp9_mean = tl.where(rmask, tmp9_mean_next, tmp9_mean)
tmp9_m2 = tl.where(rmask, tmp9_m2_next, tmp9_m2)
tmp9_weight = tl.where(rmask, tmp9_weight_next, tmp9_weight)
tmp9_tmp, tmp10_tmp, tmp11_tmp = triton_helpers.welford(
tmp9_mean, tmp9_m2, tmp9_weight, 1
)
tmp9 = tmp9_tmp[:, None]
tmp10 = tmp10_tmp[:, None]
tmp11 = tmp11_tmp[:, None]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r2 = rindex
tmp14 = tl.load(in_ptr2 + (r2 + (384*x0)), rmask, eviction_policy='evict_last', other=0)
tmp16 = tl.load(in_ptr3 + (r2 + (384*x3)), rmask, other=0).to(tl.float32)
tmp26 = tl.load(in_ptr4 + (r2), rmask, eviction_policy='evict_last', other=0)
tmp12 = tl.where(tmp0 < 0, tmp0 + 65, tmp0)
tl.device_assert((0 <= tmp12) & (tmp12 < 65), "index out of bounds: 0 <= tmp12 < 65")
tmp13 = tl.load(in_ptr1 + (r2 + (384*tmp12)), rmask, other=0)
tmp15 = tmp13 + tmp14
tmp17 = tmp16.to(tl.float32)
tmp18 = tmp15 + tmp17
tmp19 = tmp18 - tmp9
tmp20 = 384.0
tmp21 = tmp10 / tmp20
tmp22 = 1e-05
tmp23 = tmp21 + tmp22
tmp24 = tl.math.rsqrt(tmp23)
tmp25 = tmp19 * tmp24
tmp27 = tmp25 * tmp26
tmp28 = tmp27.to(tl.float32)
tl.store(out_ptr2 + (r2 + (384*x3)), tmp28, rmask)
#'signature': {0: '*i64', 1: '*fp32', 2: '*fp32', 3: '*bf16', 4: '*fp32', 5: '*bf16', 6: 'i32', 7: 'i32'}
ret = triton.compile(triton_, signature="*i64,*fp32,*fp32,*bf16,*fp32,*bf16,i32,i32", constants={"XBLOCK": 64, "RBLOCK": 512})
output_filename = __file__ + ".ttir"
with open(output_filename, "w") as f:
f.write(ret.asm["ttir"])
print(f"triton_ saved to {output_filename}") Running triton-shared-opt --triton-to-linalg, IR: module {
tt.func public @triton__01234567(%arg0: !tt.ptr<i64, 1>, %arg1: !tt.ptr<f32, 1>, %arg2: !tt.ptr<f32, 1>, %arg3: !tt.ptr<bf16, 1>, %arg4: !tt.ptr<f32, 1>, %arg5: !tt.ptr<bf16, 1>, %arg6: i32, %arg7: i32) attributes {noinline = false} {
%cst = arith.constant dense<1.000000e+00> : tensor<1x512xf32>
%cst_0 = arith.constant dense<0.000000e+00> : tensor<1x512xf32>
%cst_1 = arith.constant 0.000000e+00 : f32
%cst_2 = arith.constant dense<0.000000e+00> : tensor<64x512xf32>
%cst_3 = arith.constant dense<384> : tensor<64x1xi64>
%cst_4 = arith.constant dense<65> : tensor<64x1xi64>
%cst_5 = arith.constant dense<0> : tensor<64x1xi64>
%cst_6 = arith.constant dense<0.000000e+00> : tensor<64x512xbf16>
%cst_7 = arith.constant dense<9.99999974E-6> : tensor<64x1xf32>
%cst_8 = arith.constant dense<3.840000e+02> : tensor<64x1xf32>
%cst_9 = arith.constant dense<384> : tensor<64x1xi32>
%cst_10 = arith.constant dense<384> : tensor<1x512xi32>
%cst_11 = arith.constant dense<256> : tensor<64x1xi32>
%c64_i32 = arith.constant 64 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c64_i32 : i32
%2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%3 = tt.expand_dims %2 {axis = 1 : i32} : (tensor<64xi32>) -> tensor<64x1xi32>
%4 = tt.splat %1 : (i32) -> tensor<64x1xi32>
%5 = arith.addi %4, %3 : tensor<64x1xi32>
%6 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32>
%7 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<512xi32>) -> tensor<1x512xi32>
%8 = tt.splat %arg0 : (!tt.ptr<i64, 1>) -> tensor<64x1x!tt.ptr<i64, 1>>
%9 = tt.addptr %8, %5 : tensor<64x1x!tt.ptr<i64, 1>>, tensor<64x1xi32>
%10 = tt.load %9 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<64x1xi64>
%11 = arith.remsi %5, %cst_11 : tensor<64x1xi32>
%12 = arith.cmpi slt, %7, %cst_10 : tensor<1x512xi32>
%13 = arith.muli %11, %cst_9 : tensor<64x1xi32>
%14 = tt.broadcast %7 : (tensor<1x512xi32>) -> tensor<64x512xi32>
%15 = tt.broadcast %13 : (tensor<64x1xi32>) -> tensor<64x512xi32>
%16 = arith.addi %14, %15 : tensor<64x512xi32>
%17 = tt.splat %arg2 : (!tt.ptr<f32, 1>) -> tensor<64x512x!tt.ptr<f32, 1>>
%18 = tt.addptr %17, %16 : tensor<64x512x!tt.ptr<f32, 1>>, tensor<64x512xi32>
%19 = tt.broadcast %12 : (tensor<1x512xi1>) -> tensor<64x512xi1>
%20 = tt.load %18, %19, %cst_2 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<64x512xf32>
%21 = arith.muli %5, %cst_9 : tensor<64x1xi32>
%22 = tt.broadcast %21 : (tensor<64x1xi32>) -> tensor<64x512xi32>
%23 = arith.addi %14, %22 : tensor<64x512xi32>
%24 = tt.splat %arg3 : (!tt.ptr<bf16, 1>) -> tensor<64x512x!tt.ptr<bf16, 1>>
%25 = tt.addptr %24, %23 : tensor<64x512x!tt.ptr<bf16, 1>>, tensor<64x512xi32>
%26 = tt.load %25, %19, %cst_6 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<64x512xbf16>
%27 = arith.extf %26 : tensor<64x512xbf16> to tensor<64x512xf32>
%28 = arith.cmpi slt, %10, %cst_5 : tensor<64x1xi64>
%29 = arith.addi %10, %cst_4 : tensor<64x1xi64>
%30 = arith.select %28, %29, %10 : tensor<64x1xi1>, tensor<64x1xi64>
%31 = arith.muli %30, %cst_3 : tensor<64x1xi64>
%32 = tt.broadcast %31 : (tensor<64x1xi64>) -> tensor<64x512xi64>
%33 = arith.extsi %7 : tensor<1x512xi32> to tensor<1x512xi64>
%34 = tt.broadcast %33 : (tensor<1x512xi64>) -> tensor<64x512xi64>
%35 = arith.addi %34, %32 : tensor<64x512xi64>
%36 = tt.splat %arg1 : (!tt.ptr<f32, 1>) -> tensor<64x512x!tt.ptr<f32, 1>>
%37 = tt.addptr %36, %35 : tensor<64x512x!tt.ptr<f32, 1>>, tensor<64x512xi64>
%38 = tt.load %37, %19, %cst_2 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<64x512xf32>
%39 = arith.addf %38, %20 : tensor<64x512xf32>
%40 = arith.addf %39, %27 : tensor<64x512xf32>
%41 = arith.addf %40, %cst_2 : tensor<64x512xf32>
%42 = arith.subf %40, %41 : tensor<64x512xf32>
%43 = arith.mulf %40, %42 : tensor<64x512xf32>
%44 = arith.addf %43, %cst_2 : tensor<64x512xf32>
%45 = arith.select %19, %41, %cst_2 : tensor<64x512xi1>, tensor<64x512xf32>
%46 = arith.select %19, %44, %cst_2 : tensor<64x512xi1>, tensor<64x512xf32>
%47 = arith.select %12, %cst, %cst_0 : tensor<1x512xi1>, tensor<1x512xf32>
%48 = tt.broadcast %47 : (tensor<1x512xf32>) -> tensor<64x512xf32>
%49:3 = "tt.reduce"(%45, %46, %48) <{axis = 1 : i32}> ({
^bb0(%arg8: f32, %arg9: f32, %arg10: f32, %arg11: f32, %arg12: f32, %arg13: f32):
%72 = arith.subf %arg11, %arg8 : f32
%73 = arith.addf %arg10, %arg13 : f32
%74 = arith.cmpf oeq, %73, %cst_1 : f32
%75 = arith.divf %arg13, %73 : f32
%76 = arith.select %74, %cst_1, %75 : f32
%77 = arith.mulf %72, %76 : f32
%78 = arith.addf %arg8, %77 : f32
%79 = arith.addf %arg9, %arg12 : f32
%80 = arith.mulf %72, %72 : f32
%81 = arith.mulf %80, %arg10 : f32
%82 = arith.mulf %81, %76 : f32
%83 = arith.addf %79, %82 : f32
tt.reduce.return %78, %83, %73 : f32, f32, f32
}) : (tensor<64x512xf32>, tensor<64x512xf32>, tensor<64x512xf32>) -> (tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
%50 = tt.expand_dims %49#0 {axis = 1 : i32} : (tensor<64xf32>) -> tensor<64x1xf32>
%51 = tt.expand_dims %49#1 {axis = 1 : i32} : (tensor<64xf32>) -> tensor<64x1xf32>
%52 = tt.load %25, %19, %cst_6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x512xbf16>
%53 = arith.extf %52 : tensor<64x512xbf16> to tensor<64x512xf32>
%54 = tt.splat %arg4 : (!tt.ptr<f32, 1>) -> tensor<1x512x!tt.ptr<f32, 1>>
%55 = tt.addptr %54, %7 : tensor<1x512x!tt.ptr<f32, 1>>, tensor<1x512xi32>
%56 = tt.load %55, %12, %cst_0 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<1x512xf32>
%57 = tt.load %37, %19, %cst_2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x512xf32>
%58 = arith.addf %57, %20 : tensor<64x512xf32>
%59 = arith.addf %58, %53 : tensor<64x512xf32>
%60 = tt.broadcast %50 : (tensor<64x1xf32>) -> tensor<64x512xf32>
%61 = arith.subf %59, %60 : tensor<64x512xf32>
%62 = arith.divf %51, %cst_8 : tensor<64x1xf32>
%63 = arith.addf %62, %cst_7 : tensor<64x1xf32>
%64 = tt.extern_elementwise %63 {libname = "libdevice", libpath = "/home/lanwo/code/triton-dlc/triton/python/triton/language/../third_party/cuda/lib/libdevice.10.bc", pure = true, symbol = "__nv_rsqrtf"} : (tensor<64x1xf32>) -> tensor<64x1xf32>
%65 = tt.broadcast %64 : (tensor<64x1xf32>) -> tensor<64x512xf32>
%66 = arith.mulf %61, %65 : tensor<64x512xf32>
%67 = tt.broadcast %56 : (tensor<1x512xf32>) -> tensor<64x512xf32>
%68 = arith.mulf %66, %67 : tensor<64x512xf32>
%69 = tt.splat %arg5 : (!tt.ptr<bf16, 1>) -> tensor<64x512x!tt.ptr<bf16, 1>>
%70 = tt.addptr %69, %23 : tensor<64x512x!tt.ptr<bf16, 1>>, tensor<64x512xi32>
%71 = arith.truncf %68 : tensor<64x512xf32> to tensor<64x512xbf16>
tt.store %70, %71, %19 {cache = 1 : i32, evict = 1 : i32} : tensor<64x512xbf16>
tt.return
}
}
Running triton-shared-opt --triton-to-linalg on the IR throws this error:
|
Thank you @paretoth for the original triton python code. This looks like it was generated by I will do an initial investigation and open up issues that might need to be addressed in order to lower this program. Keep in mind that you will need to do additional lowering regardless (e.g: |
Thanks for your reply. It was indeed generated by torch-inductor. |
This PR addresses several issues: 1. Support modulo pattern generated by torch inductor We now support `tl.arange(0, size)[:, None] % mod` -- expanding the shape before applying the modulo). Fixes #14 #48. 2. Add more modulo tests running on the CPU backend, I found out that our current usages of `memref.reinterpret_cast` to support modulo ops in a loop are incorrect. Previously we insert two "no-op" `memref.reinterpret_cast` for the two blocks of the mod pointers so that `LoadConveter` can determine the sizes of the blocks to copy to the local buffers. However, when lowering all the way to llvm, doing this meant that we are resetting the offset of the blocks being yielded in each loop interation. To solve this, I have replaced the casts with the proper `memref.dim_op` to get the correct sizes. 3. Fix individual modulo block's type can sometimes mismatch in a loop Previously, the types for each individual modulo block can have static strides. During a loop, their corresponding loop's yield values have dynamic strides, causing type mismatch. I have instead make the strides always dynamic to begin with. 4. Support lowering to CPU for more cases Lowering to memref can produces more affine ops which we would have already run in the current pass ordering. I have added two additional passes in the pass list to fix this issue. 5. Add softmax tutorial test for CPU backend
Fixed by #68 |
Running triton-shared-opt --triton-to-linalg on the IR throws this error:
The text was updated successfully, but these errors were encountered: