Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TMTensor] Cast i1 to i32by extsi instead of trunci for aten scatter_add #3947

Merged
merged 1 commit into from
Jan 10, 2025

Conversation

AmosLewis
Copy link
Collaborator

@AmosLewis AmosLewis commented Jan 8, 2025

To fix nod-ai/SHARK-ModelDev#898 issue arise from arith::trunci i1 to i64. Should use arith.extui instead.

  • arith::trunci: The integer truncation operation takes an integer input of width M and an integer destination type of width N. The destination bit-width must be smaller than the input bit-width (N < M). The top-most (N - M) bits of the input are discarded.

Also add dynamic e2e test for aten.scatter_add op in passing.

Follow up of onnx.Compress op nod-ai/SHARK-ModelDev#893 (comment)

torch-mlir-opt -pass-pipeline='builtin.module(torch-onnx-to-torch-backend-pipeline{backend-legal-ops=aten.flatten.using_ints,aten.unflatten.int})' ./compress.onnx.mlir > compress.torch.mlir
(mlir_venv) (test_suite.venv) ➜  torch-mlir git:(scanfix) ✗ torch-mlir-opt -pass-pipeline='builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)' compress.torch.mlir > compress.linalg.mlir
compress.torch.mlir:24:11: error: 'arith.trunci' op operand type 'i1' and result type 'i32' are cast incompatible
    %14 = torch.aten.scatter_add %13, %int0, %6, %9 : !torch.vtensor<[?],i1>, !torch.int, !torch.vtensor<[?],i1>, !torch.vtensor<[?],i1> -> !torch.vtensor<[?],i1>
          ^
compress.torch.mlir:24:11: note: see current operation: %109 = "arith.trunci"(%105) : (i1) -> i32

compress.onnx.mlir

module {
  func.func @CNTKGraph(%311:!torch.vtensor<[1,?],f32>, %312:!torch.vtensor<[?],i1> ) -> (!torch.vtensor<[1],f32>)  attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.opset_versions = {ai.onnx.ml = 1 : si64}, torch.onnx_meta.producer_name = "CNTK", torch.onnx_meta.producer_version = "2.7"} {
    %313 = torch.operator "onnx.Compress"(%311, %312) : (!torch.vtensor<[1,?],f32>, !torch.vtensor<[?],i1>) -> !torch.vtensor<[1],f32> 
    return %313: !torch.vtensor<[1],f32> 
  }
}

scatter_add.torch.mlir

module {
  func.func @CNTKGraph(%13: !torch.vtensor<[?],i1>, %6: !torch.vtensor<[?],i1>, %9:!torch.vtensor<[?],i1>) -> !torch.vtensor<[?],i1> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.opset_versions = {ai.onnx.ml = 1 : si64}, torch.onnx_meta.producer_name = "CNTK", torch.onnx_meta.producer_version = "2.7"} {
    %int0 = torch.constant.int 0    
    %14 = torch.aten.scatter_add %13, %int0, %6, %9 : !torch.vtensor<[?],i1>, !torch.int, !torch.vtensor<[?],i1>, !torch.vtensor<[?],i1> -> !torch.vtensor<[?],i1>
    return %14 : !torch.vtensor<[?],i1>
  }
}

scatter_add.tm.mlir

#map = affine_map<(d0) -> (d0, 0)>
#map1 = affine_map<(d0) -> (d0)>
module {
  func.func @CNTKGraph(%arg0: !torch.vtensor<[?],i1>, %arg1: !torch.vtensor<[?],i1>, %arg2: !torch.vtensor<[?],i1>) -> !torch.vtensor<[?],i1> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.opset_versions = {ai.onnx.ml = 1 : si64}, torch.onnx_meta.producer_name = "CNTK", torch.onnx_meta.producer_version = "2.7"} {
    %0 = torch_c.to_builtin_tensor %arg2 : !torch.vtensor<[?],i1> -> tensor<?xi1>
    %1 = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[?],i1> -> tensor<?xi1>
    %2 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?],i1> -> tensor<?xi1>
    %int0 = torch.constant.int 0
    %c0 = arith.constant 0 : index
    %dim = tensor.dim %1, %c0 : tensor<?xi1>
    %c1 = arith.constant 1 : index
    %3 = arith.muli %c1, %dim : index
    %4 = arith.index_cast %3 : index to i64
    %5 = arith.index_cast %4 : i64 to index
    %c0_0 = arith.constant 0 : index
    %dim_1 = tensor.dim %1, %c0_0 : tensor<?xi1>
    %c1_2 = arith.constant 1 : index
    %6 = tensor.empty(%5) : tensor<?x1xi32>
    %c0_i32 = arith.constant 0 : i32
    %7 = linalg.fill ins(%c0_i32 : i32) outs(%6 : tensor<?x1xi32>) -> tensor<?x1xi32>
    %8 = tensor.empty(%5) : tensor<?xi1>
    %false = arith.constant false
    %9 = linalg.fill ins(%false : i1) outs(%8 : tensor<?xi1>) -> tensor<?xi1>
    %10:2 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel"]} outs(%7, %9 : tensor<?x1xi32>, tensor<?xi1>) {
    ^bb0(%out: i32, %out_14: i1):
      %15 = linalg.index 0 : index
      %16 = arith.remsi %15, %dim_1 : index
      %17 = arith.divsi %15, %dim_1 : index
      %extracted = tensor.extract %1[%16] : tensor<?xi1>
      %extracted_15 = tensor.extract %0[%16] : tensor<?xi1>
      %18 = arith.index_cast %16 : index to i64
      %19 = arith.trunci %18 : i64 to i32
      %20 = arith.extui %extracted : i1 to i32
      linalg.yield %20, %extracted_15 : i32, i1
    } -> (tensor<?x1xi32>, tensor<?xi1>)
    %c0_3 = arith.constant 0 : index
    %c0_4 = arith.constant 0 : index
    %c1_5 = arith.constant 1 : index
    %c1_6 = arith.constant 1 : index
    %c1_7 = arith.constant 1 : index
    %11 = tensor.empty(%5) : tensor<?x1xi32>
    %c0_i32_8 = arith.constant 0 : i32
    %12 = linalg.fill ins(%c0_i32_8 : i32) outs(%11 : tensor<?x1xi32>) -> tensor<?x1xi32>
    %c0_9 = arith.constant 0 : index
    %dim_10 = tensor.dim %10#0, %c0_9 : tensor<?x1xi32>
    %c1_11 = arith.constant 1 : index
    %c1_12 = arith.constant 1 : index
    %inserted_slice = tensor.insert_slice %10#0 into %12[0, 0] [%dim_10, 1] [1, 1] : tensor<?x1xi32> into tensor<?x1xi32>
    %c1_13 = arith.constant 1 : index
    %13 = tm_tensor.scatter {dimension_map = array<i64: 0>} unique_indices(false) ins(%10#1, %inserted_slice : tensor<?xi1>, tensor<?x1xi32>) outs(%2 : tensor<?xi1>) {
    ^bb0(%arg3: i1, %arg4: i1):
      %15 = arith.addi %arg4, %arg3 : i1
      tm_tensor.yield %15 : i1
    } -> tensor<?xi1>
    %cast = tensor.cast %13 : tensor<?xi1> to tensor<?xi1>
    %14 = torch_c.from_builtin_tensor %cast : tensor<?xi1> -> !torch.vtensor<[?],i1>
    return %14 : !torch.vtensor<[?],i1>
  }
}

…_add

Add dynamic e2e test for aten.scatter_add op
@AmosLewis AmosLewis changed the title [TMTensor] Cast i1 to i64 by extsi instead of turnci for aten scatter_add [TMTensor] Cast i1 to i64 by extsi instead of trunci for aten scatter_add Jan 8, 2025
@AmosLewis AmosLewis changed the title [TMTensor] Cast i1 to i64 by extsi instead of trunci for aten scatter_add [TMTensor] Cast i1 to i32by extsi instead of trunci for aten scatter_add Jan 8, 2025
@AmosLewis AmosLewis requested a review from jinchen62 January 8, 2025 17:49
@jinchen62
Copy link
Collaborator

What's the error of ScatterAddDynamicModule_basic?

@AmosLewis
Copy link
Collaborator Author

What's the error of ScatterAddDynamicModule_basic?

no error for it for now, just I found for scatter_add op, there is only static test, so I add the dynamic version in passing.

Copy link
Collaborator

@jinchen62 jinchen62 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to figure out the e2e test failure.

Copy link
Collaborator

@jinchen62 jinchen62 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@AmosLewis AmosLewis merged commit a45356e into llvm:main Jan 10, 2025
3 checks passed
@AmosLewis AmosLewis deleted the scatteradd branch January 10, 2025 01:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

onnx.Compress fail at 'arith.trunci' op operand type 'i1' and result type 'i32' are cast incompatible
2 participants