Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: cpu backend make ttir or ttir to ttsharedir failed with some tests in triton-lang/kernels #185

Open
HuanyuCai opened this issue Oct 12, 2024 · 5 comments
Labels
bug Something isn't working

Comments

@HuanyuCai
Copy link

HuanyuCai commented Oct 12, 2024

Triton python code

Actually it is a piece of code from https://github.com/triton-lang/kernels/blob/main/test/test_inductor.py

It seems that almost all of the kernel tests in this repo cannot run cpu backend without crash, sometimes crash at ast to ttir, somtimes at ttir to ttsharedir.

import pytest
import torch

import triton
import triton.language as tl

def test_normalization_with_remat(device):

@triton.jit
def triton_(
    in_out_ptr0,
    in_out_ptr1,
    in_ptr0,
    in_ptr1,
    in_ptr2,
    in_ptr3,
    xnumel,
    rnumel,
    XBLOCK: tl.constexpr,
    RBLOCK: tl.constexpr,
):
    xnumel = 512
    rnumel = 4096
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x3 = xindex
    x0 = xindex % 64
    tmp1 = tl.load(in_ptr0 + (x0), xmask)
    tmp3 = tl.load(in_ptr1 + (x0), xmask)
    tmp11 = tl.load(in_ptr2 + (x0), xmask)
    tmp13 = tl.load(in_ptr3 + (x0), xmask)
    _tmp17 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r2 = rindex
        tmp0 = tl.load(
            in_out_ptr0 + (r2 + (4096 * x3)),
            rmask & xmask,
            eviction_policy="evict_last",
            other=0,
        )
        tmp2 = tmp0 - tmp1
        tmp4 = 1e-05
        tmp5 = tmp3 + tmp4
        tmp6 = tl.sqrt(tmp5)
        tmp7 = 1 / tmp6
        tmp8 = 1.0
        tmp9 = tmp7 * tmp8
        tmp10 = tmp2 * tmp9
        tmp12 = tmp10 * tmp11
        tmp14 = tmp12 + tmp13
        _tmp17 = tl.where(rmask & xmask, _tmp17 + tmp14, _tmp17)
        tl.store(
            in_out_ptr0 + (r2 + (4096 * x3) + tl.zeros([XBLOCK, RBLOCK], tl.int32)),
            tmp14,
            rmask & xmask,
        )
    tmp17 = tl.sum(_tmp17, 1)[:, None]
    tmp18 = 4096.0
    tmp19 = tmp17 / tmp18
    tl.store(in_out_ptr1 + (x3 + tl.zeros([XBLOCK, 1], tl.int32)), tmp19, xmask)

torch.manual_seed(123)

buf14 = torch.rand(8, 64, 64, 64, device=device)
buf16 = torch.rand(8, 1, 64, device=device)
arg114_1 = torch.rand(64, device=device)
arg115_1 = torch.rand(64, device=device)
arg8_1 = torch.rand(64, device=device)
arg9_1 = torch.rand(64, device=device)
triton_[(512,)](
    buf14, buf16, arg114_1, arg115_1, arg8_1, arg9_1, 512, 4096, 1, 2048
)
torch.testing.assert_close(
    buf16.mean().item(), buf14.mean().item(), atol=1e-7, rtol=0
)

Triton IR

part of ttir around tt.load

%9 = tt.addptr %arg3, %5 : !tt.ptr, i32 loc(#loc9)
%10 = tt.splat %9 : !tt.ptr -> tensor<1x1x!tt.ptr> loc(#loc9)
%11 = tt.load %10, %2 : tensor<1x1x!tt.ptr> loc(#loc10)
%12 = tt.addptr %arg4, %5 : !tt.ptr, i32 loc(#loc11)

Crash log

error: PtrAnalysis: pointer is not replace with tts.make_tptr so loadOp cannot be rewritten
It is seems that the tt.splat has no implementation of rewriteSplatOp to createTTSMakeTensorPtrOp.

Additional information

No response

@HuanyuCai HuanyuCai added the bug Something isn't working label Oct 12, 2024
@nhat-nguyen
Copy link
Collaborator

Thanks so much for reporting the issue. torch-inductor generated code is definitely an area that we currently have a lot of trouble due to the non-structured pointer access patterns. We currently have a planned feature that will lower all of these non-structured accesses to gather / scatter which should hopefully solve most of these problems. I'll update the issue as soon as we have any meaningful progress.

@yonucy
Copy link

yonucy commented Nov 14, 2024

In order to temporarily fix the MaskAnalysis Fail bug, I added the following to the parseSplat method in MaskAnalysis.cpp:

  if (src.getType().isInteger(1)) {
      Value indexValue = builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), src);
      for (auto s : dstShape) {
          Value newConstantOp = builder.create<arith::ConstantOp>(loc, builder.getIndexAttr(s));
          // 使用 MulIOp 进行乘法运算
          auto mulOp = builder.create<arith::MulIOp>(loc, newConstantOp, indexValue);
          this->dims.push_back(mulOp.getResult());
      }
      return success();
  }

After adding, the parseSplat method is implemented as follows:

LogicalResult MaskState::parseSplat(triton::SplatOp splatOp, const Location loc,
                                    OpBuilder &builder) {
  assert(this->isEmpty());

  auto src = splatOp.getSrc();
  auto dst = splatOp.getResult();
  auto dstShape = cast<ShapedType>(dst.getType()).getShape();

  if (!isa<IntegerType>(src.getType())) {
    InFlightDiagnostic diag =
        emitError(loc)
        << "splat source must be an integer scalar for load/store masks";
    return failure();
  }

  if (src.getType().isInteger(1)) {
      Value indexValue = builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), src);
      for (auto s : dstShape) {
          Value newConstantOp = builder.create<arith::ConstantOp>(loc, builder.getIndexAttr(s));
          // 使用 MulIOp 进行乘法运算
          auto mulOp = builder.create<arith::MulIOp>(loc, newConstantOp, indexValue);
          this->dims.push_back(mulOp.getResult());
      }
      return success();
  }

  if (failed(this->parse(src, loc, builder)))
    return failure();

  for (auto s : dstShape)
    this->dims.push_back(builder.getIndexAttr(s));

  return success();
}

I understand that this can fix the error, but there is a problem. Although there is no error in converting ttir to ttsharedir, a segmentation fault occurs when using python test_inductor.py in the end, without any other information, only the word "segmentation fault".

However, if I add the following code, the code runs successfully and produces the correct results:

  if (src.getType().isInteger(1)) {
      for (auto s : dstShape) {
           this->dims.push_back(builder.getIndexAttr(s));
      }
      return success();
  }

I don't know why this is happening?

@parsifal-47
Copy link
Contributor

do you have stack traces enabled? If you have your branch published, I can run it on my setup to see if I can get more information

@yonucy
Copy link

yonucy commented Nov 15, 2024

dda47e3

@HuanyuCai
Copy link
Author

HuanyuCai commented Dec 11, 2024

dda47e3

Your code handles splat i1 and splat !tt.ptr<> well as shown below

%10 = tt.splat %5 : i1 -> tensor<1x1024xi1> loc(#loc10)
%11 = tt.splat %2: !tt.ptr -> tensor<1x1024x!tt.ptr> loc(#loc11)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants