Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Failed to lower TorchInductor generated repeat kernel #194

Open
Nullkooland opened this issue Nov 27, 2024 · 0 comments
Open

[Bug]: Failed to lower TorchInductor generated repeat kernel #194

Nullkooland opened this issue Nov 27, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@Nullkooland
Copy link
Contributor

Nullkooland commented Nov 27, 2024

Triton python code

import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[524288], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cpu', index=None, cc='', major=None, regs_per_multiprocessor=None, max_threads_per_multi_processor=None, multi_processor_count=None), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(1, 2), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_repeat_0', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'E9AC141F66F25ECE51A81A7F95D3277C38AC42924354810A6DB4E2121081F21C', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 393216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)
    x0 = xindex % 1024
    x1 = (xindex // 1024)
    x2 = xindex
    tmp0 = tl.load(in_ptr0 + ((512*(x1 % 128)) + (x0 % 512)), None)
    tl.store(tl.make_block_ptr(out_ptr0, shape=[393216], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tl.broadcast_to(tmp0, [XBLOCK]).to(tl.float16))

Triton IR

module {
  tt.func public @triton_(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
    %c1_i64 = arith.constant 1 : i64
    %c393216_i64 = arith.constant 393216 : i64
    %cst = arith.constant dense<512> : tensor<1024xi32>
    %cst_0 = arith.constant dense<128> : tensor<1024xi32>
    %cst_1 = arith.constant dense<1024> : tensor<1024xi32>
    %c1024_i32 = arith.constant 1024 : i32
    %0 = tt.get_program_id x : i32
    %1 = arith.muli %0, %c1024_i32 : i32
    %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
    %3 = tt.splat %1 : i32 -> tensor<1024xi32>
    %4 = arith.addi %3, %2 : tensor<1024xi32>
    %5 = arith.remsi %4, %cst_1 : tensor<1024xi32>
    %6 = arith.divsi %4, %cst_1 : tensor<1024xi32>
    %7 = arith.remsi %6, %cst_0 : tensor<1024xi32>
    %8 = arith.muli %7, %cst : tensor<1024xi32>
    %9 = arith.remsi %5, %cst : tensor<1024xi32>
    %10 = arith.addi %8, %9 : tensor<1024xi32>
    %11 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<1024x!tt.ptr<f16>>
    %12 = tt.addptr %11, %10 : tensor<1024x!tt.ptr<f16>>, tensor<1024xi32>
    %13 = tt.load %12 : tensor<1024x!tt.ptr<f16>>
    %14 = tt.make_tensor_ptr %arg1, [%c393216_i64], [%c1_i64], [%1] {order = array<i32: 0>} : <tensor<1024xf16>>
    tt.store %14, %13 : !tt.ptr<tensor<1024xf16>>
    tt.return
  }
}

Crash log

PtrAnalysis: encountered addptr operand produced by an unsupported operation
%6 = arith.divsi %4, %cst_1 : tensor<1024xi32>
output/kernels/triton/repeat.mlir:21:11: remark: PtrAnalysis: Failed to rewrite AddPtrOp
    %12 = tt.addptr %11, %10 : tensor<1024x!tt.ptr<f16>>, tensor<1024xi32>
          ^
output/kernels/triton/repeat.mlir:21:11: note: see current operation: %12 = tt.addptr %11, %10 : tensor<1024x!tt.ptr<f16>>, tensor<1024xi32>
output/kernels/triton/repeat.mlir:22:11: remark: PtrAnalysis: pointer is not replace with tts.make_tptr so loadOp cannot be rewritten
    %13 = tt.load %12 : tensor<1024x!tt.ptr<f16>>
          ^
output/kernels/triton/repeat.mlir:22:11: note: see current operation: %13 = tt.load %12 : tensor<1024x!tt.ptr<f16>>
output/kernels/triton/repeat.mlir:22:11: remark: PtrAnalysis: Failed to rewrite LoadOp
    %13 = tt.load %12 : tensor<1024x!tt.ptr<f16>>
          ^
output/kernels/triton/repeat.mlir:22:11: note: see current operation: %13 = tt.load %12 : tensor<1024x!tt.ptr<f16>>
triton-shared-opt: triton_shared/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp:185: std::optional<SmallVector<Value>> (anonymous namespace)::buildCastAndOffsetOps(OpBuilder &, TypeRange, Value, Location): Assertion `castOp && "Unexpected defining op for input of type tt.ptr"' failed.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.	Program arguments: triton-shared-opt --triton-to-linalg-experimental output/kernels/triton/repeat.mlir
 #0 0x000055e1d0d953a7 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (triton_shared/triton/python/build/cmake.linux-x86_64-cpython-3.10/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x32783a7)
 #1 0x000055e1d0d92ece llvm::sys::RunSignalHandlers() (triton_shared/triton/python/build/cmake.linux-x86_64-cpython-3.10/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x3275ece)
 #2 0x000055e1d0d95a5f SignalHandler(int) Signals.cpp:0:0
 #3 0x00007f1c72106520 (/lib/x86_64-linux-gnu/libc.so.6+0x42520)
 #4 0x00007f1c7215a9fc pthread_kill (/lib/x86_64-linux-gnu/libc.so.6+0x969fc)
 #5 0x00007f1c72106476 gsignal (/lib/x86_64-linux-gnu/libc.so.6+0x42476)
 #6 0x00007f1c720ec7f3 abort (/lib/x86_64-linux-gnu/libc.so.6+0x287f3)
 #7 0x00007f1c720ec71b (/lib/x86_64-linux-gnu/libc.so.6+0x2871b)
 #8 0x00007f1c720fde96 (/lib/x86_64-linux-gnu/libc.so.6+0x39e96)
 #9 0x000055e1cedb6020 (anonymous namespace)::buildCastAndOffsetOps(mlir::OpBuilder&, mlir::TypeRange, mlir::Value, mlir::Location) triton_shared/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp:200:0
#10 0x000055e1cedb7bec std::_Function_handler<std::optional<llvm::SmallVector<mlir::Value, 6u>> (mlir::OpBuilder&, mlir::TypeRange, mlir::Value, mlir::Location), std::optional<llvm::SmallVector<mlir::Value, 6u>> (*)(mlir::OpBuilder&, mlir::TypeRange, mlir::Value, mlir::Location)>::_M_invoke(std::_Any_data const&, mlir::OpBuilder&, mlir::TypeRange&&, mlir::Value&&, mlir::Location&&) /usr/bin/../lib/gcc/x86_64-linux-gnu/12/../../../../include/c++/12/bits/std_function.h:290:2
#11 0x000055e1d0a4ddf4 mlir::OneToNTypeConverter::materializeTargetConversion(mlir::OpBuilder&, mlir::Location, mlir::TypeRange, mlir::Value) const (triton_shared/triton/python/build/cmake.linux-x86_64-cpython-3.10/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x2f30df4)
#12 0x000055e1d0a500d9 mlir::applyPartialOneToNConversion(mlir::Operation*, mlir::OneToNTypeConverter&, mlir::FrozenRewritePatternSet const&) (triton_shared/triton/python/build/cmake.linux-x86_64-cpython-3.10/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x2f330d9)
#13 0x000055e1cedb4b25 (anonymous namespace)::StructuredToMemrefPass::convertAddPtrToReinterpretCast() triton_shared/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp:338:16
#14 0x000055e1cedb4b25 (anonymous namespace)::StructuredToMemrefPass::runOnOperation() triton_shared/lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp:360:16
...

Additional information

The mm kernel is the codegen result of torch-inductor, from:

x = torch.randn(size=(128, 512), dtype=torch.float16, device=device)

@torch.compile(fullgraph=True)
def test_func(x: torch.Tensor) -> torch.Tensor:
    out = x.repeat([3, 2])
    return out
@Nullkooland Nullkooland added the bug Something isn't working label Nov 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant