Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce unstructured-to-memref pass #216

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft

Conversation

nhat-nguyen
Copy link
Collaborator

@nhat-nguyen nhat-nguyen commented Jan 7, 2025

This PR introduces the unstructured-to-memref pass responsible for converting unstructured triton load / store ops to memref load / store ops. This is part of the work to allow triton-shared to lower gather / scatter pointer sequences. The pass is intended to be used after running --fold-unstructured-ptr.

Triton load op (gather) is lowered to a linalg.generic whose body contains a load from the offset indicated by the offset provided by tts.make_unstructured_tptr. For load op with mask, an inner-most scf.if is used to return a default value (or the other in tt.load if provided) if the corresponding mask value is false.

Example of a load:

  func.func @gather_simple_mask_with_other(%arg0: memref<*xf32>, %arg1: memref<*xf32>) {
      %cst = arith.constant -1.000000e+00 : f32
      %cast = memref.cast %arg0 : memref<*xf32> to memref<?xf32>
      %load_tensor = bufferization.to_tensor %cast restrict : memref<?xf32>
      %out = tensor.empty() : tensor<64xf32>
      %gather = linalg.generic {
        iterator_types = ["parallel"]
      } ins(%offset_tensor, %mask_tensor : tensor<64xi32>, tensor<64xi1>)
        outs(%out : tensor<64xf32>) {
      ^bb0(%offset: i32, %mask: i1, %out: f32):
        %yield = scf.if %mask -> (f32) {
          %index = arith.index_cast %offset : i32 to index
          %extracted = tensor.extract %load_tensor[%index] : tensor<?xf32>
          scf.yield %extracted : f32
        } else {
          scf.yield %cst : f32
        }
        linalg.yield %yield : f32
      } -> tensor<64xf32>

Triton store op (scatter) is lowered to an affine.for loop nest that stores the value to the appropriate offset provided by tts.make_unstructured_tptr. Store op with mask is also supported.

Example of a store:

  func.func @masked_gather_scatter(%arg0: memref<*xf32>, %arg1: memref<*xf32>) {
    %store_memref = memref.cast %arg1 : memref<*xf32> to memref<?xf32>
    affine.for %i = 0 to 4 {
      %mask_val = tensor.extract %mask[%i] : tensor<4xi1>
      scf.if %mask_val {
        %offset_val = tensor.extract %offset_tensor[%i] : tensor<4xi32>
        %store_value = tensor.extract %tensor[%i] : tensor<4xf32>
        %offset_index = arith.index_cast %offset_val : i32 to index
        memref.store %store_value, %store_memref[%offset_index] : memref<?xf32>
      }
    }

Intended lowering pipeline:

  • triton-to-structured (no changes):
    • analyzes structured addptr sequences
      • introduces tts.make_tptr %ptr_arg with offsets and strides
      • introduces tts.load and tts.store
    • leaves unstructured addptr sequences and their corresponding tt.load and tt.store intact
  • fold-unstructured-ptr (Introduce fold-unstructured-ptr pass #210):
    • converts all unstructured addptr sequences into sequences that compute pointer offsets
      • introduces tts.make_unstructured_tptr %ptr_arg %offsets
    • removes all tt.addptr
  • structured-to-memref (to be updated in a different PR):
    • currently converts everything to memref including scalar addptr and kernel arguments
    • will change to just convert ops in the tts dialect to memref with the exception of tts.make_unstructured_tptr
  • unstructured-to-memref (this PR):
    • converts the remaining unstructured tt.load, tt.store, and tts.make_unstructured_tptr into memref
  • triton-ptr-to-memref (Introduce triton-ptr-to-memref pass #211):
    • converts kernel arguments with pointer type to memref

@Nullkooland
Copy link
Contributor

@nhat-nguyen Hi, I tried your PR, and noticed that tt.store on unstructured ptrs tensor is lowered to an affine.for, for instance:

%ptrs = "tts.make_unstructured_tptr"(%base_ptr, %offsets) : (!tt.ptr<f16>, tensor<1024xi32>) -> tensor<1024x!tt.ptr<f16>>
tt.store %ptrs, %vals, %mask : tensor<1024x!tt.ptr<f16>>

is converted to:

%base_mem = memref.cast %base_ptr : memref<*xf16> to memref<?xf16>
affine.for %i = 0 to 1024 {
  %cond = tensor.extract %mask[%i] : tensor<1024xi1>
  scf.if %cond {
    %offset_i32 = tensor.extract %offsets[%i] : tensor<1024xi32>
    %val = tensor.extract %vals[%i] : tensor<1024xf16>
    %offset = arith.index_cast %offset_i32 : i32 to index
    memref.store %val, %base_mem[%offset] : memref<?xf16>
  }
}

However, as far as I know affine dialect is not actively used today and there is no pass to parallelize it, see MLIR discourse

So I suggest converting tt.store to linalg.generic, same as the conversion for for tt.load you implemented.
I understand that the linalg.generic op must have an output operand to imply iteration range, so you could add a dummy output tensor, like:

%base_mem = memref.cast %base_ptr : memref<*xf16> to memref<?xf16>
%cst = arith.constant 0.000000e+00 : f16
%dummy = tensor.empty() : tensor<1024xf16> // dummy output tensor.
%11 = linalg.generic {
  indexing_maps = [#map, #map, #map, #map],
  iterator_types = ["parallel"]
}
ins(%offsets, %vals, %mask : tensor<1024xi32>, tensor<1024xf16>, tensor<1024xi1>)
outs(%dummy: tensor<1024xf16>) {
^bb0(%offset_i32 : i32, %val : f16, %cond : i1, %out_dummy : f16):
  %dummy_val = scf.if %cond -> f16 {
    %offset = arith.index_cast %offset_i32 : i32 to index
    memref.store %val, %base_mem[%offset] : memref<?xf16>
    scf.yield %cst : f16
  } else {
    scf.yield %cst : f16
  }
  linalg.yield %dummy_val : f16
} -> tensor<1024xf16>

later we can lower the linalg.generic to scf.parallel loop with one-shot-bufferize and convert-linalg-to-parallel-loops, the dummy output is eliminated by dce:

%base_mem = memref.cast %base_ptr : memref<*xf16> to memref<?xf16>
scf.parallel (%i) = (%c0) to (%c1024) step (%c1) {
  %offset_i32 = memref.load %offsets_buf[%i] : memref<1024xi32>
  %val = memref.load %vals_buf[%i] : memref<1024xf16>
  %cond = memref.load %mask_buf[%i] : memref<1024xi1>
  scf.if %4 {
    %offset = arith.index_cast %offset_i32 : i32 to index
    memref.store %val, %base_mem[%offset] : memref<?xf16>
  }
  scf.reduce 
}

@nhat-nguyen
Copy link
Collaborator Author

@Nullkooland The issue with using linalg on tensor for these write operations is the whole op can be removed through canonicalization if we haven't converted to memref yet. I think that is a pretty big drawback -- worse than not being able to leverage the conversion to parallel loop.

With the following IR:

#map = affine_map<(d0) -> (d0)>
module {
  tt.func public @masked_gather_scatter(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) attributes {noinline = false} {
    %0 = builtin.unrealized_conversion_cast %arg1 : !tt.ptr<f32> to memref<*xf32>
    %1 = builtin.unrealized_conversion_cast %arg0 : !tt.ptr<f32> to memref<*xf32>
    %cst = arith.constant 9.900000e+01 : f32
    %dummy_const = arith.constant 1 : i1
    %c0_i32 = arith.constant 0 : i32
    %c1_i32 = arith.constant 1 : i32
    %c2_i32 = arith.constant 2 : i32
    %cst_0 = arith.constant dense<4> : tensor<4xi32>
    %cst_1 = arith.constant dense<64> : tensor<4xi32>
    %cst_2 = arith.constant dense<3> : tensor<4xi32>
    %2 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32>
    %3:2 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %2, %arg4 = %2) -> (tensor<4xi32>, tensor<4xi32>)  : i32 {
      %4 = arith.divsi %arg3, %cst_2 : tensor<4xi32>
      %5 = tt.splat %arg2 : i32 -> tensor<4xi32>
      %6 = arith.addi %4, %5 : tensor<4xi32>
      %7 = arith.cmpi slt, %6, %cst_1 : tensor<4xi32>
      %cast = memref.cast %1 : memref<*xf32> to memref<?xf32>
      %8 = bufferization.to_tensor %cast restrict : memref<?xf32>
      %9 = tensor.empty() : tensor<4xf32>
      %10 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%6, %7 : tensor<4xi32>, tensor<4xi1>) outs(%9 : tensor<4xf32>) {
      ^bb0(%in: i32, %in_4: i1, %out: f32):
        %15 = scf.if %in_4 -> (f32) {
          %16 = arith.index_cast %in : i32 to index
          %extracted = tensor.extract %8[%16] : tensor<?xf32>
          scf.yield %extracted : f32
        } else {
          scf.yield %cst : f32
        }
        linalg.yield %15 : f32
      } -> tensor<4xf32>
      %cast_3 = memref.cast %0 : memref<*xf32> to memref<?xf32>
      %11 = tensor.empty() : tensor<4xi1>
        %alloc = memref.alloc() : memref<4xi1>
      linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%10, %6, %7 : tensor<4xf32>, tensor<4xi32>, tensor<4xi1>) outs(%alloc : memref<4xi1>) {
      ^bb0(%in: f32, %in_4: i32, %in_5: i1, %out: i1):
        %15 = arith.index_cast %in_4 : i32 to index
        %yield = scf.if %in_5 -> i1 {
          memref.store %in, %cast_3[%15] : memref<?xf32>
          scf.yield %dummy_const : i1
        } else {
        scf.yield %dummy_const : i1
        }
        linalg.yield %yield : i1
      }
      %13 = arith.addi %6, %cst_0 : tensor<4xi32>
      %14 = arith.addi %arg4, %cst_0 : tensor<4xi32>
      scf.yield %13, %14 : tensor<4xi32>, tensor<4xi32>
    }
    tt.return
  }
}

--canonicalize will end up removing the whole body. Let me know if you know of any ways to prevent this. Otherwise, I think just leaving the scatter into the affine loop is the best we can do. The loop nest itself is pretty simple, so we can pattern match and parallelize it later if necessary.

@Nullkooland
Copy link
Contributor

@nhat-nguyen This might be an upstream MLIR bug, due to linalg op not implementing RecursiveMemoryEffects that take into account ops with memory side effect in its body, see llvm/llvm-project#114045.

You could apply this upstream fix patch to the llvm-project dependency and try again to see if this linalg.generic with memref.store in its body gets removed or not. If this works, I guess triton-shared needs to update the dependent triton version with newer dependent llvm-project version that includes this fix.

@Nullkooland
Copy link
Contributor

@nhat-nguyen I trited your example IR (with minor modification that outs %alloc = memref.alloc() : memref<4xi1> is changed to %alloc = tensor.empty() : tensor<4xi1> since we cannot mix tensor and memref I/O operands in linalg ops) using a triton-shared-opt with a llvm-project built from source with that llvm/llvm-project#114045 upstream fix.

triton-shared-opt --canonicalize --cse masked_gather_scatter.mlir

the output IR is:

#map = affine_map<(d0) -> (d0)>
module {
  tt.func public @masked_gather_scatter(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) attributes {noinline = false} {
    %cst = arith.constant dense<3> : tensor<4xi32>
    %cst_0 = arith.constant dense<64> : tensor<4xi32>
    %cst_1 = arith.constant dense<4> : tensor<4xi32>
    %c2_i32 = arith.constant 2 : i32
    %c1_i32 = arith.constant 1 : i32
    %c0_i32 = arith.constant 0 : i32
    %true = arith.constant true
    %cst_2 = arith.constant 9.900000e+01 : f32
    %0 = builtin.unrealized_conversion_cast %arg1 : !tt.ptr<f32> to memref<*xf32>
    %1 = builtin.unrealized_conversion_cast %arg0 : !tt.ptr<f32> to memref<*xf32>
    %2 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32>
    %3:2 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %2, %arg4 = %2) -> (tensor<4xi32>, tensor<4xi32>)  : i32 {
      %4 = arith.divsi %arg3, %cst : tensor<4xi32>
      %5 = tt.splat %arg2 : i32 -> tensor<4xi32>
      %6 = arith.addi %4, %5 : tensor<4xi32>
      %7 = arith.cmpi slt, %6, %cst_0 : tensor<4xi32>
      %cast = memref.cast %1 : memref<*xf32> to memref<?xf32>
      %8 = bufferization.to_tensor %cast restrict : memref<?xf32>
      %9 = tensor.empty() : tensor<4xf32>
      %10 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%6, %7 : tensor<4xi32>, tensor<4xi1>) outs(%9 : tensor<4xf32>) {
      ^bb0(%in: i32, %in_4: i1, %out: f32):
        %15 = scf.if %in_4 -> (f32) {
          %16 = arith.index_cast %in : i32 to index
          %extracted = tensor.extract %8[%16] : tensor<?xf32>
          scf.yield %extracted : f32
        } else {
          scf.yield %cst_2 : f32
        }
        linalg.yield %15 : f32
      } -> tensor<4xf32>
      %cast_3 = memref.cast %0 : memref<*xf32> to memref<?xf32>
      %11 = tensor.empty() : tensor<4xi1>
      %12 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%10, %6, %7 : tensor<4xf32>, tensor<4xi32>, tensor<4xi1>) outs(%11 : tensor<4xi1>) {
      ^bb0(%in: f32, %in_4: i32, %in_5: i1, %out: i1):
        %15 = arith.index_cast %in_4 : i32 to index
        scf.if %in_5 {
          memref.store %in, %cast_3[%15] : memref<?xf32>
        }
        linalg.yield %true : i1
      } -> tensor<4xi1>
      %13 = arith.addi %6, %cst_1 : tensor<4xi32>
      %14 = arith.addi %arg4, %cst_1 : tensor<4xi32>
      scf.yield %13, %14 : tensor<4xi32>, tensor<4xi32>
    }
    tt.return
  }
}

The linalg.generic with memref.store is not removed by canonicalization.

@nhat-nguyen
Copy link
Collaborator Author

That is perfect. Sorry about the incorrect IR (I was playing around with different ways to get this to work and forgot to revert). Looks like we need to update triton which in turn will update llvm to get the fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants