-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Investigate correct uses of linalg out params #196
Comments
While convert #map = affine_map<(d0, d1) -> (d0, d1)>
module {
func.func @matmul_kernel(%arg0: memref<*xf32> {tt.divisibility = 16 : i32}, %arg1: memref<*xf32> {tt.divisibility = 16 : i32}, %arg2: memref<*xf32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32) {
%c8_i32 = arith.constant 8 : i32
%c32_i32 = arith.constant 32 : i32
%c64_i32 = arith.constant 64 : i32
%c16_i32 = arith.constant 16 : i32
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c31_i32 = arith.constant 31 : i32
%c63_i32 = arith.constant 63 : i32
%c15_i32 = arith.constant 15 : i32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%cst = arith.constant 0.000000e+00 : f32
%c32 = arith.constant 32 : index
%c64 = arith.constant 64 : index
%0 = tensor.empty() : tensor<32x64xf32>
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<32x64xf32>) -> tensor<32x64xf32>
%2 = arith.addi %arg3, %c31_i32 : i32
%3 = arith.divsi %2, %c32_i32 : i32
%4 = arith.addi %arg4, %c63_i32 : i32
%5 = arith.divsi %4, %c64_i32 : i32
%6 = arith.muli %5, %c8_i32 : i32
%7 = arith.divsi %arg12, %6 : i32
%8 = arith.muli %7, %c8_i32 : i32
%9 = arith.subi %3, %8 : i32
%10 = arith.minsi %9, %c8_i32 : i32
%11 = arith.remsi %arg12, %10 : i32
%12 = arith.addi %8, %11 : i32
%13 = arith.remsi %arg12, %6 : i32
%14 = arith.divsi %13, %10 : i32
%15 = arith.muli %12, %c32_i32 : i32
%16 = arith.index_cast %15 : i32 to index
%17 = arith.muli %14, %c64_i32 : i32
%18 = arith.index_cast %17 : i32 to index
%19 = arith.index_cast %arg3 : i32 to index
%20 = arith.index_cast %arg6 : i32 to index
%21 = arith.muli %16, %20 : index
%22 = arith.muli %19, %20 : index
%23 = arith.index_cast %arg7 : i32 to index
%24 = arith.index_cast %arg4 : i32 to index
%25 = arith.addi %arg5, %c15_i32 : i32
%26 = arith.divsi %25, %c16_i32 : i32
%27 = arith.muli %arg7, %c16_i32 : i32
%28 = arith.index_cast %27 : i32 to index
%29:3 = scf.for %arg15 = %c0_i32 to %26 step %c1_i32 iter_args(%arg16 = %1, %arg17 = %21, %arg18 = %c0) -> (tensor<32x64xf32>, index, index) : i32 {
%43 = arith.addi %arg18, %18 : index
%44 = arith.remsi %43, %24 : index
%45 = arith.subi %43, %44 : index
%46 = arith.addi %44, %c64 : index
%47 = arith.minsi %46, %24 : index
%48 = arith.subi %47, %44 : index
%reinterpret_cast_0 = memref.reinterpret_cast %arg1 to offset: [%43], sizes: [%c16, %48], strides: [%23, %c1] : memref<*xf32> to memref<16x?xf32, strided<[?, ?], offset: ?>>
%49 = arith.subi %c64, %48 : index
%reinterpret_cast_1 = memref.reinterpret_cast %arg1 to offset: [%45], sizes: [%c16, %49], strides: [%23, %c1] : memref<*xf32> to memref<16x?xf32, strided<[?, ?], offset: ?>>
%50 = arith.remsi %arg17, %20 : index
%51 = arith.addi %22, %50 : index
%52 = arith.subi %51, %arg17 : index
%53 = arith.divsi %52, %20 : index
%reinterpret_cast_2 = memref.reinterpret_cast %arg0 to offset: [%arg17], sizes: [%53, %c16], strides: [%20, %c1] : memref<*xf32> to memref<?x16xf32, strided<[?, ?], offset: ?>>
%54 = arith.subi %c32, %53 : index
%reinterpret_cast_3 = memref.reinterpret_cast %arg0 to offset: [%50], sizes: [%54, %c16], strides: [%20, %c1] : memref<*xf32> to memref<?x16xf32, strided<[?, ?], offset: ?>>
%55 = arith.muli %arg15, %c16_i32 : i32
%56 = arith.subi %arg5, %55 : i32
%57 = arith.index_cast %56 : i32 to index
%58 = arith.minsi %57, %c16 : index
%59 = arith.maxsi %58, %c0 : index
%alloc = memref.alloc() : memref<32x16xf32>
%60 = arith.cmpi slt, %59, %c16 : index
scf.if %60 {
linalg.fill ins(%cst : f32) outs(%alloc : memref<32x16xf32>)
}
%61 = arith.minsi %53, %c32 : index
%62 = arith.subi %c32, %61 : index
%subview_4 = memref.subview %reinterpret_cast_2[0, 0] [%61, %59] [1, 1] : memref<?x16xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%subview_5 = memref.subview %reinterpret_cast_3[0, 0] [%62, %59] [1, 1] : memref<?x16xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%subview_6 = memref.subview %alloc[0, 0] [%61, %59] [1, 1] : memref<32x16xf32> to memref<?x?xf32, strided<[16, 1]>>
%subview_7 = memref.subview %alloc[%61, 0] [%62, %59] [1, 1] : memref<32x16xf32> to memref<?x?xf32, strided<[16, 1], offset: ?>>
memref.copy %subview_4, %subview_6 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[16, 1]>>
memref.copy %subview_5, %subview_7 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[16, 1], offset: ?>>
%63 = bufferization.to_tensor %alloc restrict writable : memref<32x16xf32>
%alloc_8 = memref.alloc() : memref<16x64xf32>
scf.if %60 {
linalg.fill ins(%cst : f32) outs(%alloc_8 : memref<16x64xf32>)
}
%64 = arith.minsi %48, %c64 : index
%65 = arith.subi %c64, %64 : index
%subview_9 = memref.subview %reinterpret_cast_0[0, 0] [%59, %64] [1, 1] : memref<16x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%subview_10 = memref.subview %reinterpret_cast_1[0, 0] [%59, %65] [1, 1] : memref<16x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%subview_11 = memref.subview %alloc_8[0, 0] [%59, %64] [1, 1] : memref<16x64xf32> to memref<?x?xf32, strided<[64, 1]>>
%subview_12 = memref.subview %alloc_8[0, %64] [%59, %65] [1, 1] : memref<16x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
memref.copy %subview_9, %subview_11 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1]>>
memref.copy %subview_10, %subview_12 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1], offset: ?>>
%66 = bufferization.to_tensor %alloc_8 restrict writable : memref<16x64xf32>
%67 = linalg.matmul ins(%63, %66 : tensor<32x16xf32>, tensor<16x64xf32>) outs(%1 : tensor<32x64xf32>) -> tensor<32x64xf32>
%68 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%67, %arg16 : tensor<32x64xf32>, tensor<32x64xf32>) outs(%67 : tensor<32x64xf32>) {
^bb0(%in: f32, %in_13: f32, %out: f32):
%71 = arith.addf %in, %in_13 : f32
linalg.yield %71 : f32
} -> tensor<32x64xf32>
%69 = arith.addi %arg17, %c16 : index
%70 = arith.addi %arg18, %28 : index
scf.yield %68, %69, %70 : tensor<32x64xf32>, index, index
}
%30 = arith.index_cast %arg8 : i32 to index
%31 = arith.muli %16, %30 : index
%32 = arith.addi %31, %18 : index
%reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [%32], sizes: [32, 64], strides: [%30, 1] : memref<*xf32> to memref<32x64xf32, strided<[?, 1], offset: ?>>
%33 = arith.addi %16, %c32 : index
%34 = arith.minsi %33, %19 : index
%35 = arith.maxsi %34, %16 : index
%36 = arith.subi %35, %16 : index
%37 = arith.addi %18, %c64 : index
%38 = arith.minsi %37, %24 : index
%39 = arith.maxsi %38, %18 : index
%40 = arith.subi %39, %18 : index
%41 = arith.minsi %36, %c32 : index
%42 = arith.minsi %40, %c64 : index
%extracted_slice = tensor.extract_slice %29#0[0, 0] [%41, %42] [1, 1] : tensor<32x64xf32> to tensor<?x?xf32>
%subview = memref.subview %reinterpret_cast[0, 0] [%41, %42] [1, 1] : memref<32x64xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
bufferization.materialize_in_destination %extracted_slice in writable %subview : (tensor<?x?xf32>, memref<?x?xf32, strided<[?, 1], offset: ?>>) -> ()
return
}
} The generic op #map = affine_map<(d0, d1) -> (d0, d1)>
module {
func.func @matmul_kernel(%arg0: memref<*xf32> {tt.divisibility = 16 : i32}, %arg1: memref<*xf32> {tt.divisibility = 16 : i32}, %arg2: memref<*xf32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32) {
%c8_i32 = arith.constant 8 : i32
%c32_i32 = arith.constant 32 : i32
%c64_i32 = arith.constant 64 : i32
%c16_i32 = arith.constant 16 : i32
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c31_i32 = arith.constant 31 : i32
%c63_i32 = arith.constant 63 : i32
%c15_i32 = arith.constant 15 : i32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%cst = arith.constant 0.000000e+00 : f32
%c32 = arith.constant 32 : index
%c64 = arith.constant 64 : index
%alloc = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
linalg.fill ins(%cst : f32) outs(%alloc : memref<32x64xf32>)
%0 = arith.addi %arg3, %c31_i32 : i32
%1 = arith.divsi %0, %c32_i32 : i32
%2 = arith.addi %arg4, %c63_i32 : i32
%3 = arith.divsi %2, %c64_i32 : i32
%4 = arith.muli %3, %c8_i32 : i32
%5 = arith.divsi %arg12, %4 : i32
%6 = arith.muli %5, %c8_i32 : i32
%7 = arith.subi %1, %6 : i32
%8 = arith.minsi %7, %c8_i32 : i32
%9 = arith.remsi %arg12, %8 : i32
%10 = arith.addi %6, %9 : i32
%11 = arith.remsi %arg12, %4 : i32
%12 = arith.divsi %11, %8 : i32
%13 = arith.muli %10, %c32_i32 : i32
%14 = arith.index_cast %13 : i32 to index
%15 = arith.muli %12, %c64_i32 : i32
%16 = arith.index_cast %15 : i32 to index
%17 = arith.index_cast %arg3 : i32 to index
%18 = arith.index_cast %arg6 : i32 to index
%19 = arith.muli %14, %18 : index
%20 = arith.muli %17, %18 : index
%21 = arith.index_cast %arg7 : i32 to index
%22 = arith.index_cast %arg4 : i32 to index
%23 = arith.addi %arg5, %c15_i32 : i32
%24 = arith.divsi %23, %c16_i32 : i32
%25 = arith.muli %arg7, %c16_i32 : i32
%26 = arith.index_cast %25 : i32 to index
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
memref.copy %alloc, %alloc_0 : memref<32x64xf32> to memref<32x64xf32>
%27:3 = scf.for %arg15 = %c0_i32 to %24 step %c1_i32 iter_args(%arg16 = %alloc_0, %arg17 = %19, %arg18 = %c0) -> (memref<32x64xf32>, index, index) : i32 {
%41 = arith.addi %arg18, %16 : index
%42 = arith.remsi %41, %22 : index
%43 = arith.subi %41, %42 : index
%44 = arith.addi %42, %c64 : index
%45 = arith.minsi %44, %22 : index
%46 = arith.subi %45, %42 : index
%reinterpret_cast_2 = memref.reinterpret_cast %arg1 to offset: [%41], sizes: [%c16, %46], strides: [%21, %c1] : memref<*xf32> to memref<16x?xf32, strided<[?, ?], offset: ?>>
%47 = arith.subi %c64, %46 : index
%reinterpret_cast_3 = memref.reinterpret_cast %arg1 to offset: [%43], sizes: [%c16, %47], strides: [%21, %c1] : memref<*xf32> to memref<16x?xf32, strided<[?, ?], offset: ?>>
%48 = arith.remsi %arg17, %18 : index
%49 = arith.addi %20, %48 : index
%50 = arith.subi %49, %arg17 : index
%51 = arith.divsi %50, %18 : index
%reinterpret_cast_4 = memref.reinterpret_cast %arg0 to offset: [%arg17], sizes: [%51, %c16], strides: [%18, %c1] : memref<*xf32> to memref<?x16xf32, strided<[?, ?], offset: ?>>
%52 = arith.subi %c32, %51 : index
%reinterpret_cast_5 = memref.reinterpret_cast %arg0 to offset: [%48], sizes: [%52, %c16], strides: [%18, %c1] : memref<*xf32> to memref<?x16xf32, strided<[?, ?], offset: ?>>
%53 = arith.muli %arg15, %c16_i32 : i32
%54 = arith.subi %arg5, %53 : i32
%55 = arith.index_cast %54 : i32 to index
%56 = arith.minsi %55, %c16 : index
%57 = arith.maxsi %56, %c0 : index
%alloc_6 = memref.alloc() : memref<32x16xf32>
%58 = arith.cmpi slt, %57, %c16 : index
scf.if %58 {
linalg.fill ins(%cst : f32) outs(%alloc_6 : memref<32x16xf32>)
}
%59 = arith.minsi %51, %c32 : index
%60 = arith.subi %c32, %59 : index
%subview_7 = memref.subview %reinterpret_cast_4[0, 0] [%59, %57] [1, 1] : memref<?x16xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%subview_8 = memref.subview %reinterpret_cast_5[0, 0] [%60, %57] [1, 1] : memref<?x16xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%subview_9 = memref.subview %alloc_6[0, 0] [%59, %57] [1, 1] : memref<32x16xf32> to memref<?x?xf32, strided<[16, 1]>>
%subview_10 = memref.subview %alloc_6[%59, 0] [%60, %57] [1, 1] : memref<32x16xf32> to memref<?x?xf32, strided<[16, 1], offset: ?>>
memref.copy %subview_7, %subview_9 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[16, 1]>>
memref.copy %subview_8, %subview_10 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[16, 1], offset: ?>>
%alloc_11 = memref.alloc() : memref<16x64xf32>
scf.if %58 {
linalg.fill ins(%cst : f32) outs(%alloc_11 : memref<16x64xf32>)
}
%61 = arith.minsi %46, %c64 : index
%62 = arith.subi %c64, %61 : index
%subview_12 = memref.subview %reinterpret_cast_2[0, 0] [%57, %61] [1, 1] : memref<16x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%subview_13 = memref.subview %reinterpret_cast_3[0, 0] [%57, %62] [1, 1] : memref<16x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%subview_14 = memref.subview %alloc_11[0, 0] [%57, %61] [1, 1] : memref<16x64xf32> to memref<?x?xf32, strided<[64, 1]>>
%subview_15 = memref.subview %alloc_11[0, %61] [%57, %62] [1, 1] : memref<16x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
memref.copy %subview_12, %subview_14 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1]>>
memref.copy %subview_13, %subview_15 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1], offset: ?>>
%alloc_16 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
memref.copy %alloc, %alloc_16 : memref<32x64xf32> to memref<32x64xf32>
linalg.matmul ins(%alloc_6, %alloc_11 : memref<32x16xf32>, memref<16x64xf32>) outs(%alloc_16 : memref<32x64xf32>)
linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_16, %arg16 : memref<32x64xf32>, memref<32x64xf32>) outs(%alloc_16: memref<32x64xf32>) {
^bb0(%in: f32, %in_17: f32, %out: f32):
%65 = arith.addf %in, %in_17 : f32
linalg.yield %65 : f32
}
%63 = arith.addi %arg17, %c16 : index
%64 = arith.addi %arg18, %26 : index
scf.yield %alloc_16, %63, %64 : memref<32x64xf32>, index, index
}
%28 = arith.index_cast %arg8 : i32 to index
%29 = arith.muli %14, %28 : index
%30 = arith.addi %29, %16 : index
%reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [%30], sizes: [32, 64], strides: [%28, 1] : memref<*xf32> to memref<32x64xf32, strided<[?, 1], offset: ?>>
%31 = arith.addi %14, %c32 : index
%32 = arith.minsi %31, %17 : index
%33 = arith.maxsi %32, %14 : index
%34 = arith.subi %33, %14 : index
%35 = arith.addi %16, %c64 : index
%36 = arith.minsi %35, %22 : index
%37 = arith.maxsi %36, %16 : index
%38 = arith.subi %37, %16 : index
%39 = arith.minsi %34, %c32 : index
%40 = arith.minsi %38, %c64 : index
%subview = memref.subview %27#0[0, 0] [%39, %40] [1, 1] : memref<32x64xf32> to memref<?x?xf32, strided<[64, 1]>>
%subview_1 = memref.subview %reinterpret_cast[0, 0] [%39, %40] [1, 1] : memref<32x64xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
memref.copy %subview, %subview_1 : memref<?x?xf32, strided<[64, 1]>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
return
}
} The result may cause the confuse that the partial sum is not correct added into the accumulator, especially after convert the scf to cf. So I think the better way is using an new empty tensor as the DPS init, here the IR after this change: #map = affine_map<(d0, d1) -> (d0, d1)>
module {
func.func @matmul_kernel(%arg0: memref<*xf32> {tt.divisibility = 16 : i32}, %arg1: memref<*xf32> {tt.divisibility = 16 : i32}, %arg2: memref<*xf32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32) {
%c8_i32 = arith.constant 8 : i32
%c32_i32 = arith.constant 32 : i32
%c64_i32 = arith.constant 64 : i32
%c16_i32 = arith.constant 16 : i32
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c31_i32 = arith.constant 31 : i32
%c63_i32 = arith.constant 63 : i32
%c15_i32 = arith.constant 15 : i32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%cst = arith.constant 0.000000e+00 : f32
%c32 = arith.constant 32 : index
%c64 = arith.constant 64 : index
%0 = tensor.empty() : tensor<32x64xf32>
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<32x64xf32>) -> tensor<32x64xf32>
%2 = arith.addi %arg3, %c31_i32 : i32
%3 = arith.divsi %2, %c32_i32 : i32
%4 = arith.addi %arg4, %c63_i32 : i32
%5 = arith.divsi %4, %c64_i32 : i32
%6 = arith.muli %5, %c8_i32 : i32
%7 = arith.divsi %arg12, %6 : i32
%8 = arith.muli %7, %c8_i32 : i32
%9 = arith.subi %3, %8 : i32
%10 = arith.minsi %9, %c8_i32 : i32
%11 = arith.remsi %arg12, %10 : i32
%12 = arith.addi %8, %11 : i32
%13 = arith.remsi %arg12, %6 : i32
%14 = arith.divsi %13, %10 : i32
%15 = arith.muli %12, %c32_i32 : i32
%16 = arith.index_cast %15 : i32 to index
%17 = arith.muli %14, %c64_i32 : i32
%18 = arith.index_cast %17 : i32 to index
%19 = arith.index_cast %arg3 : i32 to index
%20 = arith.index_cast %arg6 : i32 to index
%21 = arith.muli %16, %20 : index
%22 = arith.muli %19, %20 : index
%23 = arith.index_cast %arg7 : i32 to index
%24 = arith.index_cast %arg4 : i32 to index
%25 = arith.addi %arg5, %c15_i32 : i32
%26 = arith.divsi %25, %c16_i32 : i32
%27 = arith.muli %arg7, %c16_i32 : i32
%28 = arith.index_cast %27 : i32 to index
%29:3 = scf.for %arg15 = %c0_i32 to %26 step %c1_i32 iter_args(%arg16 = %1, %arg17 = %21, %arg18 = %c0) -> (tensor<32x64xf32>, index, index) : i32 {
%43 = arith.addi %arg18, %18 : index
%44 = arith.remsi %43, %24 : index
%45 = arith.subi %43, %44 : index
%46 = arith.addi %44, %c64 : index
%47 = arith.minsi %46, %24 : index
%48 = arith.subi %47, %44 : index
%reinterpret_cast_0 = memref.reinterpret_cast %arg1 to offset: [%43], sizes: [%c16, %48], strides: [%23, %c1] : memref<*xf32> to memref<16x?xf32, strided<[?, ?], offset: ?>>
%49 = arith.subi %c64, %48 : index
%reinterpret_cast_1 = memref.reinterpret_cast %arg1 to offset: [%45], sizes: [%c16, %49], strides: [%23, %c1] : memref<*xf32> to memref<16x?xf32, strided<[?, ?], offset: ?>>
%50 = arith.remsi %arg17, %20 : index
%51 = arith.addi %22, %50 : index
%52 = arith.subi %51, %arg17 : index
%53 = arith.divsi %52, %20 : index
%reinterpret_cast_2 = memref.reinterpret_cast %arg0 to offset: [%arg17], sizes: [%53, %c16], strides: [%20, %c1] : memref<*xf32> to memref<?x16xf32, strided<[?, ?], offset: ?>>
%54 = arith.subi %c32, %53 : index
%reinterpret_cast_3 = memref.reinterpret_cast %arg0 to offset: [%50], sizes: [%54, %c16], strides: [%20, %c1] : memref<*xf32> to memref<?x16xf32, strided<[?, ?], offset: ?>>
%55 = arith.muli %arg15, %c16_i32 : i32
%56 = arith.subi %arg5, %55 : i32
%57 = arith.index_cast %56 : i32 to index
%58 = arith.minsi %57, %c16 : index
%59 = arith.maxsi %58, %c0 : index
%alloc = memref.alloc() : memref<32x16xf32>
%60 = arith.cmpi slt, %59, %c16 : index
scf.if %60 {
linalg.fill ins(%cst : f32) outs(%alloc : memref<32x16xf32>)
}
%61 = arith.minsi %53, %c32 : index
%62 = arith.subi %c32, %61 : index
%subview_4 = memref.subview %reinterpret_cast_2[0, 0] [%61, %59] [1, 1] : memref<?x16xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%subview_5 = memref.subview %reinterpret_cast_3[0, 0] [%62, %59] [1, 1] : memref<?x16xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%subview_6 = memref.subview %alloc[0, 0] [%61, %59] [1, 1] : memref<32x16xf32> to memref<?x?xf32, strided<[16, 1]>>
%subview_7 = memref.subview %alloc[%61, 0] [%62, %59] [1, 1] : memref<32x16xf32> to memref<?x?xf32, strided<[16, 1], offset: ?>>
memref.copy %subview_4, %subview_6 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[16, 1]>>
memref.copy %subview_5, %subview_7 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[16, 1], offset: ?>>
%63 = bufferization.to_tensor %alloc restrict writable : memref<32x16xf32>
%alloc_8 = memref.alloc() : memref<16x64xf32>
scf.if %60 {
linalg.fill ins(%cst : f32) outs(%alloc_8 : memref<16x64xf32>)
}
%64 = arith.minsi %48, %c64 : index
%65 = arith.subi %c64, %64 : index
%subview_9 = memref.subview %reinterpret_cast_0[0, 0] [%59, %64] [1, 1] : memref<16x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%subview_10 = memref.subview %reinterpret_cast_1[0, 0] [%59, %65] [1, 1] : memref<16x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%subview_11 = memref.subview %alloc_8[0, 0] [%59, %64] [1, 1] : memref<16x64xf32> to memref<?x?xf32, strided<[64, 1]>>
%subview_12 = memref.subview %alloc_8[0, %64] [%59, %65] [1, 1] : memref<16x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
memref.copy %subview_9, %subview_11 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1]>>
memref.copy %subview_10, %subview_12 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1], offset: ?>>
%66 = bufferization.to_tensor %alloc_8 restrict writable : memref<16x64xf32>
%67 = linalg.matmul ins(%63, %66 : tensor<32x16xf32>, tensor<16x64xf32>) outs(%1 : tensor<32x64xf32>) -> tensor<32x64xf32>
%manual_empty = tensor.empty() : tensor<32x64xf32>
%68 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg16, %67 : tensor<32x64xf32>, tensor<32x64xf32>) outs(%manual_empty : tensor<32x64xf32>) {
^bb0(%in: f32, %in_13: f32, %out: f32):
%71 = arith.addf %in, %in_13 : f32
linalg.yield %71 : f32
} -> tensor<32x64xf32>
%69 = arith.addi %arg17, %c16 : index
%70 = arith.addi %arg18, %28 : index
scf.yield %68, %69, %70 : tensor<32x64xf32>, index, index
}
%30 = arith.index_cast %arg8 : i32 to index
%31 = arith.muli %16, %30 : index
%32 = arith.addi %31, %18 : index
%reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [%32], sizes: [32, 64], strides: [%30, 1] : memref<*xf32> to memref<32x64xf32, strided<[?, 1], offset: ?>>
%33 = arith.addi %16, %c32 : index
%34 = arith.minsi %33, %19 : index
%35 = arith.maxsi %34, %16 : index
%36 = arith.subi %35, %16 : index
%37 = arith.addi %18, %c64 : index
%38 = arith.minsi %37, %24 : index
%39 = arith.maxsi %38, %18 : index
%40 = arith.subi %39, %18 : index
%41 = arith.minsi %36, %c32 : index
%42 = arith.minsi %40, %c64 : index
%extracted_slice = tensor.extract_slice %29#0[0, 0] [%41, %42] [1, 1] : tensor<32x64xf32> to tensor<?x?xf32>
%subview = memref.subview %reinterpret_cast[0, 0] [%41, %42] [1, 1] : memref<32x64xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
bufferization.materialize_in_destination %extracted_slice in writable %subview : (tensor<?x?xf32>, memref<?x?xf32, strided<[?, 1], offset: ?>>) -> ()
return
}
} after bufferize: #map = affine_map<(d0, d1) -> (d0, d1)>
module {
func.func @matmul_kernel(%arg0: memref<*xf32> {tt.divisibility = 16 : i32}, %arg1: memref<*xf32> {tt.divisibility = 16 : i32}, %arg2: memref<*xf32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32) {
%c8_i32 = arith.constant 8 : i32
%c32_i32 = arith.constant 32 : i32
%c64_i32 = arith.constant 64 : i32
%c16_i32 = arith.constant 16 : i32
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c31_i32 = arith.constant 31 : i32
%c63_i32 = arith.constant 63 : i32
%c15_i32 = arith.constant 15 : i32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%cst = arith.constant 0.000000e+00 : f32
%c32 = arith.constant 32 : index
%c64 = arith.constant 64 : index
%alloc = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
linalg.fill ins(%cst : f32) outs(%alloc : memref<32x64xf32>)
%0 = arith.addi %arg3, %c31_i32 : i32
%1 = arith.divsi %0, %c32_i32 : i32
%2 = arith.addi %arg4, %c63_i32 : i32
%3 = arith.divsi %2, %c64_i32 : i32
%4 = arith.muli %3, %c8_i32 : i32
%5 = arith.divsi %arg12, %4 : i32
%6 = arith.muli %5, %c8_i32 : i32
%7 = arith.subi %1, %6 : i32
%8 = arith.minsi %7, %c8_i32 : i32
%9 = arith.remsi %arg12, %8 : i32
%10 = arith.addi %6, %9 : i32
%11 = arith.remsi %arg12, %4 : i32
%12 = arith.divsi %11, %8 : i32
%13 = arith.muli %10, %c32_i32 : i32
%14 = arith.index_cast %13 : i32 to index
%15 = arith.muli %12, %c64_i32 : i32
%16 = arith.index_cast %15 : i32 to index
%17 = arith.index_cast %arg3 : i32 to index
%18 = arith.index_cast %arg6 : i32 to index
%19 = arith.muli %14, %18 : index
%20 = arith.muli %17, %18 : index
%21 = arith.index_cast %arg7 : i32 to index
%22 = arith.index_cast %arg4 : i32 to index
%23 = arith.addi %arg5, %c15_i32 : i32
%24 = arith.divsi %23, %c16_i32 : i32
%25 = arith.muli %arg7, %c16_i32 : i32
%26 = arith.index_cast %25 : i32 to index
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
memref.copy %alloc, %alloc_0 : memref<32x64xf32> to memref<32x64xf32>
%27:3 = scf.for %arg15 = %c0_i32 to %24 step %c1_i32 iter_args(%arg16 = %alloc_0, %arg17 = %19, %arg18 = %c0) -> (memref<32x64xf32>, index, index) : i32 {
%41 = arith.addi %arg18, %16 : index
%42 = arith.remsi %41, %22 : index
%43 = arith.subi %41, %42 : index
%44 = arith.addi %42, %c64 : index
%45 = arith.minsi %44, %22 : index
%46 = arith.subi %45, %42 : index
%reinterpret_cast_2 = memref.reinterpret_cast %arg1 to offset: [%41], sizes: [%c16, %46], strides: [%21, %c1] : memref<*xf32> to memref<16x?xf32, strided<[?, ?], offset: ?>>
%47 = arith.subi %c64, %46 : index
%reinterpret_cast_3 = memref.reinterpret_cast %arg1 to offset: [%43], sizes: [%c16, %47], strides: [%21, %c1] : memref<*xf32> to memref<16x?xf32, strided<[?, ?], offset: ?>>
%48 = arith.remsi %arg17, %18 : index
%49 = arith.addi %20, %48 : index
%50 = arith.subi %49, %arg17 : index
%51 = arith.divsi %50, %18 : index
%reinterpret_cast_4 = memref.reinterpret_cast %arg0 to offset: [%arg17], sizes: [%51, %c16], strides: [%18, %c1] : memref<*xf32> to memref<?x16xf32, strided<[?, ?], offset: ?>>
%52 = arith.subi %c32, %51 : index
%reinterpret_cast_5 = memref.reinterpret_cast %arg0 to offset: [%48], sizes: [%52, %c16], strides: [%18, %c1] : memref<*xf32> to memref<?x16xf32, strided<[?, ?], offset: ?>>
%53 = arith.muli %arg15, %c16_i32 : i32
%54 = arith.subi %arg5, %53 : i32
%55 = arith.index_cast %54 : i32 to index
%56 = arith.minsi %55, %c16 : index
%57 = arith.maxsi %56, %c0 : index
%alloc_6 = memref.alloc() : memref<32x16xf32>
%58 = arith.cmpi slt, %57, %c16 : index
scf.if %58 {
linalg.fill ins(%cst : f32) outs(%alloc_6 : memref<32x16xf32>)
}
%59 = arith.minsi %51, %c32 : index
%60 = arith.subi %c32, %59 : index
%subview_7 = memref.subview %reinterpret_cast_4[0, 0] [%59, %57] [1, 1] : memref<?x16xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%subview_8 = memref.subview %reinterpret_cast_5[0, 0] [%60, %57] [1, 1] : memref<?x16xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%subview_9 = memref.subview %alloc_6[0, 0] [%59, %57] [1, 1] : memref<32x16xf32> to memref<?x?xf32, strided<[16, 1]>>
%subview_10 = memref.subview %alloc_6[%59, 0] [%60, %57] [1, 1] : memref<32x16xf32> to memref<?x?xf32, strided<[16, 1], offset: ?>>
memref.copy %subview_7, %subview_9 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[16, 1]>>
memref.copy %subview_8, %subview_10 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[16, 1], offset: ?>>
%alloc_11 = memref.alloc() : memref<16x64xf32>
scf.if %58 {
linalg.fill ins(%cst : f32) outs(%alloc_11 : memref<16x64xf32>)
}
%61 = arith.minsi %46, %c64 : index
%62 = arith.subi %c64, %61 : index
%subview_12 = memref.subview %reinterpret_cast_2[0, 0] [%57, %61] [1, 1] : memref<16x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%subview_13 = memref.subview %reinterpret_cast_3[0, 0] [%57, %62] [1, 1] : memref<16x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%subview_14 = memref.subview %alloc_11[0, 0] [%57, %61] [1, 1] : memref<16x64xf32> to memref<?x?xf32, strided<[64, 1]>>
%subview_15 = memref.subview %alloc_11[0, %61] [%57, %62] [1, 1] : memref<16x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
memref.copy %subview_12, %subview_14 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1]>>
memref.copy %subview_13, %subview_15 : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[64, 1], offset: ?>>
%alloc_16 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
memref.copy %alloc, %alloc_16 : memref<32x64xf32> to memref<32x64xf32>
linalg.matmul ins(%alloc_6, %alloc_11 : memref<32x16xf32>, memref<16x64xf32>) outs(%alloc_16 : memref<32x64xf32>)
%alloc_17 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg16, %alloc_16 : memref<32x64xf32>, memref<32x64xf32>) outs(%alloc_17 : memref<32x64xf32>) {
^bb0(%in: f32, %in_18: f32, %out: f32):
%65 = arith.addf %in, %in_18 : f32
linalg.yield %65 : f32
}
%63 = arith.addi %arg17, %c16 : index
%64 = arith.addi %arg18, %26 : index
scf.yield %alloc_17, %63, %64 : memref<32x64xf32>, index, index
}
%28 = arith.index_cast %arg8 : i32 to index
%29 = arith.muli %14, %28 : index
%30 = arith.addi %29, %16 : index
%reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [%30], sizes: [32, 64], strides: [%28, 1] : memref<*xf32> to memref<32x64xf32, strided<[?, 1], offset: ?>>
%31 = arith.addi %14, %c32 : index
%32 = arith.minsi %31, %17 : index
%33 = arith.maxsi %32, %14 : index
%34 = arith.subi %33, %14 : index
%35 = arith.addi %16, %c64 : index
%36 = arith.minsi %35, %22 : index
%37 = arith.maxsi %36, %16 : index
%38 = arith.subi %37, %16 : index
%39 = arith.minsi %34, %c32 : index
%40 = arith.minsi %38, %c64 : index
%subview = memref.subview %27#0[0, 0] [%39, %40] [1, 1] : memref<32x64xf32> to memref<?x?xf32, strided<[64, 1]>>
%subview_1 = memref.subview %reinterpret_cast[0, 0] [%39, %40] [1, 1] : memref<32x64xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
memref.copy %subview, %subview_1 : memref<?x?xf32, strided<[64, 1]>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
return
}
} |
the tt.dot with accumulator will lower to linalg.matmul and arith.add, and the arith.add will further lower to linalg.generic, generic will take the lhs of add as the DPS init, so the lhs should be the matmul accumulator. This is a temporary fix for issue microsoft#196.
the tt.dot with accumulator will lower to linalg.matmul and arith.add, and the arith.add will further lower to linalg.generic, generic will take the lhs of add as the DPS init, so the lhs should be the matmul accumulator. This is a temporary fix for issue microsoft#196.
the tt.dot with accumulator will lower to linalg.matmul and arith.add, and the arith.add will further lower to linalg.generic, generic will take the lhs of add as the DPS init, so the lhs should be the matmul accumulator. This is a temporary fix for issue microsoft#196.
#191 reported a bug where the bufferization result is incorrect when reusing tensor values in linalg's init param. Since we rely on the upstream arith to linalg pass, we do not have control over which value is being used as the init param. This is likely creating a lot of codegen bugs at the moment.
The text was updated successfully, but these errors were encountered: