Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: failed to legalize operation 'arith.cmpi' marked as erased #64

Open
yuanfz98 opened this issue Nov 27, 2023 · 7 comments
Open

[Bug]: failed to legalize operation 'arith.cmpi' marked as erased #64

yuanfz98 opened this issue Nov 27, 2023 · 7 comments
Labels
bug Something isn't working

Comments

@yuanfz98
Copy link
Contributor

Triton python code

def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, out_ptr2, xnumel, rnumel):
    xnumel = 16384
    XBLOCK: tl.constexpr = 1
    rnumel = 384
    RBLOCK: tl.constexpr = 512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = tl.full([1], xoffset, tl.int32)
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[:]
    rmask = rindex < rnumel
    r1 = rindex
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r1 + (384*x0)), rmask, other=0).to(tl.float32)
    tmp2 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0)
    tmp8 = tl.load(in_ptr2 + (r1 + (384*x0)), rmask, other=0)
    tmp9 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last')
    tmp11 = tl.load(in_ptr4 + (x0), None, eviction_policy='evict_last')
    tmp18 = tl.load(in_out_ptr0 + (r1 + (384*x0)), rmask, other=0)
    tmp28 = tl.load(in_ptr5 + (r1 + (384*x0)), rmask)
    tmp1 = tmp0.to(tl.float32)
    tmp3 = tmp1 * tmp2
    tmp4 = tl.broadcast_to(tmp3, [RBLOCK])
    tmp6 = tl.where(rmask, tmp4, 0)
    tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp6, 0))
    tmp10 = tmp8 - tmp9
    tmp12 = tmp10 * tmp11
    tmp13 = tmp3 * tmp12
    tmp14 = tl.broadcast_to(tmp13, [RBLOCK])
    tmp16 = tl.where(rmask, tmp14, 0)
    tmp17 = triton_helpers.promote_to_tensor(tl.sum(tmp16, 0))
    tmp19 = 384.0
    tmp20 = tmp11 / tmp19
    tmp21 = tmp3 * tmp19
    tmp22 = tmp21 - tmp7
    tmp23 = tmp12 * tmp17
    tmp24 = tmp22 - tmp23
    tmp25 = tmp20 * tmp24
    tmp26 = tmp18 + tmp25
    tmp27 = tmp26.to(tl.float32)
    tmp29 = tmp28.to(tl.float32)
    tmp30 = 1.25
    tmp31 = tmp29 * tmp30
    tmp32 = tmp27 * tmp31
    tl.store(in_out_ptr0 + (r1 + (384*x0)), tmp26, rmask)
    tl.store(out_ptr2 + (r1 + (384*x0)), tmp32, rmask)

Triton IR

module {
  tt.func public @triton__0d1d2d3d4d5d6d7d8de9de(%arg0: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<bf16, 1> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<i1, 1> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<bf16, 1> {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg9: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} {
    %c384_i32 = arith.constant 384 : i32
    %cst = arith.constant 0.000000e+00 : f32
    %cst_0 = arith.constant dense<0.000000e+00> : tensor<512xf32>
    %cst_1 = arith.constant dense<0.000000e+00> : tensor<512xbf16>
    %cst_2 = arith.constant dense<1.250000e+00> : tensor<512xf32>
    %cst_3 = arith.constant dense<3.840000e+02> : tensor<512xf32>
    %cst_4 = arith.constant dense<3.840000e+02> : tensor<1xf32>
    %cst_5 = arith.constant dense<384> : tensor<512xi32>
    %0 = tt.get_program_id x : i32
    %1 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32>
    %2 = arith.cmpi slt, %1, %cst_5 : tensor<512xi32>
    %3 = arith.muli %0, %c384_i32 : i32
    %4 = tt.splat %3 : (i32) -> tensor<512xi32>
    %5 = arith.addi %1, %4 : tensor<512xi32>
    %6 = tt.splat %arg1 : (!tt.ptr<bf16, 1>) -> tensor<512x!tt.ptr<bf16, 1>>
    %7 = tt.addptr %6, %5 : tensor<512x!tt.ptr<bf16, 1>>, tensor<512xi32>
    %8 = tt.load %7, %2, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<512xbf16>
    %9 = arith.extf %8 : tensor<512xbf16> to tensor<512xf32>
    %10 = tt.splat %arg2 : (!tt.ptr<f32, 1>) -> tensor<512x!tt.ptr<f32, 1>>
    %11 = tt.addptr %10, %1 : tensor<512x!tt.ptr<f32, 1>>, tensor<512xi32>
    %12 = tt.load %11, %2, %cst_0 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<512xf32>
    %13 = tt.splat %arg3 : (!tt.ptr<f32, 1>) -> tensor<512x!tt.ptr<f32, 1>>
    %14 = tt.addptr %13, %5 : tensor<512x!tt.ptr<f32, 1>>, tensor<512xi32>
    %15 = tt.load %14, %2, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<512xf32>
    %16 = tt.addptr %arg4, %0 : !tt.ptr<f32, 1>, i32
    %17 = tt.splat %16 : (!tt.ptr<f32, 1>) -> tensor<1x!tt.ptr<f32, 1>>
    %18 = tt.load %17 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<1xf32>
    %19 = tt.addptr %arg5, %0 : !tt.ptr<f32, 1>, i32
    %20 = tt.splat %19 : (!tt.ptr<f32, 1>) -> tensor<1x!tt.ptr<f32, 1>>
    %21 = tt.load %20 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<1xf32>
    %22 = tt.splat %arg0 : (!tt.ptr<f32, 1>) -> tensor<512x!tt.ptr<f32, 1>>
    %23 = tt.addptr %22, %5 : tensor<512x!tt.ptr<f32, 1>>, tensor<512xi32>
    %24 = tt.load %23, %2, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<512xf32>
    %25 = tt.splat %arg6 : (!tt.ptr<i1, 1>) -> tensor<512x!tt.ptr<i1, 1>>
    %26 = tt.addptr %25, %5 : tensor<512x!tt.ptr<i1, 1>>, tensor<512xi32>
    %27 = tt.bitcast %26 : tensor<512x!tt.ptr<i1, 1>> -> tensor<512x!tt.ptr<i8, 1>>
    %28 = tt.load %27, %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<512xi8>
    %29 = arith.mulf %9, %12 : tensor<512xf32>
    %30 = arith.select %2, %29, %cst_0 : tensor<512xi1>, tensor<512xf32>
    %31 = "tt.reduce"(%30) <{axis = 0 : i32}> ({
    ^bb0(%arg10: f32, %arg11: f32):
      %57 = arith.addf %arg10, %arg11 : f32
      tt.reduce.return %57 : f32
    }) : (tensor<512xf32>) -> f32
    %32 = arith.addf %31, %cst : f32
    %33 = tt.broadcast %18 : (tensor<1xf32>) -> tensor<512xf32>
    %34 = arith.subf %15, %33 : tensor<512xf32>
    %35 = tt.broadcast %21 : (tensor<1xf32>) -> tensor<512xf32>
    %36 = arith.mulf %34, %35 : tensor<512xf32>
    %37 = arith.mulf %29, %36 : tensor<512xf32>
    %38 = arith.select %2, %37, %cst_0 : tensor<512xi1>, tensor<512xf32>
    %39 = "tt.reduce"(%38) <{axis = 0 : i32}> ({
    ^bb0(%arg10: f32, %arg11: f32):
      %57 = arith.addf %arg10, %arg11 : f32
      tt.reduce.return %57 : f32
    }) : (tensor<512xf32>) -> f32
    %40 = arith.addf %39, %cst : f32
    %41 = arith.divf %21, %cst_4 : tensor<1xf32>
    %42 = arith.mulf %29, %cst_3 : tensor<512xf32>
    %43 = tt.splat %32 : (f32) -> tensor<512xf32>
    %44 = arith.subf %42, %43 : tensor<512xf32>
    %45 = tt.splat %40 : (f32) -> tensor<512xf32>
    %46 = arith.mulf %36, %45 : tensor<512xf32>
    %47 = arith.subf %44, %46 : tensor<512xf32>
    %48 = tt.broadcast %41 : (tensor<1xf32>) -> tensor<512xf32>
    %49 = arith.mulf %48, %47 : tensor<512xf32>
    %50 = arith.addf %24, %49 : tensor<512xf32>
    %51 = arith.sitofp %28 : tensor<512xi8> to tensor<512xf32>
    %52 = arith.mulf %51, %cst_2 : tensor<512xf32>
    %53 = arith.mulf %50, %52 : tensor<512xf32>
    tt.store %23, %50, %2 {cache = 1 : i32, evict = 1 : i32} : tensor<512xf32>
    %54 = tt.splat %arg7 : (!tt.ptr<bf16, 1>) -> tensor<512x!tt.ptr<bf16, 1>>
    %55 = tt.addptr %54, %5 : tensor<512x!tt.ptr<bf16, 1>>, tensor<512xi32>
    %56 = arith.truncf %53 : tensor<512xf32> to tensor<512xbf16>
    tt.store %55, %56, %2 {cache = 1 : i32, evict = 1 : i32} : tensor<512xbf16>
    tt.return
  }
}

Crash log

/workspace/hongjing/temp/fg66ptts/triton_.ttir:13:10: error: failed to legalize operation 'arith.cmpi' marked as erased
    %2 = arith.cmpi slt, %1, %cst_5 : tensor<512xi32>
         ^
/workspace/hongjing/temp/fg66ptts/triton_.ttir:13:10: note: see current operation: %29 = "arith.cmpi"(%28, %23) <{predicate = 2 : i64}> {MetaUse} : (tensor<512xi32>, tensor<512xi32>) -> tensor<512xi1>
/workspace/hongjing/temp/fg66ptts/triton_.ttir:36:11: note: found live user of result #0: %112 = "tt.load"(%111, %29) <{cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 0>}> : (tensor<512x!tt.ptr<i8, 1>>, tensor<512xi1>) -> tensor<512xi8>
    %25 = tt.load %24, %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<512xi8>

Additional information

No response

@yuanfz98 yuanfz98 added the bug Something isn't working label Nov 27, 2023
@yuanfz98
Copy link
Contributor Author

yuanfz98 commented Nov 27, 2023

@nhat-nguyen
By analogy with #65, we have to ensure nothing inserted between addptr and load.

    %26 = tt.addptr %25, %5 : tensor<512x!tt.ptr<i1, 1>>, tensor<512xi32>
    %27 = tt.bitcast %26 : tensor<512x!tt.ptr<i1, 1>> -> tensor<512x!tt.ptr<i8, 1>>
    %28 = tt.load %27, %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<512xi8>

A bitcast of tt.ptr can be moved before addptr. In PtrAnalysis I propose to add a "Type srcElemType" in PtrState struct to represent type information. We then propagate the type Info during recursion, hoping that in the end we will get a "final consistant" type for AddPtrConverter.
But on the other side, we may don't have an op in memref that casts datatype of memref.

@jingchangshi
Copy link

jingchangshi commented Mar 26, 2024

@nhat-nguyen By analogy with #65, we have to ensure nothing inserted between addptr and load.

    %26 = tt.addptr %25, %5 : tensor<512x!tt.ptr<i1, 1>>, tensor<512xi32>
    %27 = tt.bitcast %26 : tensor<512x!tt.ptr<i1, 1>> -> tensor<512x!tt.ptr<i8, 1>>
    %28 = tt.load %27, %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<512xi8>

A bitcast of tt.ptr can be moved before addptr. In PtrAnalysis I propose to add a "Type srcElemType" in PtrState struct to represent type information. We then propagate the type Info during recursion, hoping that in the end we will get a "final consistant" type for AddPtrConverter. But on the other side, we may don't have an op in memref that casts datatype of memref.

Have you solved this issue?
tt.bitcast would be lowered to linalg.generic wrapping arith.bitcast. However, arith.bitcast does not accept !tt.ptr<i1, 1>. How could you resolve this issue?
In TritonGPUToLLVM, tt.bitcast is lowered to llvm.bitcast. So do we need to do so?

@nhat-nguyen
Copy link
Collaborator

@jingchangshi We're in the process of testing a rewrite for the Pointer Analysis pass. Most of the logic is kept the same, so you will probably also encounter the same issue. However, the passes are broken down into modular pieces which might make some of the suggestions here such as making a canonicalization pass easier. In the mean time you can test out with --triton-to-structured --structured-to-linalg --structured-to-memref. We would appreciate if you could report any issues with the new pass. I'll have more time to investigate these issues later once we know the rewrite is stable.

@jingchangshi
Copy link

jingchangshi commented Apr 1, 2024

@jingchangshi We're in the process of testing a rewrite for the Pointer Analysis pass. Most of the logic is kept the same, so you will probably also encounter the same issue. However, the passes are broken down into modular pieces which might make some of the suggestions here such as making a canonicalization pass easier. In the mean time you can test out with --triton-to-structured --structured-to-linalg --structured-to-memref. We would appreciate if you could report any issues with the new pass. I'll have more time to investigate these issues later once we know the rewrite is stable.

I'm still working on the unrefactored version to check if the fix works. Then I'll move to the latest version.
However, the problem is the following reported error:

error: 'llvm.bitcast' op operand #0 must be LLVM-compatible non-aggregate type, but got 'memref<?xi1>'

Just as @yuanfz98 said,

But on the other side, we may don't have an op in memref that casts datatype of memref.

Do you have a solution to this problem? llvm.bitcast and arith.bitcast cannot be used. mlir::UnrealizedConversionCastOp is used to finish the pass.

@fhossein-quic
Copy link

@jingchangshi We're in the process of testing a rewrite for the Pointer Analysis pass. Most of the logic is kept the same, so you will probably also encounter the same issue. However, the passes are broken down into modular pieces which might make some of the suggestions here such as making a canonicalization pass easier. In the mean time you can test out with --triton-to-structured --structured-to-linalg --structured-to-memref. We would appreciate if you could report any issues with the new pass. I'll have more time to investigate these issues later once we know the rewrite is stable.

I'm still working on the unrefactored version to check if the fix works. Then I'll move to the latest version. However, the problem is the following reported error:

error: 'llvm.bitcast' op operand #0 must be LLVM-compatible non-aggregate type, but got 'memref<?xi1>'

Just as @yuanfz98 said,

But on the other side, we may don't have an op in memref that casts datatype of memref.

Do you have a solution to this problem? llvm.bitcast and arith.bitcast cannot be used. mlir::UnrealizedConversionCastOp is used to finish the pass.

@nhat-nguyen, I wanted to follow up on this issue. I’m encountering multiple “failed to legalize operation tt.XYZ marked as erased” errors while lowering Triton kernels generated by Torch inductors to Linalg. I tried the --triton-to-linalg-experimental flag, which essentially implements --triton-to-structured --canonicalize --triton-arith-to-linalg --structured-to-memref. It got rid of the original error (I see some “remark: PtrAnalysis: Failed to rewrite StoreOp” though), but now I get LLVM ERROR: Failed to infer result type(s) for the related load and store ops.

@nhat-nguyen
Copy link
Collaborator

@fhossein-quic More torch-inductor support is definitely on our radar as there have been lots of interest. A lot of these issues are because we couldn't analyze the pointer arithmetic sequence generated by torch-inductor. There are some work in torch from the Meta team to simplify these pointer arithmetic: pytorch/pytorch#125077, but in general, we need a fallback mode in our pointer analysis pass to support all cases.

In any case, if you could share your step-by-step reproduce along with the torch programs that you want to compile, that would be great. Thank you!

@fhossein-quic
Copy link

fhossein-quic commented Jul 23, 2024

@fhossein-quic More torch-inductor support is definitely on our radar as there have been lots of interest. A lot of these issues are because we couldn't analyze the pointer arithmetic sequence generated by torch-inductor. There are some work in torch from the Meta team to simplify these pointer arithmetic: pytorch/pytorch#125077, but in general, we need a fallback mode in our pointer analysis pass to support all cases.

In any case, if you could share your step-by-step reproduce along with the torch programs that you want to compile, that would be great. Thank you!

Thank you @nhat-nguyen for your response,
I've attached two TTIR files related to simple conv2d and avg_pool2d kernels with kernel_size = 2 and random input of size= (1,1,4,4).

You may reproduce the issue by running
triton-shared-opt <ttir_file_name> --triton-to-linalg-experimental.

I also attached the triton kernel generated by the inductor JIC.

TTIR:
avgpool.ttir.txt
conv2d.ttir.txt
Code by Inductor:
avgpool2d_by_inductor.txt
conv2d_by_inductor.txt

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants