Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: not support scatter tt.load for tmp6 #177

Open
GoodNight-bye opened this issue Sep 20, 2024 · 0 comments
Open

[Bug]: not support scatter tt.load for tmp6 #177

GoodNight-bye opened this issue Sep 20, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@GoodNight-bye
Copy link

Triton python code

import torch
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math


# @triton_heuristics.persistent_reduction(
#     size_hints=[128, 1024],
#     reduction_hint=ReductionHint.INNER,
#     filename=__file__,
#     triton_meta={'signature': {0: '*i32', 1: '*fp32', 2: '*i64', 3: '*fp32', 4: '*i64', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*fp32', 9: '*fp32', 10: '*fp32', 11: '*fp32', 12: 'i32', 13: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=86, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=82), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13), equal_to_1=())]},
#     inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 5, 'num_reduction': 4, 'backend_hash': 'EF1B67C46BD8509C594D7E16C1F960076FB4922E1F91617C724FA83EF8449AB6', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
# )
@triton.jit
def triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, out_ptr0, out_ptr3, out_ptr4, out_ptr5, xnumel, rnumel):
    xnumel = 128
    XBLOCK: tl.constexpr = 1
    rnumel = 768
    RBLOCK: tl.constexpr = 1024
    xoffset = tl.program_id(0) * XBLOCK
    xindex = tl.full([1], xoffset, tl.int32)
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[:]
    roffset = 0
    rmask = rindex < rnumel
    x0 = xindex
    r1 = rindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask, eviction_policy='evict_last')
    tmp7 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
    tmp15 = tl.load(in_ptr4 + (x0), xmask, eviction_policy='evict_last')
    tmp46 = tl.load(in_ptr6 + (r1), rmask, eviction_policy='evict_last', other=0.0)
    tmp48 = tl.load(in_ptr7 + (r1), rmask, eviction_policy='evict_last', other=0.0)
    tmp1 = tl.full([RBLOCK], 21128, tl.int32)
    tmp2 = tmp0 + tmp1
    tmp3 = tmp0 < 0
    tmp4 = tl.where(tmp3, tmp2, tmp0)
    tl.device_assert(((0 <= tmp4) & (tmp4 < 21128)) | ~(xmask), "index out of bounds: 0 <= tmp4 < 21128")
    tmp6 = tl.load(in_ptr1 + (r1 + (768*tmp4)), rmask & xmask, other=0.0)
    tmp8 = tl.full([RBLOCK], 2, tl.int32)
    tmp9 = tmp7 + tmp8
    tmp10 = tmp7 < 0
    tmp11 = tl.where(tmp10, tmp9, tmp7)
    tl.device_assert(((0 <= tmp11) & (tmp11 < 2)) | ~(xmask), "index out of bounds: 0 <= tmp11 < 2")
    tmp13 = tl.load(in_ptr3 + (r1 + (768*tmp11)), rmask & xmask, other=0.0)
    tmp14 = tmp6 + tmp13
    tmp16 = tl.full([RBLOCK], 512, tl.int32)
    tmp17 = tmp15 + tmp16
    tmp18 = tmp15 < 0
    tmp19 = tl.where(tmp18, tmp17, tmp15)
    tl.device_assert(((0 <= tmp19) & (tmp19 < 512)) | ~(xmask), "index out of bounds: 0 <= tmp19 < 512")
    tmp21 = tl.load(in_ptr5 + (r1 + (768*tmp19)), rmask & xmask, other=0.0)
    tmp22 = tmp14 + tmp21
    tmp23 = tl.broadcast_to(tmp22, [RBLOCK])
    tmp25 = tl.where(rmask & xmask, tmp23, 0)
    tmp26 = tl.broadcast_to(tmp23, [RBLOCK])
    tmp28 = tl.where(rmask & xmask, tmp26, 0)
    tmp29 = triton_helpers.promote_to_tensor(tl.sum(tmp28, 0))
    tmp30 = tl.full([1], 768, tl.int32)
    tmp31 = tmp30.to(tl.float32)
    tmp32 = tmp29 / tmp31
    tmp33 = tmp23 - tmp32
    tmp34 = tmp33 * tmp33
    tmp35 = tl.broadcast_to(tmp34, [RBLOCK])
    tmp37 = tl.where(rmask & xmask, tmp35, 0)
    tmp38 = triton_helpers.promote_to_tensor(tl.sum(tmp37, 0))
    tmp39 = tmp22 - tmp32
    tmp40 = 768.0
    tmp41 = tmp38 / tmp40
    tmp42 = 1e-12
    tmp43 = tmp41 + tmp42
    tmp44 = libdevice.rsqrt(tmp43)
    tmp45 = tmp39 * tmp44
    tmp47 = tmp45 * tmp46
    tmp49 = tmp47 + tmp48
    tmp50 = 0.0013020833333333333
    tmp51 = tmp44 * tmp50
    tl.store(out_ptr0 + (r1 + (768*x0)), tmp22, rmask & xmask)
    tl.store(out_ptr3 + (r1 + (768*x0)), tmp45, rmask & xmask)
    tl.store(out_ptr4 + (r1 + (768*x0)), tmp49, rmask & xmask)
    tl.store(out_ptr5 + (x0), tmp51, xmask)


def test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0(device):
    # primals_1 = rand_strided((21128, 768), (768, 1), device='cuda:0', dtype=torch.float32)
    # primals_2 = rand_strided((2, 768), (768, 1), device='cuda:0', dtype=torch.float32)
    # primals_3 = rand_strided((512, 768), (768, 1), device='cuda:0', dtype=torch.float32)
    # primals_4 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
    # primals_5 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
    # primals_200 = rand_strided((1, 512), (512, 1), device='cuda:0', dtype=torch.int64)
    # primals_201 = rand_strided((1, 512), (512, 1), device='cuda:0', dtype=torch.int64)
    # primals_202 = rand_strided((1, 128), (128, 1), device='cuda:0', dtype=torch.int32)

    primals_1 = torch.randn((21128, 768), device=device, dtype=torch.float32)
    primals_2 = torch.randn((2, 768), device=device, dtype=torch.float32)
    primals_3 = torch.randn((512, 768), device=device, dtype=torch.float32)
    primals_4 = torch.randn((768, ), device=device, dtype=torch.float32)
    primals_5 = torch.randn((768, ), device=device, dtype=torch.float32)
    primals_200 = torch.randint(0, 2, (1, 512), device=device, dtype=torch.int64)
    primals_201 = torch.randint(0, 512, (1, 512), device=device, dtype=torch.int64)
    primals_202 = torch.randint(0, 21128, (1, 128), device=device, dtype=torch.int32)

    # buf0 = empty_strided_cuda((1, 128, 768), (98304, 768, 1), torch.float32)
    # buf4 = empty_strided_cuda((1, 128, 768), (98304, 768, 1), torch.float32)
    # buf6 = empty_strided_cuda((1, 128, 768), (98304, 768, 1), torch.float32)
    # buf297 = empty_strided_cuda((1, 128, 1), (128, 1, 1), torch.float32)
    buf0 = torch.empty((1, 128, 768), device=device, dtype=torch.float32)
    buf4 = torch.empty((1, 128, 768), device=device, dtype=torch.float32)
    buf6 = torch.empty((1, 128, 768), device=device, dtype=torch.float32)
    buf297 = torch.empty((1, 128, 1), device=device, dtype=torch.float32)

    # triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.run(primals_202, primals_1, primals_200, primals_2, primals_201, primals_3, primals_4, primals_5, buf0, buf4, buf6, buf297, 128, 768, grid=grid(128), stream=stream0)
    grid = (128,)
    triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0[grid](primals_202, primals_1, primals_200, primals_2, primals_201, primals_3, primals_4, primals_5, buf0, buf4, buf6, buf297, 128, 768)

Triton IR

#loc = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":17:0)
#loc1 = loc(unknown)
module {
  tt.func public @triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0(%arg0: !tt.ptr<i32> {tt.divisibility = 16 : i32} loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":17:0), %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":17:0), %arg2: !tt.ptr<i64> {tt.divisibility = 16 : i32} loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":17:0), %arg3: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":17:0), %arg4: !tt.ptr<i64> {tt.divisibility = 16 : i32} loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":17:0), %arg5: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":17:0), %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":17:0), %arg7: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":17:0), %arg8: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":17:0), %arg9: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":17:0), %arg10: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":17:0), %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":17:0), %arg12: i32 {tt.divisibility = 16 : i32} loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":17:0), %arg13: i32 {tt.divisibility = 16 : i32} loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":17:0)) attributes {noinline = false} {
    %cst = arith.constant dense<0.000000e+00> : tensor<1xf32> loc(#loc1)
    %cst_0 = arith.constant dense<512> : tensor<1024xi64> loc(#loc1)
    %cst_1 = arith.constant dense<768> : tensor<1024xi64> loc(#loc1)
    %cst_2 = arith.constant dense<0> : tensor<1xi64> loc(#loc1)
    %cst_3 = arith.constant dense<2> : tensor<1024xi64> loc(#loc1)
    %cst_4 = arith.constant dense<0.00130208337> : tensor<1xf32> loc(#loc1)
    %cst_5 = arith.constant dense<9.99999996E-13> : tensor<1xf32> loc(#loc1)
    %cst_6 = arith.constant dense<7.680000e+02> : tensor<1xf32> loc(#loc1)
    %cst_7 = arith.constant dense<768> : tensor<1xi32> loc(#loc1)
    %cst_8 = arith.constant dense<0> : tensor<1xi32> loc(#loc1)
    %cst_9 = arith.constant dense<21128> : tensor<1024xi32> loc(#loc1)
    %cst_10 = arith.constant dense<0.000000e+00> : tensor<1024xf32> loc(#loc1)
    %cst_11 = arith.constant dense<768> : tensor<1024xi32> loc(#loc1)
    %cst_12 = arith.constant dense<128> : tensor<1xi32> loc(#loc1)
    %0 = tt.get_program_id x : i32 loc(#loc2)
    %1 = tt.splat %0 : i32 -> tensor<1xi32> loc(#loc3)
    %2 = arith.cmpi slt, %1, %cst_12 : tensor<1xi32> loc(#loc4)
    %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> loc(#loc5)
    %4 = arith.cmpi slt, %3, %cst_11 : tensor<1024xi32> loc(#loc6)
    %5 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<1x!tt.ptr<i32>> loc(#loc7)
    %6 = tt.addptr %5, %1 : tensor<1x!tt.ptr<i32>>, tensor<1xi32> loc(#loc7)
    %7 = tt.load %6, %2 evictionPolicy = evict_last : tensor<1x!tt.ptr<i32>> loc(#loc8)
    %8 = tt.splat %arg2 : !tt.ptr<i64> -> tensor<1x!tt.ptr<i64>> loc(#loc9)
    %9 = tt.addptr %8, %1 : tensor<1x!tt.ptr<i64>>, tensor<1xi32> loc(#loc9)
    %10 = tt.load %9, %2 evictionPolicy = evict_last : tensor<1x!tt.ptr<i64>> loc(#loc10)
    %11 = tt.splat %arg4 : !tt.ptr<i64> -> tensor<1x!tt.ptr<i64>> loc(#loc11)
    %12 = tt.addptr %11, %1 : tensor<1x!tt.ptr<i64>>, tensor<1xi32> loc(#loc11)
    %13 = tt.load %12, %2 evictionPolicy = evict_last : tensor<1x!tt.ptr<i64>> loc(#loc12)
    %14 = tt.splat %arg6 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc13)
    %15 = tt.addptr %14, %3 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc13)
    %16 = tt.load %15, %4, %cst_10 evictionPolicy = evict_last : tensor<1024x!tt.ptr<f32>> loc(#loc14)
    %17 = tt.splat %arg7 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc15)
    %18 = tt.addptr %17, %3 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc15)
    %19 = tt.load %18, %4, %cst_10 evictionPolicy = evict_last : tensor<1024x!tt.ptr<f32>> loc(#loc16)
    %20 = tt.broadcast %7 : tensor<1xi32> -> tensor<1024xi32> loc(#loc17)
    %21 = arith.addi %20, %cst_9 : tensor<1024xi32> loc(#loc17)
    %22 = arith.cmpi slt, % , %cst_8 : tensor<1xi32> loc(#loc18)
    %23 = tt.broadcast %22 : tensor<1xi1> -> tensor<1024xi1> loc(#loc19)
    %24 = arith.select %23, %21, %20 : tensor<1024xi1>, tensor<1024xi32> loc(#loc19)
    %25 = arith.muli %24, %cst_11 : tensor<1024xi32> loc(#loc20)
    %26 = arith.addi %3, %25 : tensor<1024xi32> loc(#loc21)
    %27 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc22)
    %28 = tt.addptr %27, %26 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc22)
    %29 = tt.broadcast %2 : tensor<1xi1> -> tensor<1024xi1> loc(#loc23)
    %30 = arith.andi %4, %29 : tensor<1024xi1> loc(#loc23)
    %31 = tt.load %28, %30, %cst_10 : tensor<1024x!tt.ptr<f32>> loc(#loc24)
    %32 = tt.broadcast %10 : tensor<1xi64> -> tensor<1024xi64> loc(#loc25)
    %33 = arith.addi %32, %cst_3 : tensor<1024xi64> loc(#loc25)
    %34 = arith.cmpi slt, %10, %cst_2 : tensor<1xi64> loc(#loc26)
    %35 = tt.broadcast %34 : tensor<1xi1> -> tensor<1024xi1> loc(#loc27)
    %36 = arith.select %35, %33, %32 : tensor<1024xi1>, tensor<1024xi64> loc(#loc27)
    %37 = arith.muli %36, %cst_1 : tensor<1024xi64> loc(#loc28)
    %38 = arith.extsi %3 : tensor<1024xi32> to tensor<1024xi64> loc(#loc29)
    %39 = arith.addi %38, %37 : tensor<1024xi64> loc(#loc29)
    %40 = tt.splat %arg3 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc30)
    %41 = tt.addptr %40, %39 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi64> loc(#loc30)
    %42 = tt.load %41, %30, %cst_10 : tensor<1024x!tt.ptr<f32>> loc(#loc31)
    %43 = arith.addf %31, %42 : tensor<1024xf32> loc(#loc32)
    %44 = tt.broadcast %13 : tensor<1xi64> -> tensor<1024xi64> loc(#loc33)
    %45 = arith.addi %44, %cst_0 : tensor<1024xi64> loc(#loc33)
    %46 = arith.cmpi slt, %13, %cst_2 : tensor<1xi64> loc(#loc34)
    %47 = tt.broadcast %46 : tensor<1xi1> -> tensor<1024xi1> loc(#loc35)
    %48 = arith.select %47, %45, %44 : tensor<1024xi1>, tensor<1024xi64> loc(#loc35)
    %49 = arith.muli %48, %cst_1 : tensor<1024xi64> loc(#loc36)
    %50 = arith.addi %38, %49 : tensor<1024xi64> loc(#loc37)
    %51 = tt.splat %arg5 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc38)
    %52 = tt.addptr %51, %50 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi64> loc(#loc38)
    %53 = tt.load %52, %30, %cst_10 : tensor<1024x!tt.ptr<f32>> loc(#loc39)
    %54 = arith.addf %43, %53 : tensor<1024xf32> loc(#loc40)
    %55 = arith.select %30, %54, %cst_10 : tensor<1024xi1>, tensor<1024xf32> loc(#loc41)
    %56 = "tt.reduce"(%55) <{axis = 0 : i32}> ({
    ^bb0(%arg14: f32 loc(unknown), %arg15: f32 loc(unknown)):
      %86 = arith.addf %arg14, %arg15 : f32 loc(#loc76)
      tt.reduce.return %86 : f32 loc(#loc71)
    }) : (tensor<1024xf32>) -> f32 loc(#loc71)
    %57 = tt.splat %56 : f32 -> tensor<1xf32> loc(#loc73)
    %58 = arith.addf %57, %cst : tensor<1xf32> loc(#loc73)
    %59 = arith.divf %58, %cst_6 : tensor<1xf32> loc(#loc47)
    %60 = tt.broadcast %59 : tensor<1xf32> -> tensor<1024xf32> loc(#loc48)
    %61 = arith.subf %54, %60 : tensor<1024xf32> loc(#loc48)
    %62 = arith.mulf %61, %61 : tensor<1024xf32> loc(#loc49)
    %63 = arith.select %30, %62, %cst_10 : tensor<1024xi1>, tensor<1024xf32> loc(#loc50)
    %64 = "tt.reduce"(%63) <{axis = 0 : i32}> ({
    ^bb0(%arg14: f32 loc(unknown), %arg15: f32 loc(unknown)):
      %86 = arith.addf %arg14, %arg15 : f32 loc(#loc77)
      tt.reduce.return %86 : f32 loc(#loc74)
    }) : (tensor<1024xf32>) -> f32 loc(#loc74)
    %65 = tt.splat %64 : f32 -> tensor<1xf32> loc(#loc75)
    %66 = arith.addf %65, %cst : tensor<1xf32> loc(#loc75)
    %67 = arith.divf %66, %cst_6 : tensor<1xf32> loc(#loc53)
    %68 = arith.addf %67, %cst_5 : tensor<1xf32> loc(#loc54)
    %69 = tt.extern_elementwise %68 {libname = "", libpath = "", pure = true, symbol = "__nv_rsqrtf"} : (tensor<1xf32>) -> tensor<1xf32> loc(#loc55)
    %70 = tt.broadcast %69 : tensor<1xf32> -> tensor<1024xf32> loc(#loc56)
    %71 = arith.mulf %61, %70 : tensor<1024xf32> loc(#loc56)
    %72 = arith.mulf %71, %16 : tensor<1024xf32> loc(#loc57)
    %73 = arith.addf %72, %19 : tensor<1024xf32> loc(#loc58)
    %74 = arith.mulf %69, %cst_4 : tensor<1xf32> loc(#loc59)
    %75 = arith.muli %1, %cst_7 : tensor<1xi32> loc(#loc60)
    %76 = tt.broadcast %75 : tensor<1xi32> -> tensor<1024xi32> loc(#loc61)
    %77 = arith.addi %3, %76 : tensor<1024xi32> loc(#loc61)
    %78 = tt.splat %arg8 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc62)
    %79 = tt.addptr %78, %77 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc62)
    tt.store %79, %54, %30 : tensor<1024x!tt.ptr<f32>> loc(#loc63)
    %80 = tt.splat %arg9 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc64)
    %81 = tt.addptr %80, %77 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc64)
    tt.store %81, %71, %30 : tensor<1024x!tt.ptr<f32>> loc(#loc65)
    %82 = tt.splat %arg10 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc66)
    %83 = tt.addptr %82, %77 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc66)
    tt.store %83, %73, %30 : tensor<1024x!tt.ptr<f32>> loc(#loc67)
    %84 = tt.splat %arg11 : !tt.ptr<f32> -> tensor<1x!tt.ptr<f32>> loc(#loc68)
    %85 = tt.addptr %84, %1 : tensor<1x!tt.ptr<f32>>, tensor<1xi32> loc(#loc68)
    tt.store %85, %74, %2 : tensor<1x!tt.ptr<f32>> loc(#loc69)
    tt.return loc(#loc70)
  } loc(#loc)
} loc(#loc)
#loc2 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":22:28)
#loc3 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":23:35)
#loc4 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":24:21)
#loc5 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":25:26)
#loc6 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":27:21)
#loc7 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":30:30)
#loc8 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":30:35)
#loc9 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":31:30)
#loc10 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":31:35)
#loc11 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":32:31)
#loc12 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":32:36)
#loc13 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":33:31)
#loc14 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":33:36)
#loc15 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":34:31)
#loc16 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":34:36)
#loc17 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":36:18)
#loc18 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":37:18)
#loc19 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":38:32)
#loc20 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":40:40)
#loc21 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":40:36)
#loc22 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":40:30)
#loc23 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":40:56)
#loc24 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":40:48)
#loc25 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":42:18)
#loc26 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":43:19)
#loc27 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":44:34)
#loc28 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":46:41)
#loc29 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":46:37)
#loc30 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":46:31)
#loc31 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":46:50)
#loc32 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":47:19)
#loc33 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":49:20)
#loc34 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":50:20)
#loc35 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":51:35)
#loc36 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":53:41)
#loc37 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":53:37)
#loc38 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":53:31)
#loc39 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":53:50)
#loc40 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":54:20)
#loc41 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":58:43)
#loc42 = loc("/home/yewei/code/git/triton_evofc/triton/python/triton/language/standard.py":267:36)
#loc43 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":59:59)
#loc44 = loc("/home/yewei/code/git/triton_evofc/triton/python/triton/language/standard.py":256:15)
#loc45 = loc("/home/yewei/miniconda3/envs/compile_triton_shared/lib/python3.10/site-packages/torch/_inductor/runtime/triton_helpers.py":40:15)
#loc46 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":59:45)
#loc47 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":62:20)
#loc48 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":63:20)
#loc49 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":64:20)
#loc50 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":66:43)
#loc51 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":67:59)
#loc52 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":67:45)
#loc53 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":70:20)
#loc54 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":72:20)
#loc55 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":73:28)
#loc56 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":74:20)
#loc57 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":75:20)
#loc58 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":76:20)
#loc59 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":78:20)
#loc60 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":79:35)
#loc61 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":79:31)
#loc62 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":79:25)
#loc63 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":79:48)
#loc64 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":80:25)
#loc65 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":80:48)
#loc66 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":81:25)
#loc67 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":81:48)
#loc68 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":82:25)
#loc69 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":82:37)
#loc70 = loc("/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py":82:4)
#loc71 = loc(callsite(#loc42 at #loc43))
#loc72 = loc(callsite(#loc44 at #loc42))
#loc73 = loc(callsite(#loc45 at #loc46))
#loc74 = loc(callsite(#loc42 at #loc51))
#loc75 = loc(callsite(#loc45 at #loc52))
#loc76 = loc(callsite(#loc72 at #loc43))
#loc77 = loc(callsite(#loc72 at #loc51))

Crash log

PtrAnalysis: encountered addptr operand produced by an unsupported operation
%33 = arith.cmpi slt, %12, %cst_8 : tensor<1xi32>
/home/yewei/code/git/triton_evofc/ev_test/model/nlp/bert128/test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py:40:30: remark: PtrAnalysis: Failed to rewrite AddPtrOp
    tmp6 = tl.load(in_ptr1 + (r1 + (768*tmp4)), rmask & xmask, other=0.0)

Additional information

pytest -s test_triton_per_fused_add_embedding_native_layer_norm_native_layer_norm_backward_0.py

@GoodNight-bye GoodNight-bye added the bug Something isn't working label Sep 20, 2024
@GoodNight-bye GoodNight-bye changed the title [Bug]: not support discrete tt.load for tmp6 [Bug]: not support scatter tt.load for tmp6 Sep 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant