-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce unstructured-to-memref
pass
#216
base: main
Are you sure you want to change the base?
Conversation
@nhat-nguyen Hi, I tried your PR, and noticed that %ptrs = "tts.make_unstructured_tptr"(%base_ptr, %offsets) : (!tt.ptr<f16>, tensor<1024xi32>) -> tensor<1024x!tt.ptr<f16>>
tt.store %ptrs, %vals, %mask : tensor<1024x!tt.ptr<f16>> is converted to: %base_mem = memref.cast %base_ptr : memref<*xf16> to memref<?xf16>
affine.for %i = 0 to 1024 {
%cond = tensor.extract %mask[%i] : tensor<1024xi1>
scf.if %cond {
%offset_i32 = tensor.extract %offsets[%i] : tensor<1024xi32>
%val = tensor.extract %vals[%i] : tensor<1024xf16>
%offset = arith.index_cast %offset_i32 : i32 to index
memref.store %val, %base_mem[%offset] : memref<?xf16>
}
} However, as far as I know So I suggest converting %base_mem = memref.cast %base_ptr : memref<*xf16> to memref<?xf16>
%cst = arith.constant 0.000000e+00 : f16
%dummy = tensor.empty() : tensor<1024xf16> // dummy output tensor.
%11 = linalg.generic {
indexing_maps = [#map, #map, #map, #map],
iterator_types = ["parallel"]
}
ins(%offsets, %vals, %mask : tensor<1024xi32>, tensor<1024xf16>, tensor<1024xi1>)
outs(%dummy: tensor<1024xf16>) {
^bb0(%offset_i32 : i32, %val : f16, %cond : i1, %out_dummy : f16):
%dummy_val = scf.if %cond -> f16 {
%offset = arith.index_cast %offset_i32 : i32 to index
memref.store %val, %base_mem[%offset] : memref<?xf16>
scf.yield %cst : f16
} else {
scf.yield %cst : f16
}
linalg.yield %dummy_val : f16
} -> tensor<1024xf16> later we can lower the %base_mem = memref.cast %base_ptr : memref<*xf16> to memref<?xf16>
scf.parallel (%i) = (%c0) to (%c1024) step (%c1) {
%offset_i32 = memref.load %offsets_buf[%i] : memref<1024xi32>
%val = memref.load %vals_buf[%i] : memref<1024xf16>
%cond = memref.load %mask_buf[%i] : memref<1024xi1>
scf.if %4 {
%offset = arith.index_cast %offset_i32 : i32 to index
memref.store %val, %base_mem[%offset] : memref<?xf16>
}
scf.reduce
} |
@Nullkooland The issue with using linalg on tensor for these write operations is the whole op can be removed through canonicalization if we haven't converted to memref yet. I think that is a pretty big drawback -- worse than not being able to leverage the conversion to parallel loop. With the following IR: #map = affine_map<(d0) -> (d0)>
module {
tt.func public @masked_gather_scatter(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) attributes {noinline = false} {
%0 = builtin.unrealized_conversion_cast %arg1 : !tt.ptr<f32> to memref<*xf32>
%1 = builtin.unrealized_conversion_cast %arg0 : !tt.ptr<f32> to memref<*xf32>
%cst = arith.constant 9.900000e+01 : f32
%dummy_const = arith.constant 1 : i1
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c2_i32 = arith.constant 2 : i32
%cst_0 = arith.constant dense<4> : tensor<4xi32>
%cst_1 = arith.constant dense<64> : tensor<4xi32>
%cst_2 = arith.constant dense<3> : tensor<4xi32>
%2 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32>
%3:2 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %2, %arg4 = %2) -> (tensor<4xi32>, tensor<4xi32>) : i32 {
%4 = arith.divsi %arg3, %cst_2 : tensor<4xi32>
%5 = tt.splat %arg2 : i32 -> tensor<4xi32>
%6 = arith.addi %4, %5 : tensor<4xi32>
%7 = arith.cmpi slt, %6, %cst_1 : tensor<4xi32>
%cast = memref.cast %1 : memref<*xf32> to memref<?xf32>
%8 = bufferization.to_tensor %cast restrict : memref<?xf32>
%9 = tensor.empty() : tensor<4xf32>
%10 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%6, %7 : tensor<4xi32>, tensor<4xi1>) outs(%9 : tensor<4xf32>) {
^bb0(%in: i32, %in_4: i1, %out: f32):
%15 = scf.if %in_4 -> (f32) {
%16 = arith.index_cast %in : i32 to index
%extracted = tensor.extract %8[%16] : tensor<?xf32>
scf.yield %extracted : f32
} else {
scf.yield %cst : f32
}
linalg.yield %15 : f32
} -> tensor<4xf32>
%cast_3 = memref.cast %0 : memref<*xf32> to memref<?xf32>
%11 = tensor.empty() : tensor<4xi1>
%alloc = memref.alloc() : memref<4xi1>
linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%10, %6, %7 : tensor<4xf32>, tensor<4xi32>, tensor<4xi1>) outs(%alloc : memref<4xi1>) {
^bb0(%in: f32, %in_4: i32, %in_5: i1, %out: i1):
%15 = arith.index_cast %in_4 : i32 to index
%yield = scf.if %in_5 -> i1 {
memref.store %in, %cast_3[%15] : memref<?xf32>
scf.yield %dummy_const : i1
} else {
scf.yield %dummy_const : i1
}
linalg.yield %yield : i1
}
%13 = arith.addi %6, %cst_0 : tensor<4xi32>
%14 = arith.addi %arg4, %cst_0 : tensor<4xi32>
scf.yield %13, %14 : tensor<4xi32>, tensor<4xi32>
}
tt.return
}
}
|
@nhat-nguyen This might be an upstream MLIR bug, due to You could apply this upstream fix patch to the |
@nhat-nguyen I trited your example IR (with minor modification that outs triton-shared-opt --canonicalize --cse masked_gather_scatter.mlir the output IR is: #map = affine_map<(d0) -> (d0)>
module {
tt.func public @masked_gather_scatter(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) attributes {noinline = false} {
%cst = arith.constant dense<3> : tensor<4xi32>
%cst_0 = arith.constant dense<64> : tensor<4xi32>
%cst_1 = arith.constant dense<4> : tensor<4xi32>
%c2_i32 = arith.constant 2 : i32
%c1_i32 = arith.constant 1 : i32
%c0_i32 = arith.constant 0 : i32
%true = arith.constant true
%cst_2 = arith.constant 9.900000e+01 : f32
%0 = builtin.unrealized_conversion_cast %arg1 : !tt.ptr<f32> to memref<*xf32>
%1 = builtin.unrealized_conversion_cast %arg0 : !tt.ptr<f32> to memref<*xf32>
%2 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32>
%3:2 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %2, %arg4 = %2) -> (tensor<4xi32>, tensor<4xi32>) : i32 {
%4 = arith.divsi %arg3, %cst : tensor<4xi32>
%5 = tt.splat %arg2 : i32 -> tensor<4xi32>
%6 = arith.addi %4, %5 : tensor<4xi32>
%7 = arith.cmpi slt, %6, %cst_0 : tensor<4xi32>
%cast = memref.cast %1 : memref<*xf32> to memref<?xf32>
%8 = bufferization.to_tensor %cast restrict : memref<?xf32>
%9 = tensor.empty() : tensor<4xf32>
%10 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%6, %7 : tensor<4xi32>, tensor<4xi1>) outs(%9 : tensor<4xf32>) {
^bb0(%in: i32, %in_4: i1, %out: f32):
%15 = scf.if %in_4 -> (f32) {
%16 = arith.index_cast %in : i32 to index
%extracted = tensor.extract %8[%16] : tensor<?xf32>
scf.yield %extracted : f32
} else {
scf.yield %cst_2 : f32
}
linalg.yield %15 : f32
} -> tensor<4xf32>
%cast_3 = memref.cast %0 : memref<*xf32> to memref<?xf32>
%11 = tensor.empty() : tensor<4xi1>
%12 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%10, %6, %7 : tensor<4xf32>, tensor<4xi32>, tensor<4xi1>) outs(%11 : tensor<4xi1>) {
^bb0(%in: f32, %in_4: i32, %in_5: i1, %out: i1):
%15 = arith.index_cast %in_4 : i32 to index
scf.if %in_5 {
memref.store %in, %cast_3[%15] : memref<?xf32>
}
linalg.yield %true : i1
} -> tensor<4xi1>
%13 = arith.addi %6, %cst_1 : tensor<4xi32>
%14 = arith.addi %arg4, %cst_1 : tensor<4xi32>
scf.yield %13, %14 : tensor<4xi32>, tensor<4xi32>
}
tt.return
}
} The |
That is perfect. Sorry about the incorrect IR (I was playing around with different ways to get this to work and forgot to revert). Looks like we need to update triton which in turn will update llvm to get the fix. |
This PR introduces the
unstructured-to-memref
pass responsible for converting unstructured triton load / store ops to memref load / store ops. This is part of the work to allow triton-shared to lower gather / scatter pointer sequences. The pass is intended to be used after running--fold-unstructured-ptr
.Triton load op (gather) is lowered to a
linalg.generic
whose body contains a load from the offset indicated by the offset provided bytts.make_unstructured_tptr
. For load op with mask, an inner-mostscf.if
is used to return a default value (or theother
intt.load
if provided) if the corresponding mask value is false.Example of a load:
Triton store op (scatter) is lowered to an
affine.for
loop nest that stores the value to the appropriate offset provided bytts.make_unstructured_tptr
. Store op with mask is also supported.Example of a store:
Intended lowering pipeline:
tts.make_tptr %ptr_arg with offsets and strides
tts.load
andtts.store
tt.load
andtt.store
intactfold-unstructured-ptr
pass #210):tts.make_unstructured_tptr %ptr_arg %offsets
tt.addptr
tts
dialect tomemref
with the exception oftts.make_unstructured_tptr
tt.load
,tt.store
, andtts.make_unstructured_tptr
into memreftriton-ptr-to-memref
pass #211):