Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Investigate correct uses of linalg out params #196

Open
nhat-nguyen opened this issue Nov 29, 2024 · 1 comment
Open

Investigate correct uses of linalg out params #196

nhat-nguyen opened this issue Nov 29, 2024 · 1 comment

Comments

@nhat-nguyen
Copy link
Collaborator

#191 reported a bug where the bufferization result is incorrect when reusing tensor values in linalg's init param. Since we rely on the upstream arith to linalg pass, we do not have control over which value is being used as the init param. This is likely creating a lot of codegen bugs at the moment.

@MercuryChen
Copy link
Contributor

While convert arith.add of tensor to linalg.generic, upstream MLIR converter will prior to take the lhs as the DPS init of generic op. Arith to linalg code
Under some circumstance this behavior is not safety, let's take follows case to further explanation. The IR is lowered from matmul with spilt-K

#map = affine_map<(d0, d1) -> (d0, d1)>
module {
  func.func @matmul_kernel(%arg0: memref<*xf32> {tt.divisibility = 16 : i32}, %arg1: memref<*xf32> {tt.divisibility = 16 : i32}, %arg2: memref<*xf32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32) {
    %c8_i32 = arith.constant 8 : i32
    %c32_i32 = arith.constant 32 : i32
    %c64_i32 = arith.constant 64 : i32
    %c16_i32 = arith.constant 16 : i32
    %c0_i32 = arith.constant 0 : i32
    %c1_i32 = arith.constant 1 : i32
    %c31_i32 = arith.constant 31 : i32
    %c63_i32 = arith.constant 63 : i32
    %c15_i32 = arith.constant 15 : i32
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %c16 = arith.constant 16 : index
    %cst = arith.constant 0.000000e+00 : f32
    %c32 = arith.constant 32 : index
    %c64 = arith.constant 64 : index
    %0 = tensor.empty() : tensor<32x64xf32>
    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<32x64xf32>) -> tensor<32x64xf32>
    %2 = arith.addi %arg3, %c31_i32 : i32
    %3 = arith.divsi %2, %c32_i32 : i32
    %4 = arith.addi %arg4, %c63_i32 : i32
    %5 = arith.divsi %4, %c64_i32 : i32
    %6 = arith.muli %5, %c8_i32 : i32
    %7 = arith.divsi %arg12, %6 : i32
    %8 = arith.muli %7, %c8_i32 : i32
    %9 = arith.subi %3, %8 : i32
    %10 = arith.minsi %9, %c8_i32 : i32
    %11 = arith.remsi %arg12, %10 : i32
    %12 = arith.addi %8, %11 : i32
    %13 = arith.remsi %arg12, %6 : i32
    %14 = arith.divsi %13, %10 : i32
    %15 = arith.muli %12, %c32_i32 : i32
    %16 = arith.index_cast %15 : i32 to index
    %17 = arith.muli %14, %c64_i32 : i32
    %18 = arith.index_cast %17 : i32 to index
    %19 = arith.index_cast %arg3 : i32 to index
    %20 = arith.index_cast %arg6 : i32 to index
    %21 = arith.muli %16, %20 : index
    %22 = arith.muli %19, %20 : index
    %23 = arith.index_cast %arg7 : i32 to index
    %24 = arith.index_cast %arg4 : i32 to index
    %25 = arith.addi %arg5, %c15_i32 : i32
    %26 = arith.divsi %25, %c16_i32 : i32
    %27 = arith.muli %arg7, %c16_i32 : i32
    %28 = arith.index_cast %27 : i32 to index
    %29:3 = scf.for %arg15 = %c0_i32 to %26 step %c1_i32 iter_args(%arg16 = %1, %arg17 = %21, %arg18 = %c0) -> (tensor<32x64xf32>, index, index)  : i32 {
      %43 = arith.addi %arg18, %18 : index
      %44 = arith.remsi %43, %24 : index
      %45 = arith.subi %43, %44 : index
      %46 = arith.addi %44, %c64 : index
      %47 = arith.minsi %46, %24 : index
      %48 = arith.subi %47, %44 : index
      %reinterpret_cast_0 = memref.reinterpret_cast %arg1 to offset: [%43], sizes: [%c16, %48], strides: [%23, %c1] : memref<*xf32> to memref<16x?xf32, strided<[?, ?], offset: ?>>
      %49 = arith.subi %c64, %48 : index
      %reinterpret_cast_1 = memref.reinterpret_cast %arg1 to offset: [%45], sizes: [%c16, %49], strides: [%23, %c1] : memref<*xf32> to memref<16x?xf32, strided<[?, ?], offset: ?>>
      %50 = arith.remsi %arg17, %20 : index
      %51 = arith.addi %22, %50 : index
      %52 = arith.subi %51, %arg17 : index
      %53 = arith.divsi %52, %20 : index
      %reinterpret_cast_2 = memref.reinterpret_cast %arg0 to offset: [%arg17], sizes: [%53, %c16], strides: [%20, %c1] : memref<*xf32> to memref<?x16xf32, strided<[?, ?], offset: ?>>
      %54 = arith.subi %c32, %53 : index
      %reinterpret_cast_3 = memref.reinterpret_cast %arg0 to offset: [%50], sizes: [%54, %c16], strides: [%20, %c1] : memref<*xf32> to memref<?x16xf32, strided<[?, ?], offset: ?>>
      %55 = arith.muli %arg15, %c16_i32 : i32
      %56 = arith.subi %arg5, %55 : i32
      %57 = arith.index_cast %56 : i32 to index
      %58 = arith.minsi %57, %c16 : index
      %59 = arith.maxsi %58, %c0 : index
      %alloc = memref.alloc() : memref<32x16xf32>
      %60 = arith.cmpi slt, %59, %c16 : index
      scf.if %60 {
        linalg.fill ins(%cst : f32) outs(%alloc : memref<32x16xf32>)
      }
      %61 = arith.minsi %53, %c32 : index
      %62 = arith.subi %c32, %61 : index
      %subview_4 = memref.subview %reinterpret_cast_2[0, 0] [%61, %59] [1, 1] : memref<?x16xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_5 = memref.subview %reinterpret_cast_3[0, 0] [%62, %59] [1, 1] : memref<?x16xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_6 = memref.subview %alloc[0, 0] [%61, %59] [1, 1] : memref<32x16xf32> to memref<?x?xf32, strided<[16, 1]>>
      %subview_7 = memref.subview %alloc[%61, 0] [%62, %59] [1, 1] : memref<32x16xf32> to memref<?x?xf32, strided<[16, 1], offset: ?>>
      memref.copy %subview_4, %subview_6 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[16, 1]>>
      memref.copy %subview_5, %subview_7 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[16, 1], offset: ?>>
      %63 = bufferization.to_tensor %alloc restrict writable : memref<32x16xf32>
      %alloc_8 = memref.alloc() : memref<16x64xf32>
      scf.if %60 {
        linalg.fill ins(%cst : f32) outs(%alloc_8 : memref<16x64xf32>)
      }
      %64 = arith.minsi %48, %c64 : index
      %65 = arith.subi %c64, %64 : index
      %subview_9 = memref.subview %reinterpret_cast_0[0, 0] [%59, %64] [1, 1] : memref<16x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_10 = memref.subview %reinterpret_cast_1[0, 0] [%59, %65] [1, 1] : memref<16x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_11 = memref.subview %alloc_8[0, 0] [%59, %64] [1, 1] : memref<16x64xf32> to memref<?x?xf32, strided<[64, 1]>>
      %subview_12 = memref.subview %alloc_8[0, %64] [%59, %65] [1, 1] : memref<16x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
      memref.copy %subview_9, %subview_11 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1]>>
      memref.copy %subview_10, %subview_12 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1], offset: ?>>
      %66 = bufferization.to_tensor %alloc_8 restrict writable : memref<16x64xf32>
      %67 = linalg.matmul ins(%63, %66 : tensor<32x16xf32>, tensor<16x64xf32>) outs(%1 : tensor<32x64xf32>) -> tensor<32x64xf32>
      %68 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%67, %arg16 : tensor<32x64xf32>, tensor<32x64xf32>) outs(%67 : tensor<32x64xf32>) {
      ^bb0(%in: f32, %in_13: f32, %out: f32):
        %71 = arith.addf %in, %in_13 : f32
        linalg.yield %71 : f32
      } -> tensor<32x64xf32>
      %69 = arith.addi %arg17, %c16 : index
      %70 = arith.addi %arg18, %28 : index
      scf.yield %68, %69, %70 : tensor<32x64xf32>, index, index
    }
    %30 = arith.index_cast %arg8 : i32 to index
    %31 = arith.muli %16, %30 : index
    %32 = arith.addi %31, %18 : index
    %reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [%32], sizes: [32, 64], strides: [%30, 1] : memref<*xf32> to memref<32x64xf32, strided<[?, 1], offset: ?>>
    %33 = arith.addi %16, %c32 : index
    %34 = arith.minsi %33, %19 : index
    %35 = arith.maxsi %34, %16 : index
    %36 = arith.subi %35, %16 : index
    %37 = arith.addi %18, %c64 : index
    %38 = arith.minsi %37, %24 : index
    %39 = arith.maxsi %38, %18 : index
    %40 = arith.subi %39, %18 : index
    %41 = arith.minsi %36, %c32 : index
    %42 = arith.minsi %40, %c64 : index
    %extracted_slice = tensor.extract_slice %29#0[0, 0] [%41, %42] [1, 1] : tensor<32x64xf32> to tensor<?x?xf32>
    %subview = memref.subview %reinterpret_cast[0, 0] [%41, %42] [1, 1] : memref<32x64xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
    bufferization.materialize_in_destination %extracted_slice in writable %subview : (tensor<?x?xf32>, memref<?x?xf32, strided<[?, 1], offset: ?>>) -> ()
    return
  }
}

