-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introducing Arm SME/SVE2 Optimization pass #109
base: main
Are you sure you want to change the base?
Conversation
backend/compiler.py
Outdated
@@ -32,16 +44,25 @@ def _ttir_to_ttsharedir(mod): | |||
dst_path = os.path.join(tmpdir, "ttshared.mlir") | |||
Path(src_path).write_text(ttir_code) | |||
triton_shared_opt_path = _get_triton_shared_opt_path() | |||
subprocess.check_call([triton_shared_opt_path, src_path, "--triton-to-structured", "--canonicalize", "--triton-arith-to-linalg", "--cse", "--structured-to-memref", "-o", dst_path]) | |||
subprocess.check_call([triton_shared_opt_path, src_path, "--triton-to-linalg", "-o", dst_path]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we revert this? We're in the process of migrating away from the monolith pass, so any future work should ideally be tested using the new modular passes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could definitely be done, I my GCC compiler didn't like the modular pass format but I think I just have to re-install
Also I know that we just merged the arm-workflow runner but I might want to get rid of it since I have working changes for float16 and bfloat16. For quite sometime, triton-shared has been swapping out bf16/fp16 for f32 and I am working on optional support if the current system supports avx512_bf16(x86) or sve-bf16 (arm) or fp16 instructions. Wondering if the runner could changed at some point in the future |
getting error when trying to pass the IR through mlir-opt, I tried some of the flags used in the SME example but their for the transform interpreter, anybody got any ideas? error: failed to legalize operation 'arm_sme.tile_load' that was explicitly marked illegal
%114 = vector.mask %113 { vector.transfer_read %extracted_slice_21[%c0, %c0], %cst_0 {in_bounds = [true, true]} : tensor<?x?xf32>, vector<[4]x[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32> I think its got something todo with the flags being passed in def _ttsharedir_to_llir(ttsharedir: str):
with tempfile.TemporaryDirectory() as tmpdir:
ttshared_path = os.path.join(tmpdir, "ttshared.mlir")
llmlir_path = os.path.join(tmpdir, "ll.mlir")
llir_path = os.path.join(tmpdir, "ll.ir")
Path(ttshared_path).write_text(ttsharedir)
mlir_opt_path = _get_llvm_bin_path("mlir-opt")
# TritonShared-MLIR to LLVM-MLIR
subprocess.check_call([
mlir_opt_path,
ttshared_path,
"--convert-linalg-to-affine-loops",
"--eliminate-empty-tensors",
"--arm-sve-legalize-vector-storage",
"--allocate-arm-sme-tiles",
"--empty-tensor-to-alloc-tensor",
"--one-shot-bufferize=allow-return-allocs-from-loops=true",
"--lower-affine",
"--convert-linalg-to-loops",
"--convert-arm-sme-to-scf",
"--convert-scf-to-cf",
"--convert-cf-to-llvm",
"--convert-arith-to-llvm",
"--convert-math-to-llvm",
"--convert-complex-to-llvm",
"--convert-vector-to-arm-sme",
"--convert-arm-sme-to-llvm",
"--convert-index-to-llvm",
"--memref-expand",
"-convert-vector-to-llvm=enable-arm-sve",
"--expand-strided-metadata",
"--finalize-memref-to-llvm",
"--convert-func-to-llvm",
# Lowering memrefs creates more affine.apply ops.
# Lowering these affine ops again creates further arith ops,
# so we have to run these two passes again here.
"--lower-affine",
"--convert-arith-to-llvm",
# Remove all unrealized casts created
"--canonicalize",
"-o",
llmlir_path,
])
# LLVM-MLIR to LLVM-IR
mlir_translate_path = _get_llvm_bin_path("mlir-translate")
subprocess.check_call([mlir_translate_path, llmlir_path,
"--mlir-to-llvmir",
"-o",
llir_path])
return Path(llir_path).read_text() |
You maybe missing lowering masked vector transfers. |
that seemed to change the IR but didn't seem to fix the issue completely this is the IR now being generated #map = affine_map<(d0, d1) -> (-d0 + 32, d1)>
#map1 = affine_map<(d0, d1) -> (-d0 + 64, d1)>
#map2 = affine_map<(d0)[s0] -> (d0 * 16 + s0)>
#map3 = affine_map<(d0)[s0] -> (d0 + s0)>
#map4 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map5 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map6 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map7 = affine_map<(d0, d1) -> (d0, d1)>
module {
func.func @matmul_kernel_0d1d2d34567c89c1011c(%arg0: memref<*xf16> {tt.divisibility = 16 : i32}, %arg1: memref<*xf16> {tt.divisibility = 16 : i32}, %arg2: memref<*xf16> {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32) {
%c4 = arith.constant 4 : index
%c8_i32 = arith.constant 8 : i32
%c32_i32 = arith.constant 32 : i32
%c64_i32 = arith.constant 64 : i32
%c16_i32 = arith.constant 16 : i32
%cst = arith.constant 0.000000e+00 : f32
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c31_i32 = arith.constant 31 : i32
%c63_i32 = arith.constant 63 : i32
%c15_i32 = arith.constant 15 : i32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%cst_0 = arith.constant 0.000000e+00 : f16
%c32 = arith.constant 32 : index
%c64 = arith.constant 64 : index
%alloc = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
linalg.fill ins(%cst : f32) outs(%alloc : memref<32x64xf32>)
%0 = arith.addi %arg3, %c31_i32 : i32
%1 = arith.divsi %0, %c32_i32 : i32
%2 = arith.addi %arg4, %c63_i32 : i32
%3 = arith.divsi %2, %c64_i32 : i32
%4 = arith.muli %3, %c8_i32 : i32
%5 = arith.divsi %arg12, %4 : i32
%6 = arith.muli %5, %c8_i32 : i32
%7 = arith.subi %1, %6 : i32
%8 = arith.minsi %7, %c8_i32 : i32
%9 = arith.remsi %arg12, %8 : i32
%10 = arith.addi %6, %9 : i32
%11 = arith.remsi %arg12, %4 : i32
%12 = arith.divsi %11, %8 : i32
%13 = arith.muli %10, %c32_i32 : i32
%14 = arith.index_cast %13 : i32 to index
%15 = arith.muli %12, %c64_i32 : i32
%16 = arith.index_cast %15 : i32 to index
%17 = arith.index_cast %arg3 : i32 to index
%18 = arith.index_cast %arg6 : i32 to index
%19 = arith.muli %14, %18 : index
%20 = arith.muli %17, %18 : index
%21 = arith.index_cast %arg7 : i32 to index
%22 = arith.index_cast %arg4 : i32 to index
%23 = arith.addi %arg5, %c15_i32 : i32
%24 = arith.divsi %23, %c16_i32 : i32
%25 = arith.muli %arg7, %c16_i32 : i32
%26 = arith.index_cast %25 : i32 to index
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
memref.copy %alloc, %alloc_1 : memref<32x64xf32> to memref<32x64xf32>
%27:3 = scf.for %arg15 = %c0_i32 to %24 step %c1_i32 iter_args(%arg16 = %alloc_1, %arg17 = %19, %arg18 = %c0) -> (memref<32x64xf32>, index, index) : i32 {
%39 = arith.addi %arg18, %16 : index
%40 = arith.remsi %39, %22 : index
%41 = arith.subi %39, %40 : index
%42 = arith.addi %40, %c64 : index
%43 = arith.minsi %42, %22 : index
%44 = arith.subi %43, %40 : index
%reinterpret_cast_4 = memref.reinterpret_cast %arg1 to offset: [%39], sizes: [%c16, %44], strides: [%21, %c1] : memref<*xf16> to memref<16x?xf16, strided<[?, ?], offset: ?>>
%45 = arith.subi %c64, %44 : index
%reinterpret_cast_5 = memref.reinterpret_cast %arg1 to offset: [%41], sizes: [%c16, %45], strides: [%21, %c1] : memref<*xf16> to memref<16x?xf16, strided<[?, ?], offset: ?>>
%46 = arith.remsi %arg17, %18 : index
%47 = arith.addi %20, %46 : index
%48 = arith.subi %47, %arg17 : index
%49 = arith.divsi %48, %18 : index
%reinterpret_cast_6 = memref.reinterpret_cast %arg0 to offset: [%arg17], sizes: [%49, %c16], strides: [%18, %c1] : memref<*xf16> to memref<?x16xf16, strided<[?, ?], offset: ?>>
%50 = arith.subi %c32, %49 : index
%reinterpret_cast_7 = memref.reinterpret_cast %arg0 to offset: [%46], sizes: [%50, %c16], strides: [%18, %c1] : memref<*xf16> to memref<?x16xf16, strided<[?, ?], offset: ?>>
%51 = arith.muli %arg15, %c16_i32 : i32
%52 = arith.subi %arg5, %51 : i32
%53 = arith.index_cast %52 : i32 to index
%54 = arith.minsi %53, %c16 : index
%alloc_8 = memref.alloc() : memref<32x16xf16>
%55 = arith.cmpi slt, %54, %c16 : index
scf.if %55 {
linalg.fill ins(%cst_0 : f16) outs(%alloc_8 : memref<32x16xf16>)
}
%56 = arith.minsi %49, %c32 : index
%57 = arith.subi %c32, %56 : index
%subview_9 = memref.subview %reinterpret_cast_6[0, 0] [%56, %54] [1, 1] : memref<?x16xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[?, ?], offset: ?>>
%subview_10 = memref.subview %reinterpret_cast_7[0, 0] [%57, %54] [1, 1] : memref<?x16xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[?, ?], offset: ?>>
%subview_11 = memref.subview %alloc_8[0, 0] [%56, %54] [1, 1] : memref<32x16xf16> to memref<?x?xf16, strided<[16, 1]>>
%subview_12 = memref.subview %alloc_8[%56, 0] [%57, %54] [1, 1] : memref<32x16xf16> to memref<?x?xf16, strided<[16, 1], offset: ?>>
memref.copy %subview_9, %subview_11 : memref<?x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[16, 1]>>
memref.copy %subview_10, %subview_12 : memref<?x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[16, 1], offset: ?>>
%alloc_13 = memref.alloc() : memref<16x64xf16>
%58 = arith.cmpi slt, %54, %c16 : index
scf.if %58 {
linalg.fill ins(%cst_0 : f16) outs(%alloc_13 : memref<16x64xf16>)
}
%59 = arith.minsi %44, %c64 : index
%60 = arith.subi %c64, %59 : index
%subview_14 = memref.subview %reinterpret_cast_4[0, 0] [%54, %59] [1, 1] : memref<16x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[?, ?], offset: ?>>
%subview_15 = memref.subview %reinterpret_cast_5[0, 0] [%54, %60] [1, 1] : memref<16x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[?, ?], offset: ?>>
%subview_16 = memref.subview %alloc_13[0, 0] [%54, %59] [1, 1] : memref<16x64xf16> to memref<?x?xf16, strided<[64, 1]>>
%subview_17 = memref.subview %alloc_13[0, %59] [%54, %60] [1, 1] : memref<16x64xf16> to memref<?x?xf16, strided<[64, 1], offset: ?>>
memref.copy %subview_14, %subview_16 : memref<?x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[64, 1]>>
memref.copy %subview_15, %subview_17 : memref<?x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[64, 1], offset: ?>>
%61 = vector.vscale
%62 = arith.muli %61, %c4 : index
%63 = arith.muli %61, %c4 : index
%alloc_18 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
memref.copy %alloc, %alloc_18 : memref<32x64xf32> to memref<32x64xf32>
%64 = scf.for %arg19 = %c0 to %c32 step %62 iter_args(%arg20 = %alloc_18) -> (memref<32x64xf32>) {
%67 = scf.for %arg21 = %c0 to %c64 step %63 iter_args(%arg22 = %arg20) -> (memref<32x64xf32>) {
%68 = scf.for %arg23 = %c0 to %c16 step %c1 iter_args(%arg24 = %arg22) -> (memref<32x64xf32>) {
%69 = affine.min #map(%arg19, %62)
%70 = affine.min #map1(%arg21, %63)
%71 = affine.min #map(%arg19, %62)
%72 = affine.min #map1(%arg21, %63)
%subview_19 = memref.subview %alloc_8[%arg19, %arg23] [%69, 1] [1, 1] : memref<32x16xf16> to memref<?x1xf16, strided<[16, 1], offset: ?>>
%subview_20 = memref.subview %alloc_13[%arg23, %arg21] [1, %70] [1, 1] : memref<16x64xf16> to memref<1x?xf16, strided<[64, 1], offset: ?>>
%subview_21 = memref.subview %arg24[%arg19, %arg21] [%71, %72] [1, 1] : memref<32x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
%73 = vector.create_mask %69 : vector<[4]xi1>
%subview_22 = memref.subview %subview_19[0, 0] [%69, 1] [1, 1] : memref<?x1xf16, strided<[16, 1], offset: ?>> to memref<?xf16, #map2>
%74 = vector.transfer_read %subview_22[%c0], %cst_0, %73 {in_bounds = [true]} : memref<?xf16, #map2>, vector<[4]xf16>
%75 = vector.shape_cast %74 : vector<[4]xf16> to vector<[4]x1xf16>
%76 = vector.create_mask %70 : vector<[4]xi1>
%subview_23 = memref.subview %subview_20[0, 0] [1, %70] [1, 1] : memref<1x?xf16, strided<[64, 1], offset: ?>> to memref<?xf16, #map3>
%77 = vector.transfer_read %subview_23[%c0], %cst_0, %76 {in_bounds = [true]} : memref<?xf16, #map3>, vector<[4]xf16>
%78 = vector.shape_cast %77 : vector<[4]xf16> to vector<1x[4]xf16>
%79 = vector.create_mask %69, %70 : vector<[4]x[4]xi1>
%80 = vector.transfer_read %subview_21[%c0, %c0], %cst, %79 {in_bounds = [true, true]} : memref<?x?xf32, strided<[64, 1], offset: ?>>, vector<[4]x[4]xf32>
%81 = arith.extf %75 : vector<[4]x1xf16> to vector<[4]x1xf32>
%82 = arith.extf %78 : vector<1x[4]xf16> to vector<1x[4]xf32>
%83 = vector.create_mask %69, %70, %c1 : vector<[4]x[4]x1xi1>
%84 = vector.mask %83 { vector.contract {indexing_maps = [#map4, #map5, #map6], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %81, %82, %80 : vector<[4]x1xf32>, vector<1x[4]xf32> into vector<[4]x[4]xf32> } : vector<[4]x[4]x1xi1> -> vector<[4]x[4]xf32>
vector.transfer_write %84, %subview_21[%c0, %c0], %79 {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32, strided<[64, 1], offset: ?>>
%subview_24 = memref.subview %arg24[%arg19, %arg21] [%71, %72] [1, 1] : memref<32x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
memref.copy %subview_21, %subview_24 : memref<?x?xf32, strided<[64, 1], offset: ?>> to memref<?x?xf32, strided<[64, 1], offset: ?>>
scf.yield %arg24 : memref<32x64xf32>
}
scf.yield %68 : memref<32x64xf32>
}
scf.yield %67 : memref<32x64xf32>
}
linalg.generic {indexing_maps = [#map7, #map7, #map7], iterator_types = ["parallel", "parallel"]} ins(%64, %arg16 : memref<32x64xf32>, memref<32x64xf32>) outs(%64 : memref<32x64xf32>) {
^bb0(%in: f32, %in_19: f32, %out: f32):
%67 = arith.addf %in, %in_19 : f32
linalg.yield %67 : f32
}
%65 = arith.addi %arg17, %c16 : index
%66 = arith.addi %arg18, %26 : index
scf.yield %64, %65, %66 : memref<32x64xf32>, index, index
}
%28 = arith.index_cast %arg8 : i32 to index
%29 = arith.muli %14, %28 : index
%30 = arith.addi %29, %16 : index
%reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [%30], sizes: [32, 64], strides: [%28, 1] : memref<*xf16> to memref<32x64xf16, strided<[?, 1], offset: ?>>
%alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf16>
linalg.generic {indexing_maps = [#map7, #map7], iterator_types = ["parallel", "parallel"]} ins(%27#0 : memref<32x64xf32>) outs(%alloc_2 : memref<32x64xf16>) {
^bb0(%in: f32, %out: f16):
%39 = arith.truncf %in : f32 to f16
linalg.yield %39 : f16
}
%31 = arith.addi %14, %c32 : index
%32 = arith.minsi %31, %17 : index
%33 = arith.subi %32, %14 : index
%34 = arith.addi %16, %c64 : index
%35 = arith.minsi %34, %22 : index
%36 = arith.subi %35, %16 : index
%37 = arith.minsi %33, %c32 : index
%38 = arith.minsi %36, %c64 : index
%subview = memref.subview %alloc_2[0, 0] [%37, %38] [1, 1] : memref<32x64xf16> to memref<?x?xf16, strided<[64, 1]>>
%subview_3 = memref.subview %reinterpret_cast[0, 0] [%37, %38] [1, 1] : memref<32x64xf16, strided<[?, 1], offset: ?>> to memref<?x?xf16, strided<[?, 1], offset: ?>>
memref.copy %subview, %subview_3 : memref<?x?xf16, strided<[64, 1]>> to memref<?x?xf16, strided<[?, 1], offset: ?>>
return
}
}
|
There's a masked vector.contract in the IR
Maybe lowering vector mask can help: |
I figured out the issue is that the outerproduct part seems to be having no effect on the MLIR output. I'm trying to figure out why this does nothing struct OuterProductVectorizationPass
: public PassWrapper<OuterProductVectorizationPass,
OperationPass<func::FuncOp>> {
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<vector::VectorDialect, func::FuncDialect>();
}
void runOnOperation() override {
func::FuncOp funcOp = getOperation();
MLIRContext *context = funcOp.getContext();
RewritePatternSet patterns(context);
ConversionTarget target(*context);
// Apply patterns for lowering masked transfers
transform::ApplyLowerMaskedTransfersPatternsOp lowerMaskedTransfersPatterns;
lowerMaskedTransfersPatterns.populatePatterns(patterns);
// Apply patterns for transfer permutation
transform::ApplyTransferPermutationPatternsOp transferPermutationPatterns;
transferPermutationPatterns.populatePatterns(patterns);
// Apply patterns for reduction to contract
transform::ApplyVectorReductionToContractPatternsOp reductionToContractPatterns;
reductionToContractPatterns.populatePatterns(patterns);
// Apply patterns for lowering contraction using outer product
transform::ApplyLowerOuterProductPatternsOp lowerOuterProductPatterns;
lowerOuterProductPatterns.populatePatterns(patterns);
// Apply patterns for lowering masks
transform::ApplyLowerMasksPatternsOp lowerMasksPatterns;
lowerMasksPatterns.populatePatterns(patterns);
// Apply patterns for rank-reducing subview
transform::ApplyRankReducingSubviewPatternsOp rankReducingSubviewPatterns;
rankReducingSubviewPatterns.populatePatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}
};
|
reductionToContractPatterns.populatePatterns(patterns); | ||
|
||
// Apply patterns for lowering contraction using outer product | ||
transform::ApplyLowerOuterProductPatternsOp lowerOuterProductPatterns; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noe that these patterns lower vector.outerproduct
rather than vector.contract
to vector.outerproduct
: https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td#L205
// Apply patterns for reduction to contract | ||
transform::ApplyVectorReductionToContractPatternsOp reductionToContractPatterns; | ||
reductionToContractPatterns.populatePatterns(patterns); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are missing patterns from vector.contract
to vector.outerproduct
:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nhat-nguyen turns out that the whole kernel need to be bufferized before we can run this pass. Can we use structured-to-memref todo this? would make things a lot easier
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah you can use a combination of --triton-to-structured --canonicalize --triton-arith-to-linalg --structured-to-memref
, this is the equivalent of --triton-to-linalg
. Although this only converts the loads and stores to memref, you would need to also run the bufferization pass to convert the remaining ops to use memref too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea i was a little worried about this, looks like all the tensor ops will have to be lowered too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nhat-nguyen so I think in order to lower to outer product, everything will have to to be lowered to memref, I added in a few off the self bufferization passes and am able to everything expect the bufferization.to_tensor op. Any ideas how to fix this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from what I'm hearing, bufferizing might be necessary before step #2
(1) bufferize here (2) (3)
linalg.matmul -------------> vector.contract ----> vector.outerproduct -----> arm_sme.fmopa
in that case optimizations may be difficult todo optimizations. Might still write a bufferization pass to make the compiler happy but def not optimal
@banach-space if I wanted to make future optizations could I modify the lowering to also accept tensors and write some lowering logic? Been thinking it over and bufferizing too early could be bad for performance and given that SME is a matmul engine I'd imagine being able to accept tensors might be useful
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from what I'm hearing, bufferizing might be necessary before step 2
I don't think so, that's still quite high level and there wouldn't be any SME ops yet :)
Been thinking it over and bufferizing too early could be bad for performance and given that SME is a matmul engine I'd imagine being able to accept tensors might be useful
In my view Tensor is too high level/abstract type for the ArmSME dialect. In particular, the hardware won't know what a tensor is. I suggest the following:
(1) (2) (3) bufferize here
linalg.matmul -----> vector.contract ----> vector.outerproduct -----------------> arm_sme.fmopa
This would be consistent with what we currently do in MLIR:
As you can see, bufferization happens after lowering vector.contract
to vector.outerproduct
.
Btw, I deliberately made the distinction into vector.contract
and vector.outerproduct
as that's representative of how the Vector Dialect splits Ops into "high level" (e.g. vector.contract
) and "low level" (e.g. vector.outerproduct
) vector ops. Here's an overview of the vectoriser n the context of SME:
Thanks again for working on this - hope this helps :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The video makes sense but a just want to clarify how the passes would be handled
This is what should happen
PassPipelineRegistration<> smeConversionPipeline(
"sme-conversion",
"Converts linalg.matmul to a more optimized form using SME",
[](OpPassManager &pm) {
pm.addPass(createMatmulTileConversionPass(true)); //tiling and vectorizing linalg.matmul
pm.addPass(createOuterProductVectorizationPass()); // lowering vector.contract to vector.outerproduct
pm.addPass(createLinalgBufferizePass()); //bufferization happens here
});
let me know if this is accurate
and ofc happy to work on this :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes sense to me, but I haven't checked the names of the actual passes. Also, you may need to configure some of these passes to do what you want.
I suggest implementing this step-by-step. At every step you can compare against the output that the Transform Dialect sequence generates - if there's divergence then that might mean trouble :) (but not necessarily)
Note that in your list of passes there's nothing SVE/SME specific, so you won't be lowering to SME just yet. That's fine - baby steps ;-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just updated, it looks mostly bufferized
the sme-opt seems to work now, producing a valid output for sme.outerproduct however I encounter an odd issue when lowering to llvm, perhaps "--convert-cf-to-llvm" fails? If you look at the snippet, you can see that cf.br is still present. @banach-space could it some complication with ll.mlir:75:5: error: Dialect `cf' not found for custom op 'cf.br'
cf.br ^bb1(%37 : index)
^
.../ll.mlir:75:5: note: Registered dialects: acc, amx, arm_neon, arm_sme, arm_sve, builtin, dlti, func, gpu, llvm, nvvm, omp, rocdl, spirv, x86vector ; for more info on dialect registration see https://mlir.llvm.org/getting_started/Faq... I ran with and without and as you can see in that cf.br is still present when it should have been lowered current output from sme-opt: #map = affine_map<()[s0] -> (s0 * 16)>
#map1 = affine_map<(d0, d1) -> (-d0 + 32, d1)>
#map2 = affine_map<(d0, d1) -> (-d0 + 64, d1)>
#map3 = affine_map<()[s0, s1] -> (s0 * 16 + s1)>
#map4 = affine_map<()[s0, s1] -> (s0 * 64 + s1)>
#map5 = affine_map<(d0)[s0] -> (d0 * 16 + s0)>
#map6 = affine_map<(d0)[s0] -> (d0 + s0)>
module {
func.func @matmul_kernel_0d1d2d34567c89c1011c(%arg0: memref<*xf16> {tt.divisibility = 16 : i32}, %arg1: memref<*xf16> {tt.divisibility = 16 : i32}, %arg2: memref<*xf16> {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32) attributes {arm_sme.tiles_in_use = 34952 : i32} {
%cst = arith.constant dense<0.000000e+00> : vector<[4]xf16>
%c4 = arith.constant 4 : index
%c8_i32 = arith.constant 8 : i32
%c32_i32 = arith.constant 32 : i32
%c64_i32 = arith.constant 64 : i32
%c16_i32 = arith.constant 16 : i32
%cst_0 = arith.constant 0.000000e+00 : f32
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c31_i32 = arith.constant 31 : i32
%c63_i32 = arith.constant 63 : i32
%c15_i32 = arith.constant 15 : i32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%cst_1 = arith.constant 0.000000e+00 : f16
%c32 = arith.constant 32 : index
%c64 = arith.constant 64 : index
%alloc = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
scf.for %arg15 = %c0 to %c32 step %c1 {
scf.for %arg16 = %c0 to %c64 step %c1 {
memref.store %cst_0, %alloc[%arg15, %arg16] : memref<32x64xf32>
}
}
%0 = arith.addi %arg3, %c31_i32 : i32
%1 = arith.divsi %0, %c32_i32 : i32
%2 = arith.addi %arg4, %c63_i32 : i32
%3 = arith.divsi %2, %c64_i32 : i32
%4 = arith.muli %3, %c8_i32 : i32
%5 = arith.divsi %arg12, %4 : i32
%6 = arith.muli %5, %c8_i32 : i32
%7 = arith.subi %1, %6 : i32
%8 = arith.minsi %7, %c8_i32 : i32
%9 = arith.remsi %arg12, %8 : i32
%10 = arith.addi %6, %9 : i32
%11 = arith.remsi %arg12, %4 : i32
%12 = arith.divsi %11, %8 : i32
%13 = arith.muli %10, %c32_i32 : i32
%14 = arith.index_cast %13 : i32 to index
%15 = arith.muli %12, %c64_i32 : i32
%16 = arith.index_cast %15 : i32 to index
%17 = arith.index_cast %arg3 : i32 to index
%18 = arith.index_cast %arg6 : i32 to index
%19 = arith.muli %14, %18 : index
%20 = arith.muli %17, %18 : index
%21 = arith.index_cast %arg7 : i32 to index
%22 = arith.index_cast %arg4 : i32 to index
%23 = arith.addi %arg5, %c15_i32 : i32
%24 = arith.divsi %23, %c16_i32 : i32
%25 = arith.muli %arg7, %c16_i32 : i32
%26 = arith.index_cast %25 : i32 to index
%alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
memref.copy %alloc, %alloc_2 : memref<32x64xf32> to memref<32x64xf32>
%27:3 = scf.for %arg15 = %c0_i32 to %24 step %c1_i32 iter_args(%arg16 = %alloc_2, %arg17 = %19, %arg18 = %c0) -> (memref<32x64xf32>, index, index) : i32 {
%39 = arith.addi %arg18, %16 : index
%40 = arith.remsi %39, %22 : index
%41 = arith.subi %39, %40 : index
%42 = arith.addi %40, %c64 : index
%43 = arith.minsi %42, %22 : index
%44 = arith.subi %43, %40 : index
%reinterpret_cast_6 = memref.reinterpret_cast %arg1 to offset: [%39], sizes: [%c16, %44], strides: [%21, %c1] : memref<*xf16> to memref<16x?xf16, strided<[?, ?], offset: ?>>
%45 = arith.subi %c64, %44 : index
%reinterpret_cast_7 = memref.reinterpret_cast %arg1 to offset: [%41], sizes: [%c16, %45], strides: [%21, %c1] : memref<*xf16> to memref<16x?xf16, strided<[?, ?], offset: ?>>
%46 = arith.remsi %arg17, %18 : index
%47 = arith.addi %20, %46 : index
%48 = arith.subi %47, %arg17 : index
%49 = arith.divsi %48, %18 : index
%reinterpret_cast_8 = memref.reinterpret_cast %arg0 to offset: [%arg17], sizes: [%49, %c16], strides: [%18, %c1] : memref<*xf16> to memref<?x16xf16, strided<[?, ?], offset: ?>>
%50 = arith.subi %c32, %49 : index
%reinterpret_cast_9 = memref.reinterpret_cast %arg0 to offset: [%46], sizes: [%50, %c16], strides: [%18, %c1] : memref<*xf16> to memref<?x16xf16, strided<[?, ?], offset: ?>>
%51 = arith.muli %arg15, %c16_i32 : i32
%52 = arith.subi %arg5, %51 : i32
%53 = arith.index_cast %52 : i32 to index
%54 = arith.minsi %53, %c16 : index
%alloc_10 = memref.alloc() : memref<32x16xf16>
%55 = arith.cmpi slt, %54, %c16 : index
scf.if %55 {
scf.for %arg19 = %c0 to %c32 step %c1 {
scf.for %arg20 = %c0 to %c16 step %c1 {
memref.store %cst_1, %alloc_10[%arg19, %arg20] : memref<32x16xf16>
}
}
}
%56 = arith.minsi %49, %c32 : index
%57 = arith.subi %c32, %56 : index
%base_buffer_11, %offset_12, %sizes_13:2, %strides_14:2 = memref.extract_strided_metadata %reinterpret_cast_8 : memref<?x16xf16, strided<[?, ?], offset: ?>> -> memref<f16>, index, index, index, index, index
%reinterpret_cast_15 = memref.reinterpret_cast %base_buffer_11 to offset: [%offset_12], sizes: [%56, %54], strides: [%strides_14#0, %strides_14#1] : memref<f16> to memref<?x?xf16, strided<[?, ?], offset: ?>>
%base_buffer_16, %offset_17, %sizes_18:2, %strides_19:2 = memref.extract_strided_metadata %reinterpret_cast_9 : memref<?x16xf16, strided<[?, ?], offset: ?>> -> memref<f16>, index, index, index, index, index
%reinterpret_cast_20 = memref.reinterpret_cast %base_buffer_16 to offset: [%offset_17], sizes: [%57, %54], strides: [%strides_19#0, %strides_19#1] : memref<f16> to memref<?x?xf16, strided<[?, ?], offset: ?>>
%reinterpret_cast_21 = memref.reinterpret_cast %alloc_10 to offset: [0], sizes: [%56, %54], strides: [16, 1] : memref<32x16xf16> to memref<?x?xf16, strided<[16, 1]>>
%58 = affine.apply #map()[%56]
%reinterpret_cast_22 = memref.reinterpret_cast %alloc_10 to offset: [%58], sizes: [%57, %54], strides: [16, 1] : memref<32x16xf16> to memref<?x?xf16, strided<[16, 1], offset: ?>>
memref.copy %reinterpret_cast_15, %reinterpret_cast_21 : memref<?x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[16, 1]>>
memref.copy %reinterpret_cast_20, %reinterpret_cast_22 : memref<?x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[16, 1], offset: ?>>
%alloc_23 = memref.alloc() : memref<16x64xf16>
%59 = arith.cmpi slt, %54, %c16 : index
scf.if %59 {
scf.for %arg19 = %c0 to %c16 step %c1 {
scf.for %arg20 = %c0 to %c64 step %c1 {
memref.store %cst_1, %alloc_23[%arg19, %arg20] : memref<16x64xf16>
}
}
}
%60 = arith.minsi %44, %c64 : index
%61 = arith.subi %c64, %60 : index
%base_buffer_24, %offset_25, %sizes_26:2, %strides_27:2 = memref.extract_strided_metadata %reinterpret_cast_6 : memref<16x?xf16, strided<[?, ?], offset: ?>> -> memref<f16>, index, index, index, index, index
%reinterpret_cast_28 = memref.reinterpret_cast %base_buffer_24 to offset: [%offset_25], sizes: [%54, %60], strides: [%strides_27#0, %strides_27#1] : memref<f16> to memref<?x?xf16, strided<[?, ?], offset: ?>>
%base_buffer_29, %offset_30, %sizes_31:2, %strides_32:2 = memref.extract_strided_metadata %reinterpret_cast_7 : memref<16x?xf16, strided<[?, ?], offset: ?>> -> memref<f16>, index, index, index, index, index
%reinterpret_cast_33 = memref.reinterpret_cast %base_buffer_29 to offset: [%offset_30], sizes: [%54, %61], strides: [%strides_32#0, %strides_32#1] : memref<f16> to memref<?x?xf16, strided<[?, ?], offset: ?>>
%reinterpret_cast_34 = memref.reinterpret_cast %alloc_23 to offset: [0], sizes: [%54, %60], strides: [64, 1] : memref<16x64xf16> to memref<?x?xf16, strided<[64, 1]>>
%reinterpret_cast_35 = memref.reinterpret_cast %alloc_23 to offset: [%60], sizes: [%54, %61], strides: [64, 1] : memref<16x64xf16> to memref<?x?xf16, strided<[64, 1], offset: ?>>
memref.copy %reinterpret_cast_28, %reinterpret_cast_34 : memref<?x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[64, 1]>>
memref.copy %reinterpret_cast_33, %reinterpret_cast_35 : memref<?x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[64, 1], offset: ?>>
%62 = vector.vscale
%63 = arith.muli %62, %c4 : index
%64 = arith.muli %62, %c4 : index
%alloc_36 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
memref.copy %alloc, %alloc_36 : memref<32x64xf32> to memref<32x64xf32>
scf.for %arg19 = %c0 to %c32 step %63 {
scf.for %arg20 = %c0 to %c64 step %64 {
scf.for %arg21 = %c0 to %c16 step %c1 {
%67 = affine.min #map1(%arg19, %63)
%68 = affine.min #map2(%arg20, %64)
%69 = affine.min #map1(%arg19, %63)
%70 = affine.min #map2(%arg20, %64)
%71 = affine.apply #map3()[%arg19, %arg21]
%72 = affine.apply #map4()[%arg21, %arg20]
%73 = affine.apply #map4()[%arg19, %arg20]
%reinterpret_cast_37 = memref.reinterpret_cast %alloc_36 to offset: [%73], sizes: [%69, %70], strides: [64, 1] : memref<32x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
%74 = vector.create_mask %67 : vector<[4]xi1>
%reinterpret_cast_38 = memref.reinterpret_cast %alloc_10 to offset: [%71], sizes: [%67], strides: [16] : memref<32x16xf16> to memref<?xf16, #map5>
%75 = vector.vscale
%76 = arith.muli %75, %c4 : index
%77 = scf.for %arg22 = %c0 to %76 step %c1 iter_args(%arg23 = %cst) -> (vector<[4]xf16>) {
%108 = vector.extractelement %74[%arg22 : index] : vector<[4]xi1>
%109 = scf.if %108 -> (vector<[4]xf16>) {
%110 = memref.load %reinterpret_cast_38[%arg22] : memref<?xf16, #map5>
%111 = vector.insertelement %110, %arg23[%arg22 : index] : vector<[4]xf16>
scf.yield %111 : vector<[4]xf16>
} else {
scf.yield %arg23 : vector<[4]xf16>
}
scf.yield %109 : vector<[4]xf16>
}
%78 = vector.shape_cast %77 : vector<[4]xf16> to vector<[4]x1xf16>
%79 = vector.create_mask %68 : vector<[4]xi1>
%reinterpret_cast_39 = memref.reinterpret_cast %alloc_23 to offset: [%72], sizes: [%68], strides: [1] : memref<16x64xf16> to memref<?xf16, #map6>
%80 = vector.transfer_read %reinterpret_cast_39[%c0], %cst_1, %79 {in_bounds = [true]} : memref<?xf16, #map6>, vector<[4]xf16>
%81 = vector.shape_cast %80 : vector<[4]xf16> to vector<1x[4]xf16>
%82 = vector.create_mask %67, %68 : vector<[4]x[4]xi1>
%83 = arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xf32>
%c4_40 = arith.constant 4 : index
%84 = vector.vscale
%85 = arith.muli %c4_40, %84 : index
%86 = arith.index_cast %67 : index to i64
%87 = arith.index_cast %85 : index to i64
%88 = arith.minsi %86, %87 : i64
%89 = arith.index_cast %88 : i64 to index
%90 = vector.create_mask %68 : vector<[4]xi1>
%c0_41 = arith.constant 0 : index
%c1_42 = arith.constant 1 : index
%91 = scf.for %arg22 = %c0_41 to %89 step %c1_42 iter_args(%arg23 = %83) -> (vector<[4]x[4]xf32>) {
%108 = arith.addi %c0, %arg22 : index
%109 = arm_sme.load_tile_slice %reinterpret_cast_37[%108, %c0], %90, %arg23, %arg22 {tile_id = 0 : i32} : memref<?x?xf32, strided<[64, 1], offset: ?>>, vector<[4]xi1>, vector<[4]x[4]xf32>
scf.yield %109 : vector<[4]x[4]xf32>
}
%92 = arith.extf %78 : vector<[4]x1xf16> to vector<[4]x1xf32>
%93 = arith.extf %81 : vector<1x[4]xf16> to vector<1x[4]xf32>
%94 = vector.transpose %92, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
%95 = vector.extract %94[0] : vector<[4]xf32> from vector<1x[4]xf32>
%96 = vector.extract %93[0] : vector<[4]xf32> from vector<1x[4]xf32>
%97 = vector.create_mask %67 : vector<[4]xi1>
%98 = vector.create_mask %68 : vector<[4]xi1>
%99 = arm_sme.outerproduct %95, %96 acc(%91) masks(%97, %98) {tile_id = 0 : i32} : vector<[4]xf32>, vector<[4]xf32>
%c4_43 = arith.constant 4 : index
%100 = vector.vscale
%101 = arith.muli %c4_43, %100 : index
%102 = arith.index_cast %67 : index to i64
%103 = arith.index_cast %101 : index to i64
%104 = arith.minsi %102, %103 : i64
%105 = arith.index_cast %104 : i64 to index
%106 = vector.create_mask %68 : vector<[4]xi1>
%c0_44 = arith.constant 0 : index
%c1_45 = arith.constant 1 : index
scf.for %arg22 = %c0_44 to %105 step %c1_45 {
%108 = arith.addi %c0, %arg22 : index
arm_sme.store_tile_slice %99, %arg22, %106, %reinterpret_cast_37[%108, %c0] {tile_id = 0 : i32} : memref<?x?xf32, strided<[64, 1], offset: ?>>, vector<[4]xi1>, vector<[4]x[4]xf32>
}
%107 = affine.apply #map4()[%arg19, %arg20]
%reinterpret_cast_46 = memref.reinterpret_cast %alloc_36 to offset: [%107], sizes: [%69, %70], strides: [64, 1] : memref<32x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
memref.copy %reinterpret_cast_37, %reinterpret_cast_46 : memref<?x?xf32, strided<[64, 1], offset: ?>> to memref<?x?xf32, strided<[64, 1], offset: ?>>
}
}
}
scf.for %arg19 = %c0 to %c32 step %c1 {
scf.for %arg20 = %c0 to %c64 step %c1 {
%67 = memref.load %alloc_36[%arg19, %arg20] : memref<32x64xf32>
%68 = memref.load %arg16[%arg19, %arg20] : memref<32x64xf32>
%69 = arith.addf %67, %68 : f32
memref.store %69, %alloc_36[%arg19, %arg20] : memref<32x64xf32>
}
}
%65 = arith.addi %arg17, %c16 : index
%66 = arith.addi %arg18, %26 : index
scf.yield %alloc_36, %65, %66 : memref<32x64xf32>, index, index
}
%28 = arith.index_cast %arg8 : i32 to index
%29 = arith.muli %14, %28 : index
%30 = arith.addi %29, %16 : index
%reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [%30], sizes: [32, 64], strides: [%28, 1] : memref<*xf16> to memref<32x64xf16, strided<[?, 1], offset: ?>>
%alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf16>
scf.for %arg15 = %c0 to %c32 step %c1 {
scf.for %arg16 = %c0 to %c64 step %c1 {
%39 = memref.load %27#0[%arg15, %arg16] : memref<32x64xf32>
%40 = arith.truncf %39 : f32 to f16
memref.store %40, %alloc_3[%arg15, %arg16] : memref<32x64xf16>
}
}
%31 = arith.addi %14, %c32 : index
%32 = arith.minsi %31, %17 : index
%33 = arith.subi %32, %14 : index
%34 = arith.addi %16, %c64 : index
%35 = arith.minsi %34, %22 : index
%36 = arith.subi %35, %16 : index
%37 = arith.minsi %33, %c32 : index
%38 = arith.minsi %36, %c64 : index
%reinterpret_cast_4 = memref.reinterpret_cast %alloc_3 to offset: [0], sizes: [%37, %38], strides: [64, 1] : memref<32x64xf16> to memref<?x?xf16, strided<[64, 1]>>
%base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %reinterpret_cast : memref<32x64xf16, strided<[?, 1], offset: ?>> -> memref<f16>, index, index, index, index, index
%reinterpret_cast_5 = memref.reinterpret_cast %base_buffer to offset: [%offset], sizes: [%37, %38], strides: [%strides#0, 1] : memref<f16> to memref<?x?xf16, strided<[?, 1], offset: ?>>
memref.copy %reinterpret_cast_4, %reinterpret_cast_5 : memref<?x?xf16, strided<[64, 1]>> to memref<?x?xf16, strided<[?, 1], offset: ?>>
return
}
} function that compiles the kernel above: def _ttsharedir_to_llir(ttsharedir: str):
with tempfile.TemporaryDirectory() as tmpdir:
ttshared_path = os.path.join(tmpdir, "ttshared.mlir")
llmlir_path = os.path.join(tmpdir, "ll.mlir")
llir_path = os.path.join(tmpdir, "ll.ir")
Path(ttshared_path).write_text(ttsharedir)
mlir_opt_path = _get_llvm_bin_path("mlir-opt")
# TritonShared-MLIR to LLVM-MLIR
subprocess.check_call([
mlir_opt_path,
ttshared_path,
"--one-shot-bufferize=allow-return-allocs-from-loops=true",
"--convert-arm-sme-to-llvm",
"--convert-vector-to-llvm=enable-arm-sve",
"--convert-arith-to-llvm",
"--convert-math-to-llvm",
"--convert-complex-to-llvm",
"--convert-func-to-llvm",
"--convert-index-to-llvm",
"--finalize-memref-to-llvm",
"--convert-scf-to-cf",
"--convert-cf-to-llvm",
"-o", llmlir_path
])
# LLVM-MLIR to LLVM-IR
mlir_translate_path = _get_llvm_bin_path("mlir-translate")
subprocess.check_call([mlir_translate_path, llmlir_path, "--mlir-to-llvmir", "-o", llir_path])
return Path(llir_path).read_text() output kernel before mlir-translate: |
You appear to be doing a f16 -> f32 matmul, that (for this example) requires slightly different tiling and passes. I think you need tile sizes of |
yea I might have forgotten that we are going from f16 ->f32 understood I guess the tiling logic has to be a bit different since this kernel uses a f32 acculumator. edit: was kind of confused about the outer-product-fusion thing, turns out these are pretty new and not in llvm commit 4017f04e that current triton branch uses @nhat-nguyen can this be bumped to triton hash ea9777d? |
got this working now too, have some concerns about future when I have to change dims from 1 -> 2 for widening but that can be worried about later. fmopa2_way is being produced right now have a minor issue with this: ll.mlir:8:10: error: LLVM Translation failed for operation: builtin.unrealized_conversion_cast
%3 = builtin.unrealized_conversion_cast %2 : !llvm.struct<(i64, ptr)> to memref<*xf16>
^ want to know if you have any idea how to remedy this and is the caused by something SME related or is it just an triton-shared/mlir thing? here my first mlir conversion before lowerings to llvm, think the order is right subprocess.check_call([mlir_opt_path, sme_first_pass,
"--canonicalize",
"--eliminate-empty-tensors",
"--convert-linalg-to-loops",
"--empty-tensor-to-alloc-tensor",
"--expand-strided-metadata",
"--arm-sme-vector-legalization",
"--convert-vector-to-arm-sme",
"--arm-sme-outer-product-fusion",
"--arm-sve-legalize-vector-storage",
"--convert-arith-to-arm-sme",
"--allocate-arm-sme-tiles",
"--convert-arm-sme-to-scf",
"--convert-vector-to-scf",
"-o",
mlir_sme_pass]) |
@danikhan632 |
I figured that much, think its got something todo with the way inputs are passed llvm.func @matmul_kernel_0d1d2d34567c89c1011c(%arg0: i64, %arg1: !llvm.ptr, %arg2: i64, %arg3: !llvm.ptr, %arg4: i64, %arg5: !llvm.ptr, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32, %arg15: i32, %arg16: i32, %arg17: i32) attributes {arm_sme.tiles_in_use = 34952 : i32} {
%0 = llvm.mlir.undef : !llvm.struct<(i64, ptr)>
%1 = llvm.insertvalue %arg4, %0[0] : !llvm.struct<(i64, ptr)>
%2 = llvm.insertvalue %arg5, %1[1] : !llvm.struct<(i64, ptr)>
%3 = builtin.unrealized_conversion_cast %2 : !llvm.struct<(i64, ptr)> to memref<*xf16> like I think the kernel is expecting an i64 value and a pointer to the inputs, but it gets a memref |
I see unrealized_conversion_cast errors when lowering to the llvm dialect. In my case, it's caused by that the user of this cast (%3 above) is not lowered to llvm dialect. I would check the dialect/op of the user and try find the pass(es) to lower it. |
...
%62 = arith.muli %vscale, %c4 : index
%63 = arith.muli %vscale, %c4 : index
%alloc_37 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
memref.copy %alloc, %alloc_37 : memref<32x64xf32> to memref<32x64xf32>
scf.for %arg19 = %c0 to %c32 step %62 {
scf.for %arg20 = %c0 to %c64 step %63 {
scf.for %arg21 = %c0 to %c16 step %c2 {
%alloca = memref.alloca() : memref<vector<2x[4]xf16>>
%alloca_38 = memref.alloca() : memref<vector<2x[4]xi1>>
...
AFTER MLIR PASSES
%33 = "arith.constant"() <{value = 2 : index}> : () -> index
%34 = "builtin.unrealized_conversion_cast"(%33) : (index) -> i64
%35 = "arith.constant"() <{value = dense<0.000000e+00> : vector<[4]xf16>}> : () -> vector<[4]xf16>
%36 = "arith.constant"() <{value = -1 : index}> : () -> index
%37 = "arith.constant"() <{value = dense<false> : vector<2x[4]xi1>}> : () -> vector<2x[4]xi1>
%38 = "builtin.unrealized_conversion_cast"(%37) : (vector<2x[4]xi1>) -> !llvm.array<2 x vector<[4]xi1>>
%39 = "builtin.unrealized_conversion_cast"(%21) : (index) -> i64
%40 = "builtin.unrealized_conversion_cast"(%21) : (index) -> i64
%41 = "llvm.mlir.constant"() <{value = 32 : index}> : () -> i64
...
AND MORE MLIR PASSES
%33 = "arith.constant"() <{value = 64 : i32}> : () -> i32
%34 = "arith.constant"() <{value = 32 : i32}> : () -> i32
%35 = "arith.constant"() <{value = 8 : i32}> : () -> i32
%36 = "arith.constant"() <{value = 4 : index}> : () -> index
%37 = "arith.constant"() <{value = 2 : index}> : () -> index
%38 = "builtin.unrealized_conversion_cast"(%37) : (index) -> i64
%39 = "arith.constant"() <{value = dense<0.000000e+00> : vector<[4]xf16>}> : () -> vector<[4]xf16>
%40 = "arith.constant"() <{value = -1 : index}> : () -> index
%41 = "arith.constant"() <{value = dense<false> : vector<2x[4]xi1>}> : () -> vector<2x[4]xi1>
%42 = "builtin.unrealized_conversion_cast"(%41) : (vector<2x[4]xi1>) -> !llvm.array<2 x vector<[4]xi1>>
%43 = "builtin.unrealized_conversion_cast"(%24) : (index) -> i64
<unknown>:0: error: failed to legalize operation 'builtin.unrealized_conversion_cast' that was explicitly marked illegal
<unknown>:0: note: see current operation: %38 = "builtin.unrealized_conversion_cast"(%37) : (i64) -> index |
I think it's more helpful to look at the users of an It looks to me like the |
yeah I don't think '--convert-arith-to-arm-sme' is really doing anything here, I wanted to vet btw that the kernel that I generated is legitimate and that the only thing that I should have to do is run it through mlir-opt and then through mlir-translate and it should be fine. I also figured that |
I think the allocas for the scalable vectors come from using default lowering of |
LLVM ERROR: Cannot select: t155: i64 = vscale Constant:i64<1024>
t154: i64 = Constant<1024>
In function: matmul_kernel_0d1d2d34567c89c1011c
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0. Program arguments: /home/green/.triton/llvm/llvm-6f44bb77-ubuntu-x64/bin/llc /tmp/tmp7lnmobsz/kernel.ll -o /tmp/tmp7lnmobsz/kernel.o
1. Running pass 'Function Pass Manager' on module '/tmp/tmp7lnmobsz/kernel.ll'.
2. Running pass 'X86 DAG->DAG Instruction Selection' on function '@matmul_kernel_0d1d2d34567c89c1011c' |
What flags are you using? To compile with |
ah I see, I think this is where the sme userspace emulator is needed, I think this is correct, going to switch over to my arm system to test it def _llir_to_bin(llir: str, metadata):
pattern = r"define void @(\w+)\(.+"
matches = re.findall(pattern, llir)
assert len(matches) == 1
metadata["name"] = matches[0]
with tempfile.TemporaryDirectory() as tmpdir:
src_path = os.path.join(tmpdir, "kernel.ll")
dst_path = os.path.join(tmpdir, "kernel.o")
Path(src_path).write_text(llir)
llc_path = _get_llvm_bin_path("llc")
subprocess.check_call(["/usr/bin/qemu-aarch64-static",llc_path, src_path, "--mattr=+sve,+sme","-o", dst_path])
# Actually it's text-format assembly. Use read_text().
return Path(dst_path).read_text() |
I did that and get this error, subprocess.check_call([ "/usr/bin/qemu-aarch64-static",llc_path, src_path, "--mattr=+sve,+sme","-o", dst_path]) output: /tmp/tmp1j2jbysw/kernel.s: Assembler messages:
/tmp/tmp1j2jbysw/kernel.s:126: Error: selected processor does not support `rdvl x8,#1'
/tmp/tmp1j2jbysw/kernel.s:130: Error: selected processor does not support `cntw x24'
/tmp/tmp1j2jbysw/kernel.s:409: Error: selected processor does not support `ptrue p2.s'
/tmp/tmp1j2jbysw/kernel.s:410: Error: selected processor does not support `index z6.s,#0,#1'
/tmp/tmp1j2jbysw/kernel.s:417: Error: selected processor does not support `incw x22'
/tmp/tmp1j2jbysw/kernel.s:420: Error: selected processor does not support `addvl x8,x8,#8'
/tmp/tmp1j2jbysw/kernel.s:438: Error: selected processor does not support `incw x25'
/tmp/tmp1j2jbysw/kernel.s:439: Error: selected processor does not support `addvl x20,x20,#1'
/tmp/tmp1j2jbysw/kernel.s:486: Error: selected processor does not support `index z6.s,#0,#1'
/tmp/tmp1j2jbysw/kernel.s:490: Error: selected processor does not support `ptrue p2.s'
/tmp/tmp1j2jbysw/kernel.s:513: Error: selected processor does not support `mov z0.s,w9'
/tmp/tmp1j2jbysw/kernel.s:514: Error: selected processor does not support `cmpgt p0.s,p2/z,z0.s,z6.s'
/tmp/tmp1j2jbysw/kernel.s:516: Error: selected processor does not support `ld1h {z0.s},p0/z,[x13,x10,lsl#1]'
/tmp/tmp1j2jbysw/kernel.s:519: Error: selected processor does not support `ld1h {z1.s},p0/z,[x11,x10,lsl#1]'
/tmp/tmp1j2jbysw/kernel.s:522: Error: unknown mnemonic `zero' -- `zero {za0.s}'
/tmp/tmp1j2jbysw/kernel.s:530: Error: operand 1 must be a list of SVE vector registers -- `ld1w {za0h.s[w12,0]},p0/z,[x13]'
/tmp/tmp1j2jbysw/kernel.s:536: Error: selected processor does not support `mov z2.s,w8'
/tmp/tmp1j2jbysw/kernel.s:539: Error: selected processor does not support `cmpgt p0.s,p2/z,z2.s,z6.s'
/tmp/tmp1j2jbysw/kernel.s:540: Error: selected processor does not support `mov z2.h,#0'
/tmp/tmp1j2jbysw/kernel.s:541: Error: selected processor does not support `mov z3.s,p0/z,#1'
/tmp/tmp1j2jbysw/kernel.s:554: Error: selected processor does not support `whilels p0.s,xzr,x11'
/tmp/tmp1j2jbysw/kernel.s:555: Error: selected processor does not support `lastb w13,p0,z3.s'
/tmp/tmp1j2jbysw/kernel.s:558: Error: selected processor does not support `mov z4.s,w11'
/tmp/tmp1j2jbysw/kernel.s:559: Error: selected processor does not support `cmpeq p0.s,p2/z,z6.s,z4.s'
/tmp/tmp1j2jbysw/kernel.s:561: Error: selected processor does not support `mov z2.h,p0/m,h4'
/tmp/tmp1j2jbysw/kernel.s:564: Error: selected processor does not support `mov z4.h,#0'
/tmp/tmp1j2jbysw/kernel.s:579: Error: selected processor does not support `whilels p0.s,xzr,x11'
/tmp/tmp1j2jbysw/kernel.s:580: Error: selected processor does not support `lastb w13,p0,z3.s'
/tmp/tmp1j2jbysw/kernel.s:583: Error: selected processor does not support `mov z5.s,w11'
/tmp/tmp1j2jbysw/kernel.s:584: Error: selected processor does not support `cmpeq p0.s,p2/z,z6.s,z5.s'
/tmp/tmp1j2jbysw/kernel.s:586: Error: selected processor does not support `mov z4.h,p0/m,h5'
/tmp/tmp1j2jbysw/kernel.s:589: Error: selected processor does not support `mov z3.s,w8'
/tmp/tmp1j2jbysw/kernel.s:590: Error: selected processor does not support `mov z5.s,w9'
/tmp/tmp1j2jbysw/kernel.s:593: Error: selected processor does not support `cmpgt p1.s,p2/z,z3.s,z6.s'
/tmp/tmp1j2jbysw/kernel.s:594: Error: selected processor does not support `cmpgt p0.s,p2/z,z5.s,z6.s'
/tmp/tmp1j2jbysw/kernel.s:595: Error: selected processor does not support `zip2 z3.s,z2.s,z4.s'
/tmp/tmp1j2jbysw/kernel.s:596: Error: selected processor does not support `zip1 z2.s,z2.s,z4.s'
/tmp/tmp1j2jbysw/kernel.s:597: Error: selected processor does not support `zip2 z4.s,z0.s,z1.s'
/tmp/tmp1j2jbysw/kernel.s:598: Error: selected processor does not support `zip1 z0.s,z0.s,z1.s'
/tmp/tmp1j2jbysw/kernel.s:600: Error: selected processor does not support `zip2 p2.s,p1.s,p1.s'
/tmp/tmp1j2jbysw/kernel.s:602: Error: selected processor does not support `zip1 p1.s,p1.s,p1.s'
/tmp/tmp1j2jbysw/kernel.s:603: Error: selected processor does not support `zip2 p3.s,p0.s,p0.s'
/tmp/tmp1j2jbysw/kernel.s:604: Error: selected processor does not support `uzp1 z1.h,z2.h,z3.h'
/tmp/tmp1j2jbysw/kernel.s:605: Error: selected processor does not support `uzp1 z0.h,z0.h,z4.h'
/tmp/tmp1j2jbysw/kernel.s:606: Error: selected processor does not support `zip1 p4.s,p0.s,p0.s'
/tmp/tmp1j2jbysw/kernel.s:607: Error: selected processor does not support `uzp1 p1.h,p1.h,p2.h'
/tmp/tmp1j2jbysw/kernel.s:608: Error: selected processor does not support `uzp1 p2.h,p4.h,p3.h'
/tmp/tmp1j2jbysw/kernel.s:609: Error: unknown mnemonic `fmopa' -- `fmopa za0.s,p1/m,p2/m,z1.h,z0.h'
/tmp/tmp1j2jbysw/kernel.s:617: Error: operand 1 must be a list of SVE vector registers -- `st1w {za0h.s[w12,0]},p0,[x13]' I took a look at the // %mcr_aarch64_cmd \
//-e=main -entry-point-result=void \
//-march=aarch64 -mattr="+sve,+sme" \
//-shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib | \
// RUN: FileCheck %s
|
7429fb8
to
83932b6
Compare
Generic AdviceIt's best to run the tests as part of the build process of MLIR (or afterwards) and then to copy the build commands from tests. CMake flags to run the SME integration tests are documented here: -DMLIR_INCLUDE_INTEGRATION_TESTS=On
-DMLIR_RUN_ARM_SME_TESTS=On
-DARM_EMULATOR_EXECUTABLE=<path-to-emulator> Then, during/after the build, you can either run all the tests: ninja check-mlir or just selected integration tests: cd <llvm-build-dir>
# Please adjust paths to match your system
bin/llvm-lit -va ../../mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir Note that I am using I would make sure that these tests work for you before trying to run things manually. Specific adviceAs you have noted, the tests will contain sth like this:
This is important - it means that ^^^ defines flags to be passed to subprocess.check_call([ "/usr/bin/qemu-aarch64-static",llc_path, src_path, "--mattr=+sve,+sme","-o", dst_path]) That's incorrect and won't work :) Now, it also looks like you are passing Also, we don't really use As for Suggestion
HTH :) (*) Unless you've cross-compiled it, but I highly doubt it. |
got it, I'm trying to pass llir through llc to compile to binary. I think mlir-cpu-runner will be good for IR tests but looking to run this E2E. Not sure if the MLIR CPU runner can do this -DARM_EMULATOR_EXECUTABLE=<path-to-emulator>
also is this the instruction emulator? https://developer.arm.com/Tools%20and%20Software/Arm%20Instruction%20Emulator |
ArmIE is one emulator, but based on the website it only support SVE and SVE2 (so no SME):
QEMU does support SME: https://qemu-project.gitlab.io/qemu/system/arm/cpu-features.html Btw, I forgot to answer your other question:
These are MLIR runtime libs - you will find them in the LLVM build directory under the
:) |
Hi, @danikhan632 I'm trying this PR out and I'm seeing a correctness issue where the result of the matmul doesn't match the expected output. I'm looking into it right now, but I was wondering if you were able to run the compiled test_matmul with QEMU. |
Yeah I've been able to get it to work with some caveats. |
Could you share those changes? Right now, I modified the generated launcher.cpp to build as an executable and then linking with the generated kernel.o, passing two tensors to |
Sure , give me a sec to push those changes, wanted to know what your output looks like btw, like is it off by a small amount or larger? |
Hmm, a lot of the elements, that differ from the expected output, are 0, but some aren't. |
I don't the widening works as of now so fp16 -> fp32 doesn't work. Change accumulatior to fp16 |
This is the change I'm trying and I'm not seeing any changes in the output. |
Ok let me try and fix that, broke my env so taking me longer than it should |
I've had issues recreating this behavior, could you reach out to [email protected] with more details? |
Since this is an optimization pass to existing ttshare output, I decided to make it its own binary, currently gets past optimization phase and fails on _ttsharedir_to_llir which is to be expected since it needs different mlir-opt flags. These flags have also just been recently updated.
Also trying to introduce bf16/f16 support as well as make the current optimization passes only apply to hardware that can support it.
There are more plans for optimization than just the tile and outerproduct approach as seen here but the current build does produce valid MLIR. Based this off the example shown here.
As of now only SVE2 can tested on real hardware which I don't have access to. SME will have to be emulated.
Not yet anywhere ready in a state to be merged but feedback would be appreciated.
Instructions to build
Same as normal however to see the optimized MLIR,
Usage
this is should cause the test to not compile and fail but the optimized MLIR should be printed in blue test
to turn this off just set
Below is the optimized MLIR produced from test_matmul.py