The generic op outs is one of ins, after bufferization, the IR is:

#map = affine_map<(d0, d1) -> (d0, d1)>
module {
  func.func @matmul_kernel(%arg0: memref<*xf32> {tt.divisibility = 16 : i32}, %arg1: memref<*xf32> {tt.divisibility = 16 : i32}, %arg2: memref<*xf32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32) {
    %c8_i32 = arith.constant 8 : i32
    %c32_i32 = arith.constant 32 : i32
    %c64_i32 = arith.constant 64 : i32
    %c16_i32 = arith.constant 16 : i32
    %c0_i32 = arith.constant 0 : i32
    %c1_i32 = arith.constant 1 : i32
    %c31_i32 = arith.constant 31 : i32
    %c63_i32 = arith.constant 63 : i32
    %c15_i32 = arith.constant 15 : i32
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %c16 = arith.constant 16 : index
    %cst = arith.constant 0.000000e+00 : f32
    %c32 = arith.constant 32 : index
    %c64 = arith.constant 64 : index
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
    linalg.fill ins(%cst : f32) outs(%alloc : memref<32x64xf32>)
    %0 = arith.addi %arg3, %c31_i32 : i32
    %1 = arith.divsi %0, %c32_i32 : i32
    %2 = arith.addi %arg4, %c63_i32 : i32
    %3 = arith.divsi %2, %c64_i32 : i32
    %4 = arith.muli %3, %c8_i32 : i32
    %5 = arith.divsi %arg12, %4 : i32
    %6 = arith.muli %5, %c8_i32 : i32
    %7 = arith.subi %1, %6 : i32
    %8 = arith.minsi %7, %c8_i32 : i32
    %9 = arith.remsi %arg12, %8 : i32
    %10 = arith.addi %6, %9 : i32
    %11 = arith.remsi %arg12, %4 : i32
    %12 = arith.divsi %11, %8 : i32
    %13 = arith.muli %10, %c32_i32 : i32
    %14 = arith.index_cast %13 : i32 to index
    %15 = arith.muli %12, %c64_i32 : i32
    %16 = arith.index_cast %15 : i32 to index
    %17 = arith.index_cast %arg3 : i32 to index
    %18 = arith.index_cast %arg6 : i32 to index
    %19 = arith.muli %14, %18 : index
    %20 = arith.muli %17, %18 : index
    %21 = arith.index_cast %arg7 : i32 to index
    %22 = arith.index_cast %arg4 : i32 to index
    %23 = arith.addi %arg5, %c15_i32 : i32
    %24 = arith.divsi %23, %c16_i32 : i32
    %25 = arith.muli %arg7, %c16_i32 : i32
    %26 = arith.index_cast %25 : i32 to index
    %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
    memref.copy %alloc, %alloc_0 : memref<32x64xf32> to memref<32x64xf32>
    %27:3 = scf.for %arg15 = %c0_i32 to %24 step %c1_i32 iter_args(%arg16 = %alloc_0, %arg17 = %19, %arg18 = %c0) -> (memref<32x64xf32>, index, index)  : i32 {
      %41 = arith.addi %arg18, %16 : index
      %42 = arith.remsi %41, %22 : index
      %43 = arith.subi %41, %42 : index
      %44 = arith.addi %42, %c64 : index
      %45 = arith.minsi %44, %22 : index
      %46 = arith.subi %45, %42 : index
      %reinterpret_cast_2 = memref.reinterpret_cast %arg1 to offset: [%41], sizes: [%c16, %46], strides: [%21, %c1] : memref<*xf32> to memref<16x?xf32, strided<[?, ?], offset: ?>>
      %47 = arith.subi %c64, %46 : index
      %reinterpret_cast_3 = memref.reinterpret_cast %arg1 to offset: [%43], sizes: [%c16, %47], strides: [%21, %c1] : memref<*xf32> to memref<16x?xf32, strided<[?, ?], offset: ?>>
      %48 = arith.remsi %arg17, %18 : index
      %49 = arith.addi %20, %48 : index
      %50 = arith.subi %49, %arg17 : index
      %51 = arith.divsi %50, %18 : index
      %reinterpret_cast_4 = memref.reinterpret_cast %arg0 to offset: [%arg17], sizes: [%51, %c16], strides: [%18, %c1] : memref<*xf32> to memref<?x16xf32, strided<[?, ?], offset: ?>>
      %52 = arith.subi %c32, %51 : index
      %reinterpret_cast_5 = memref.reinterpret_cast %arg0 to offset: [%48], sizes: [%52, %c16], strides: [%18, %c1] : memref<*xf32> to memref<?x16xf32, strided<[?, ?], offset: ?>>
      %53 = arith.muli %arg15, %c16_i32 : i32
      %54 = arith.subi %arg5, %53 : i32
      %55 = arith.index_cast %54 : i32 to index
      %56 = arith.minsi %55, %c16 : index
      %57 = arith.maxsi %56, %c0 : index
      %alloc_6 = memref.alloc() : memref<32x16xf32>
      %58 = arith.cmpi slt, %57, %c16 : index
      scf.if %58 {
        linalg.fill ins(%cst : f32) outs(%alloc_6 : memref<32x16xf32>)
      }
      %59 = arith.minsi %51, %c32 : index
      %60 = arith.subi %c32, %59 : index
      %subview_7 = memref.subview %reinterpret_cast_4[0, 0] [%59, %57] [1, 1] : memref<?x16xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_8 = memref.subview %reinterpret_cast_5[0, 0] [%60, %57] [1, 1] : memref<?x16xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_9 = memref.subview %alloc_6[0, 0] [%59, %57] [1, 1] : memref<32x16xf32> to memref<?x?xf32, strided<[16, 1]>>
      %subview_10 = memref.subview %alloc_6[%59, 0] [%60, %57] [1, 1] : memref<32x16xf32> to memref<?x?xf32, strided<[16, 1], offset: ?>>
      memref.copy %subview_7, %subview_9 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[16, 1]>>
      memref.copy %subview_8, %subview_10 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[16, 1], offset: ?>>
      %alloc_11 = memref.alloc() : memref<16x64xf32>
      scf.if %58 {
        linalg.fill ins(%cst : f32) outs(%alloc_11 : memref<16x64xf32>)
      }
      %61 = arith.minsi %46, %c64 : index
      %62 = arith.subi %c64, %61 : index
      %subview_12 = memref.subview %reinterpret_cast_2[0, 0] [%57, %61] [1, 1] : memref<16x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_13 = memref.subview %reinterpret_cast_3[0, 0] [%57, %62] [1, 1] : memref<16x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_14 = memref.subview %alloc_11[0, 0] [%57, %61] [1, 1] : memref<16x64xf32> to memref<?x?xf32, strided<[64, 1]>>
      %subview_15 = memref.subview %alloc_11[0, %61] [%57, %62] [1, 1] : memref<16x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
      memref.copy %subview_12, %subview_14 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1]>>
      memref.copy %subview_13, %subview_15 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1], offset: ?>>
      %alloc_16 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
      memref.copy %alloc, %alloc_16 : memref<32x64xf32> to memref<32x64xf32>
      linalg.matmul ins(%alloc_6, %alloc_11 : memref<32x16xf32>, memref<16x64xf32>) outs(%alloc_16 : memref<32x64xf32>)
      linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_16, %arg16 : memref<32x64xf32>, memref<32x64xf32>) outs(%alloc_16: memref<32x64xf32>) {
      ^bb0(%in: f32, %in_17: f32, %out: f32):
        %65 = arith.addf %in, %in_17 : f32
        linalg.yield %65 : f32
      }
      %63 = arith.addi %arg17, %c16 : index
      %64 = arith.addi %arg18, %26 : index
      scf.yield %alloc_16, %63, %64 : memref<32x64xf32>, index, index
    }
    %28 = arith.index_cast %arg8 : i32 to index
    %29 = arith.muli %14, %28 : index
    %30 = arith.addi %29, %16 : index
    %reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [%30], sizes: [32, 64], strides: [%28, 1] : memref<*xf32> to memref<32x64xf32, strided<[?, 1], offset: ?>>
    %31 = arith.addi %14, %c32 : index
    %32 = arith.minsi %31, %17 : index
    %33 = arith.maxsi %32, %14 : index
    %34 = arith.subi %33, %14 : index
    %35 = arith.addi %16, %c64 : index
    %36 = arith.minsi %35, %22 : index
    %37 = arith.maxsi %36, %16 : index
    %38 = arith.subi %37, %16 : index
    %39 = arith.minsi %34, %c32 : index
    %40 = arith.minsi %38, %c64 : index
    %subview = memref.subview %27#0[0, 0] [%39, %40] [1, 1] : memref<32x64xf32> to memref<?x?xf32, strided<[64, 1]>>
    %subview_1 = memref.subview %reinterpret_cast[0, 0] [%39, %40] [1, 1] : memref<32x64xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
    memref.copy %subview, %subview_1 : memref<?x?xf32, strided<[64, 1]>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
    return
  }
}

The result may cause the confuse that the partial sum is not correct added into the accumulator, especially after convert the scf to cf.

So I think the better way is using an new empty tensor as the DPS init, here the IR after this change:
before bufferize:

#map = affine_map<(d0, d1) -> (d0, d1)>
module {
  func.func @matmul_kernel(%arg0: memref<*xf32> {tt.divisibility = 16 : i32}, %arg1: memref<*xf32> {tt.divisibility = 16 : i32}, %arg2: memref<*xf32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32) {
    %c8_i32 = arith.constant 8 : i32
    %c32_i32 = arith.constant 32 : i32
    %c64_i32 = arith.constant 64 : i32
    %c16_i32 = arith.constant 16 : i32
    %c0_i32 = arith.constant 0 : i32
    %c1_i32 = arith.constant 1 : i32
    %c31_i32 = arith.constant 31 : i32
    %c63_i32 = arith.constant 63 : i32
    %c15_i32 = arith.constant 15 : i32
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %c16 = arith.constant 16 : index
    %cst = arith.constant 0.000000e+00 : f32
    %c32 = arith.constant 32 : index
    %c64 = arith.constant 64 : index
    %0 = tensor.empty() : tensor<32x64xf32>
    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<32x64xf32>) -> tensor<32x64xf32>
    %2 = arith.addi %arg3, %c31_i32 : i32
    %3 = arith.divsi %2, %c32_i32 : i32
    %4 = arith.addi %arg4, %c63_i32 : i32
    %5 = arith.divsi %4, %c64_i32 : i32
    %6 = arith.muli %5, %c8_i32 : i32
    %7 = arith.divsi %arg12, %6 : i32
    %8 = arith.muli %7, %c8_i32 : i32
    %9 = arith.subi %3, %8 : i32
    %10 = arith.minsi %9, %c8_i32 : i32
    %11 = arith.remsi %arg12, %10 : i32
    %12 = arith.addi %8, %11 : i32
    %13 = arith.remsi %arg12, %6 : i32
    %14 = arith.divsi %13, %10 : i32
    %15 = arith.muli %12, %c32_i32 : i32
    %16 = arith.index_cast %15 : i32 to index
    %17 = arith.muli %14, %c64_i32 : i32
    %18 = arith.index_cast %17 : i32 to index
    %19 = arith.index_cast %arg3 : i32 to index
    %20 = arith.index_cast %arg6 : i32 to index
    %21 = arith.muli %16, %20 : index
    %22 = arith.muli %19, %20 : index
    %23 = arith.index_cast %arg7 : i32 to index
    %24 = arith.index_cast %arg4 : i32 to index
    %25 = arith.addi %arg5, %c15_i32 : i32
    %26 = arith.divsi %25, %c16_i32 : i32
    %27 = arith.muli %arg7, %c16_i32 : i32
    %28 = arith.index_cast %27 : i32 to index
    %29:3 = scf.for %arg15 = %c0_i32 to %26 step %c1_i32 iter_args(%arg16 = %1, %arg17 = %21, %arg18 = %c0) -> (tensor<32x64xf32>, index, index)  : i32 {
      %43 = arith.addi %arg18, %18 : index
      %44 = arith.remsi %43, %24 : index
      %45 = arith.subi %43, %44 : index
      %46 = arith.addi %44, %c64 : index
      %47 = arith.minsi %46, %24 : index
      %48 = arith.subi %47, %44 : index
      %reinterpret_cast_0 = memref.reinterpret_cast %arg1 to offset: [%43], sizes: [%c16, %48], strides: [%23, %c1] : memref<*xf32> to memref<16x?xf32, strided<[?, ?], offset: ?>>
      %49 = arith.subi %c64, %48 : index
      %reinterpret_cast_1 = memref.reinterpret_cast %arg1 to offset: [%45], sizes: [%c16, %49], strides: [%23, %c1] : memref<*xf32> to memref<16x?xf32, strided<[?, ?], offset: ?>>
      %50 = arith.remsi %arg17, %20 : index
      %51 = arith.addi %22, %50 : index
      %52 = arith.subi %51, %arg17 : index
      %53 = arith.divsi %52, %20 : index
      %reinterpret_cast_2 = memref.reinterpret_cast %arg0 to offset: [%arg17], sizes: [%53, %c16], strides: [%20, %c1] : memref<*xf32> to memref<?x16xf32, strided<[?, ?], offset: ?>>
      %54 = arith.subi %c32, %53 : index
      %reinterpret_cast_3 = memref.reinterpret_cast %arg0 to offset: [%50], sizes: [%54, %c16], strides: [%20, %c1] : memref<*xf32> to memref<?x16xf32, strided<[?, ?], offset: ?>>
      %55 = arith.muli %arg15, %c16_i32 : i32
      %56 = arith.subi %arg5, %55 : i32
      %57 = arith.index_cast %56 : i32 to index
      %58 = arith.minsi %57, %c16 : index
      %59 = arith.maxsi %58, %c0 : index
      %alloc = memref.alloc() : memref<32x16xf32>
      %60 = arith.cmpi slt, %59, %c16 : index
      scf.if %60 {
        linalg.fill ins(%cst : f32) outs(%alloc : memref<32x16xf32>)
      }
      %61 = arith.minsi %53, %c32 : index
      %62 = arith.subi %c32, %61 : index
      %subview_4 = memref.subview %reinterpret_cast_2[0, 0] [%61, %59] [1, 1] : memref<?x16xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_5 = memref.subview %reinterpret_cast_3[0, 0] [%62, %59] [1, 1] : memref<?x16xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_6 = memref.subview %alloc[0, 0] [%61, %59] [1, 1] : memref<32x16xf32> to memref<?x?xf32, strided<[16, 1]>>
      %subview_7 = memref.subview %alloc[%61, 0] [%62, %59] [1, 1] : memref<32x16xf32> to memref<?x?xf32, strided<[16, 1], offset: ?>>
      memref.copy %subview_4, %subview_6 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[16, 1]>>
      memref.copy %subview_5, %subview_7 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[16, 1], offset: ?>>
      %63 = bufferization.to_tensor %alloc restrict writable : memref<32x16xf32>
      %alloc_8 = memref.alloc() : memref<16x64xf32>
      scf.if %60 {
        linalg.fill ins(%cst : f32) outs(%alloc_8 : memref<16x64xf32>)
      }
      %64 = arith.minsi %48, %c64 : index
      %65 = arith.subi %c64, %64 : index
      %subview_9 = memref.subview %reinterpret_cast_0[0, 0] [%59, %64] [1, 1] : memref<16x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_10 = memref.subview %reinterpret_cast_1[0, 0] [%59, %65] [1, 1] : memref<16x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_11 = memref.subview %alloc_8[0, 0] [%59, %64] [1, 1] : memref<16x64xf32> to memref<?x?xf32, strided<[64, 1]>>
      %subview_12 = memref.subview %alloc_8[0, %64] [%59, %65] [1, 1] : memref<16x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
      memref.copy %subview_9, %subview_11 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1]>>
      memref.copy %subview_10, %subview_12 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1], offset: ?>>
      %66 = bufferization.to_tensor %alloc_8 restrict writable : memref<16x64xf32>
      %67 = linalg.matmul ins(%63, %66 : tensor<32x16xf32>, tensor<16x64xf32>) outs(%1 : tensor<32x64xf32>) -> tensor<32x64xf32>
      %manual_empty = tensor.empty() : tensor<32x64xf32>
      %68 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg16, %67 : tensor<32x64xf32>, tensor<32x64xf32>) outs(%manual_empty : tensor<32x64xf32>) {
      ^bb0(%in: f32, %in_13: f32, %out: f32):
        %71 = arith.addf %in, %in_13 : f32
        linalg.yield %71 : f32
      } -> tensor<32x64xf32>
      %69 = arith.addi %arg17, %c16 : index
      %70 = arith.addi %arg18, %28 : index
      scf.yield %68, %69, %70 : tensor<32x64xf32>, index, index
    }
    %30 = arith.index_cast %arg8 : i32 to index
    %31 = arith.muli %16, %30 : index
    %32 = arith.addi %31, %18 : index
    %reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [%32], sizes: [32, 64], strides: [%30, 1] : memref<*xf32> to memref<32x64xf32, strided<[?, 1], offset: ?>>
    %33 = arith.addi %16, %c32 : index
    %34 = arith.minsi %33, %19 : index
    %35 = arith.maxsi %34, %16 : index
    %36 = arith.subi %35, %16 : index
    %37 = arith.addi %18, %c64 : index
    %38 = arith.minsi %37, %24 : index
    %39 = arith.maxsi %38, %18 : index
    %40 = arith.subi %39, %18 : index
    %41 = arith.minsi %36, %c32 : index
    %42 = arith.minsi %40, %c64 : index
    %extracted_slice = tensor.extract_slice %29#0[0, 0] [%41, %42] [1, 1] : tensor<32x64xf32> to tensor<?x?xf32>
    %subview = memref.subview %reinterpret_cast[0, 0] [%41, %42] [1, 1] : memref<32x64xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
    bufferization.materialize_in_destination %extracted_slice in writable %subview : (tensor<?x?xf32>, memref<?x?xf32, strided<[?, 1], offset: ?>>) -> ()
    return
  }
}

after bufferize:

#map = affine_map<(d0, d1) -> (d0, d1)>
module {
  func.func @matmul_kernel(%arg0: memref<*xf32> {tt.divisibility = 16 : i32}, %arg1: memref<*xf32> {tt.divisibility = 16 : i32}, %arg2: memref<*xf32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32) {
    %c8_i32 = arith.constant 8 : i32
    %c32_i32 = arith.constant 32 : i32
    %c64_i32 = arith.constant 64 : i32
    %c16_i32 = arith.constant 16 : i32
    %c0_i32 = arith.constant 0 : i32
    %c1_i32 = arith.constant 1 : i32
    %c31_i32 = arith.constant 31 : i32
    %c63_i32 = arith.constant 63 : i32
    %c15_i32 = arith.constant 15 : i32
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %c16 = arith.constant 16 : index
    %cst = arith.constant 0.000000e+00 : f32
    %c32 = arith.constant 32 : index
    %c64 = arith.constant 64 : index
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
    linalg.fill ins(%cst : f32) outs(%alloc : memref<32x64xf32>)
    %0 = arith.addi %arg3, %c31_i32 : i32
    %1 = arith.divsi %0, %c32_i32 : i32
    %2 = arith.addi %arg4, %c63_i32 : i32
    %3 = arith.divsi %2, %c64_i32 : i32
    %4 = arith.muli %3, %c8_i32 : i32
    %5 = arith.divsi %arg12, %4 : i32
    %6 = arith.muli %5, %c8_i32 : i32
    %7 = arith.subi %1, %6 : i32
    %8 = arith.minsi %7, %c8_i32 : i32
    %9 = arith.remsi %arg12, %8 : i32
    %10 = arith.addi %6, %9 : i32
    %11 = arith.remsi %arg12, %4 : i32
    %12 = arith.divsi %11, %8 : i32
    %13 = arith.muli %10, %c32_i32 : i32
    %14 = arith.index_cast %13 : i32 to index
    %15 = arith.muli %12, %c64_i32 : i32
    %16 = arith.index_cast %15 : i32 to index
    %17 = arith.index_cast %arg3 : i32 to index
    %18 = arith.index_cast %arg6 : i32 to index
    %19 = arith.muli %14, %18 : index
    %20 = arith.muli %17, %18 : index
    %21 = arith.index_cast %arg7 : i32 to index
    %22 = arith.index_cast %arg4 : i32 to index
    %23 = arith.addi %arg5, %c15_i32 : i32
    %24 = arith.divsi %23, %c16_i32 : i32
    %25 = arith.muli %arg7, %c16_i32 : i32
    %26 = arith.index_cast %25 : i32 to index
    %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
    memref.copy %alloc, %alloc_0 : memref<32x64xf32> to memref<32x64xf32>
    %27:3 = scf.for %arg15 = %c0_i32 to %24 step %c1_i32 iter_args(%arg16 = %alloc_0, %arg17 = %19, %arg18 = %c0) -> (memref<32x64xf32>, index, index)  : i32 {
      %41 = arith.addi %arg18, %16 : index
      %42 = arith.remsi %41, %22 : index
      %43 = arith.subi %41, %42 : index
      %44 = arith.addi %42, %c64 : index
      %45 = arith.minsi %44, %22 : index
      %46 = arith.subi %45, %42 : index
      %reinterpret_cast_2 = memref.reinterpret_cast %arg1 to offset: [%41], sizes: [%c16, %46], strides: [%21, %c1] : memref<*xf32> to memref<16x?xf32, strided<[?, ?], offset: ?>>
      %47 = arith.subi %c64, %46 : index
      %reinterpret_cast_3 = memref.reinterpret_cast %arg1 to offset: [%43], sizes: [%c16, %47], strides: [%21, %c1] : memref<*xf32> to memref<16x?xf32, strided<[?, ?], offset: ?>>
      %48 = arith.remsi %arg17, %18 : index
      %49 = arith.addi %20, %48 : index
      %50 = arith.subi %49, %arg17 : index
      %51 = arith.divsi %50, %18 : index
      %reinterpret_cast_4 = memref.reinterpret_cast %arg0 to offset: [%arg17], sizes: [%51, %c16], strides: [%18, %c1] : memref<*xf32> to memref<?x16xf32, strided<[?, ?], offset: ?>>
      %52 = arith.subi %c32, %51 : index
      %reinterpret_cast_5 = memref.reinterpret_cast %arg0 to offset: [%48], sizes: [%52, %c16], strides: [%18, %c1] : memref<*xf32> to memref<?x16xf32, strided<[?, ?], offset: ?>>
      %53 = arith.muli %arg15, %c16_i32 : i32
      %54 = arith.subi %arg5, %53 : i32
      %55 = arith.index_cast %54 : i32 to index
      %56 = arith.minsi %55, %c16 : index
      %57 = arith.maxsi %56, %c0 : index
      %alloc_6 = memref.alloc() : memref<32x16xf32>
      %58 = arith.cmpi slt, %57, %c16 : index
      scf.if %58 {
        linalg.fill ins(%cst : f32) outs(%alloc_6 : memref<32x16xf32>)
      }
      %59 = arith.minsi %51, %c32 : index
      %60 = arith.subi %c32, %59 : index
      %subview_7 = memref.subview %reinterpret_cast_4[0, 0] [%59, %57] [1, 1] : memref<?x16xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_8 = memref.subview %reinterpret_cast_5[0, 0] [%60, %57] [1, 1] : memref<?x16xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_9 = memref.subview %alloc_6[0, 0] [%59, %57] [1, 1] : memref<32x16xf32> to memref<?x?xf32, strided<[16, 1]>>
      %subview_10 = memref.subview %alloc_6[%59, 0] [%60, %57] [1, 1] : memref<32x16xf32> to memref<?x?xf32, strided<[16, 1], offset: ?>>
      memref.copy %subview_7, %subview_9 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[16, 1]>>
      memref.copy %subview_8, %subview_10 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[16, 1], offset: ?>>
      %alloc_11 = memref.alloc() : memref<16x64xf32>
      scf.if %58 {
        linalg.fill ins(%cst : f32) outs(%alloc_11 : memref<16x64xf32>)
      }
      %61 = arith.minsi %46, %c64 : index
      %62 = arith.subi %c64, %61 : index
      %subview_12 = memref.subview %reinterpret_cast_2[0, 0] [%57, %61] [1, 1] : memref<16x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_13 = memref.subview %reinterpret_cast_3[0, 0] [%57, %62] [1, 1] : memref<16x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
      %subview_14 = memref.subview %alloc_11[0, 0] [%57, %61] [1, 1] : memref<16x64xf32> to memref<?x?xf32, strided<[64, 1]>>
      %subview_15 = memref.subview %alloc_11[0, %61] [%57, %62] [1, 1] : memref<16x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
      memref.copy %subview_12, %subview_14 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1]>>
      memref.copy %subview_13, %subview_15 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1], offset: ?>>
      %alloc_16 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
      memref.copy %alloc, %alloc_16 : memref<32x64xf32> to memref<32x64xf32>
      linalg.matmul ins(%alloc_6, %alloc_11 : memref<32x16xf32>, memref<16x64xf32>) outs(%alloc_16 : memref<32x64xf32>)
      %alloc_17 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
      linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg16, %alloc_16 : memref<32x64xf32>, memref<32x64xf32>) outs(%alloc_17 : memref<32x64xf32>) {
      ^bb0(%in: f32, %in_18: f32, %out: f32):
        %65 = arith.addf %in, %in_18 : f32
        linalg.yield %65 : f32
      }
      %63 = arith.addi %arg17, %c16 : index
      %64 = arith.addi %arg18, %26 : index
      scf.yield %alloc_17, %63, %64 : memref<32x64xf32>, index, index
    }
    %28 = arith.index_cast %arg8 : i32 to index
    %29 = arith.muli %14, %28 : index
    %30 = arith.addi %29, %16 : index
    %reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [%30], sizes: [32, 64], strides: [%28, 1] : memref<*xf32> to memref<32x64xf32, strided<[?, 1], offset: ?>>
    %31 = arith.addi %14, %c32 : index
    %32 = arith.minsi %31, %17 : index
    %33 = arith.maxsi %32, %14 : index
    %34 = arith.subi %33, %14 : index
    %35 = arith.addi %16, %c64 : index
    %36 = arith.minsi %35, %22 : index
    %37 = arith.maxsi %36, %16 : index
    %38 = arith.subi %37, %16 : index
    %39 = arith.minsi %34, %c32 : index
    %40 = arith.minsi %38, %c64 : index
    %subview = memref.subview %27#0[0, 0] [%39, %40] [1, 1] : memref<32x64xf32> to memref<?x?xf32, strided<[64, 1]>>
    %subview_1 = memref.subview %reinterpret_cast[0, 0] [%39, %40] [1, 1] : memref<32x64xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
    memref.copy %subview, %subview_1 : memref<?x?xf32, strided<[64, 1]>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
    return
  }
}

MercuryChen added a commit to MercuryChen/triton-shared that referenced this issue Dec 5, 2024
the tt.dot with accumulator will lower to linalg.matmul and arith.add,
and the arith.add will further lower to linalg.generic,
generic will take the lhs of add as the DPS init, so the lhs should be
the matmul accumulator. This is a temporary fix for issue microsoft#196.
MercuryChen added a commit to MercuryChen/triton-shared that referenced this issue Dec 5, 2024
the tt.dot with accumulator will lower to linalg.matmul and arith.add,
and the arith.add will further lower to linalg.generic,
generic will take the lhs of add as the DPS init, so the lhs should be
the matmul accumulator. This is a temporary fix for issue microsoft#196.
MercuryChen added a commit to MercuryChen/triton-shared that referenced this issue Dec 5, 2024
the tt.dot with accumulator will lower to linalg.matmul and arith.add,
and the arith.add will further lower to linalg.generic,
generic will take the lhs of add as the DPS init, so the lhs should be
the matmul accumulator. This is a temporary fix for issue microsoft#196.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants