nixpkgs/pkgs/development/python-modules/openai-triton/llvm15.patch
2023-04-08 02:46:54 +03:00

4618 lines
236 KiB
Diff

From fc7a8e35819bda632bdcf1cf75fd9abe4d4e077a Mon Sep 17 00:00:00 2001
From: Christian Sigg <chsigg@users.noreply.github.com>
Date: Thu, 16 Feb 2023 15:40:53 +0100
Subject: [PATCH] Rebase Triton to LLVM-15. (#1070)
This PR rebases Triton from LLVM-14 to LLVM-15. Most changes are
mechanical, except for the analysis framework changes.
---
CMakeLists.txt | 6 +-
bin/CMakeLists.txt | 2 +-
bin/FileCheck/FileCheck.cpp | 3 +
bin/triton-opt.cpp | 6 +-
bin/triton-translate.cpp | 7 +-
include/triton/Analysis/Alias.h | 21 +-
include/triton/Analysis/Allocation.h | 2 +
include/triton/Analysis/AxisInfo.h | 56 ++-
include/triton/Analysis/Utility.h | 6 +-
include/triton/Conversion/Passes.td | 4 +-
include/triton/Dialect/Triton/IR/Dialect.h | 7 +-
.../triton/Dialect/Triton/IR/TritonDialect.td | 8 +-
include/triton/Dialect/Triton/IR/TritonOps.td | 12 +-
.../triton/Dialect/Triton/IR/TritonTypes.td | 2 +
.../Dialect/Triton/Transforms/Passes.td | 3 +-
include/triton/Dialect/TritonGPU/IR/Dialect.h | 4 +-
.../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 7 +
.../Dialect/TritonGPU/IR/TritonGPUDialect.td | 2 +-
.../Dialect/TritonGPU/IR/TritonGPUOps.td | 13 +-
lib/Analysis/Alias.cpp | 14 +-
lib/Analysis/Allocation.cpp | 30 +-
lib/Analysis/AxisInfo.cpp | 79 ++--
lib/Analysis/CMakeLists.txt | 2 +-
lib/Analysis/Membar.cpp | 2 +-
lib/Analysis/Utility.cpp | 54 +++
.../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 3 -
lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h | 10 +-
.../TritonGPUToLLVM/DotOpToLLVM.cpp | 5 -
.../TritonGPUToLLVM/ElementwiseOpToLLVM.cpp | 2 -
.../TritonGPUToLLVM/LoadStoreOpToLLVM.cpp | 5 +-
.../TritonGPUToLLVM/ReduceOpToLLVM.cpp | 2 -
.../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 7 +-
.../TritonGPUToLLVM/TritonGPUToLLVMBase.h | 26 +-
.../TritonGPUToLLVM/TritonGPUToLLVMPass.cpp | 52 +--
lib/Conversion/TritonGPUToLLVM/Utility.h | 5 +-
.../TritonToTritonGPUPass.cpp | 69 ++--
lib/Dialect/Triton/IR/CMakeLists.txt | 10 +-
lib/Dialect/Triton/IR/Ops.cpp | 34 +-
lib/Dialect/Triton/Transforms/Combine.cpp | 6 +-
lib/Dialect/Triton/Transforms/Combine.td | 2 +-
lib/Dialect/TritonGPU/IR/Dialect.cpp | 27 +-
lib/Dialect/TritonGPU/Transforms/Coalesce.cpp | 20 +-
lib/Dialect/TritonGPU/Transforms/Combine.cpp | 2 +-
lib/Dialect/TritonGPU/Transforms/Combine.td | 1 +
.../Transforms/DecomposeConversions.cpp | 2 +-
lib/Dialect/TritonGPU/Transforms/Pipeline.cpp | 10 +-
.../Transforms/ReorderInstructions.cpp | 2 +-
.../Transforms/TritonGPUConversion.cpp | 12 +-
.../Transforms/UpdateMmaForVolta.cpp | 6 +-
lib/Dialect/TritonGPU/Transforms/Utility.cpp | 2 +-
lib/Target/LLVMIR/CMakeLists.txt | 3 +-
lib/Target/PTX/PTXTranslation.cpp | 3 +
python/setup.py | 15 +-
python/src/triton.cc | 85 +++--
python/test/unit/language/test_core.py | 2 +-
python/triton/compiler.py | 4 +-
test/Analysis/test-alias.mlir | 24 +-
test/Analysis/test-alignment.mlir | 344 +++++++++---------
test/Analysis/test-allocation.mlir | 32 +-
test/Analysis/test-membar.mlir | 38 +-
test/Conversion/triton_ops.mlir | 10 +-
test/Conversion/triton_to_tritongpu.mlir | 6 +-
test/Conversion/tritongpu_to_llvm.mlir | 94 ++---
test/Target/tritongpu_to_llvmir.mlir | 4 +-
test/Target/tritongpu_to_ptx.mlir | 2 +-
test/Triton/combine.mlir | 40 +-
test/Triton/vecadd.mlir | 4 +-
test/TritonGPU/coalesce.mlir | 2 +-
test/TritonGPU/combine.mlir | 38 +-
test/TritonGPU/loop-pipeline.mlir | 22 +-
test/TritonGPU/matmul.mlir | 4 +-
test/TritonGPU/prefetch.mlir | 4 +-
test/TritonGPU/update-mma-for-volta.mlir | 4 +-
test/lib/Analysis/TestAlias.cpp | 29 +-
test/lib/Analysis/TestAllocation.cpp | 5 +-
test/lib/Analysis/TestAxisInfo.cpp | 51 +--
test/lib/Analysis/TestMembar.cpp | 7 +-
78 files changed, 808 insertions(+), 742 deletions(-)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index d0d361fc7c..b281a28400 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1,4 +1,7 @@
cmake_minimum_required(VERSION 3.6)
+
+cmake_policy(SET CMP0116 OLD)
+
include(ExternalProject)
set(CMAKE_CXX_STANDARD 17)
@@ -155,7 +158,6 @@ if(TRITON_BUILD_PYTHON_MODULE)
endif()
endif()
-
# # Triton
# file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
# if (WIN32 AND TRITON_BUILD_PYTHON_MODULE)
@@ -212,7 +214,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
# optimizations
MLIRPass
MLIRTransforms
- MLIRLLVMIR
+ MLIRLLVMDialect
MLIRSupport
MLIRTargetLLVMIRExport
MLIRExecutionEngine
diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt
index 906f635f8b..695b3479fd 100644
--- a/bin/CMakeLists.txt
+++ b/bin/CMakeLists.txt
@@ -48,7 +48,7 @@ llvm_update_compile_flags(triton-translate)
# MLIR core
MLIROptLib
MLIRIR
- MLIRLLVMIR
+ MLIRLLVMDialect
MLIRPass
MLIRSupport
MLIRTransforms
diff --git a/bin/FileCheck/FileCheck.cpp b/bin/FileCheck/FileCheck.cpp
index 819efc3541..9ac6f1b277 100644
--- a/bin/FileCheck/FileCheck.cpp
+++ b/bin/FileCheck/FileCheck.cpp
@@ -19,6 +19,7 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/Process.h"
+#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/WithColor.h"
#include "llvm/Support/raw_ostream.h"
#include <cmath>
@@ -360,6 +361,8 @@ static std::string GetCheckTypeAbbreviation(Check::FileCheckType Ty) {
return "bad-not";
case Check::CheckBadCount:
return "bad-count";
+ case Check::CheckMisspelled:
+ return "misspelled";
case Check::CheckNone:
llvm_unreachable("invalid FileCheckType");
}
diff --git a/bin/triton-opt.cpp b/bin/triton-opt.cpp
index 9f3b53b7ae..f96232e1b0 100644
--- a/bin/triton-opt.cpp
+++ b/bin/triton-opt.cpp
@@ -8,7 +8,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/InitAllPasses.h"
-#include "mlir/Support/MlirOptMain.h"
+#include "mlir/Tools/mlir-opt/MlirOptMain.h"
namespace mlir {
namespace test {
@@ -33,8 +33,8 @@ int main(int argc, char **argv) {
// TODO: register Triton & TritonGPU passes
mlir::DialectRegistry registry;
registry.insert<mlir::triton::TritonDialect,
- mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect,
- mlir::arith::ArithmeticDialect, mlir::StandardOpsDialect,
+ mlir::triton::gpu::TritonGPUDialect, mlir::func::FuncDialect,
+ mlir::math::MathDialect, mlir::arith::ArithmeticDialect,
mlir::scf::SCFDialect, mlir::gpu::GPUDialect>();
return mlir::asMainReturnCode(mlir::MlirOptMain(
diff --git a/bin/triton-translate.cpp b/bin/triton-translate.cpp
index 05ba15e453..56b5d65857 100644
--- a/bin/triton-translate.cpp
+++ b/bin/triton-translate.cpp
@@ -3,7 +3,7 @@
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
-#include "mlir/Parser.h"
+#include "mlir/Parser/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/FileUtilities.h"
@@ -38,7 +38,7 @@ OwningOpRef<ModuleOp> loadMLIRModule(llvm::StringRef inputFilename,
mlir::DialectRegistry registry;
registry.insert<TritonDialect, triton::gpu::TritonGPUDialect,
mlir::math::MathDialect, arith::ArithmeticDialect,
- StandardOpsDialect, scf::SCFDialect>();
+ scf::SCFDialect>();
context.appendDialectRegistry(registry);
@@ -50,7 +50,8 @@ OwningOpRef<ModuleOp> loadMLIRModule(llvm::StringRef inputFilename,
context.loadAllAvailableDialects();
context.allowUnregisteredDialects();
- OwningOpRef<ModuleOp> module(parseSourceFile(sourceMgr, &context));
+ OwningOpRef<ModuleOp> module =
+ parseSourceFile<ModuleOp>(sourceMgr, &context);
if (!module) {
llvm::errs() << "Parse MLIR file failed.";
return nullptr;
diff --git a/include/triton/Analysis/Alias.h b/include/triton/Analysis/Alias.h
index fa6b906fc9..631df518bc 100644
--- a/include/triton/Analysis/Alias.h
+++ b/include/triton/Analysis/Alias.h
@@ -2,7 +2,7 @@
#define TRITON_ANALYSIS_ALIAS_H
#include "mlir/Analysis/AliasAnalysis.h"
-#include "mlir/Analysis/DataFlowAnalysis.h"
+#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
#include "llvm/ADT/DenseSet.h"
namespace mlir {
@@ -21,7 +21,7 @@ class AliasInfo {
}
/// The pessimistic value state of a value without alias
- static AliasInfo getPessimisticValueState(MLIRContext *context) {
+ static AliasInfo getPessimisticValueState(MLIRContext *context = nullptr) {
return AliasInfo();
}
static AliasInfo getPessimisticValueState(Value value) { return AliasInfo(); }
@@ -29,6 +29,10 @@ class AliasInfo {
/// The union of both arguments
static AliasInfo join(const AliasInfo &lhs, const AliasInfo &rhs);
+ void print(raw_ostream &os) const {
+ llvm::interleaveComma(allocs, os, [&](Value alloc) { alloc.print(os); });
+ }
+
private:
/// The set of allocated values that are aliased by this lattice.
/// For now, we only consider aliased value produced by the following
@@ -58,9 +62,13 @@ class AliasInfo {
//===----------------------------------------------------------------------===//
// Shared Memory Alias Analysis
//===----------------------------------------------------------------------===//
-class SharedMemoryAliasAnalysis : public ForwardDataFlowAnalysis<AliasInfo> {
+class SharedMemoryAliasAnalysis
+ : public dataflow::SparseDataFlowAnalysis<dataflow::Lattice<AliasInfo>> {
public:
- using ForwardDataFlowAnalysis<AliasInfo>::ForwardDataFlowAnalysis;
+ using dataflow::SparseDataFlowAnalysis<
+ dataflow::Lattice<AliasInfo>>::SparseDataFlowAnalysis;
+ using dataflow::SparseDataFlowAnalysis<
+ dataflow::Lattice<AliasInfo>>::getLatticeElement;
/// XXX(Keren): Compatible interface with MLIR AliasAnalysis for future use.
/// Given two values, returns their aliasing behavior.
@@ -70,9 +78,10 @@ class SharedMemoryAliasAnalysis : public ForwardDataFlowAnalysis<AliasInfo> {
ModRefResult getModRef(Operation *op, Value location);
/// Computes if the alloc set of the results are changed.
- ChangeResult
+ void
visitOperation(Operation *op,
- ArrayRef<LatticeElement<AliasInfo> *> operands) override;
+ ArrayRef<const dataflow::Lattice<AliasInfo> *> operands,
+ ArrayRef<dataflow::Lattice<AliasInfo> *> results) override;
};
} // namespace mlir
diff --git a/include/triton/Analysis/Allocation.h b/include/triton/Analysis/Allocation.h
index b7c136d602..89b77034cc 100644
--- a/include/triton/Analysis/Allocation.h
+++ b/include/triton/Analysis/Allocation.h
@@ -188,6 +188,8 @@ class Allocation {
friend class triton::AllocationAnalysis;
};
+template <typename T> Interval(T, T) -> Interval<T>;
+
} // namespace mlir
#endif // TRITON_ANALYSIS_ALLOCATION_H
diff --git a/include/triton/Analysis/AxisInfo.h b/include/triton/Analysis/AxisInfo.h
index fdfbd8fbb3..7083b9c43b 100644
--- a/include/triton/Analysis/AxisInfo.h
+++ b/include/triton/Analysis/AxisInfo.h
@@ -1,9 +1,10 @@
#ifndef TRITON_ANALYSIS_AXISINFO_H
#define TRITON_ANALYSIS_AXISINFO_H
-#include "mlir/Analysis/DataFlowAnalysis.h"
+#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
#include "llvm/Support/raw_ostream.h"
+#include "mlir/Support/LLVM.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
@@ -62,7 +63,7 @@ class AxisInfo {
}
/// The pessimistic value state of the contiguity is unknown.
- static AxisInfo getPessimisticValueState(MLIRContext *context) {
+ static AxisInfo getPessimisticValueState(MLIRContext *context = nullptr) {
return AxisInfo();
}
static AxisInfo getPessimisticValueState(Value value);
@@ -70,6 +71,22 @@ class AxisInfo {
/// The gcd of both arguments for each dimension
static AxisInfo join(const AxisInfo &lhs, const AxisInfo &rhs);
+ void print(raw_ostream &os) const {
+ auto print = [&](StringRef name, DimVectorT vec) {
+ os << name << " = [";
+ llvm::interleaveComma(vec, os);
+ os << "]";
+ };
+ print("contiguity", contiguity);
+ print(", divisibility", divisibility);
+ print(", constancy", constancy);
+ os << ", constant_value = ";
+ if (constantValue)
+ os << *constantValue;
+ else
+ os << "<none>";
+ }
+
private:
/// The _contiguity_ information maps the `d`-th
/// dimension to the length of the shortest
@@ -147,7 +164,8 @@ class AxisInfoVisitor {
}
virtual AxisInfo
- getAxisInfo(Operation *op, ArrayRef<LatticeElement<AxisInfo> *> operands) = 0;
+ getAxisInfo(Operation *op,
+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) = 0;
virtual bool match(Operation *op) = 0;
};
@@ -157,15 +175,16 @@ template <typename OpTy> class AxisInfoVisitorImpl : public AxisInfoVisitor {
public:
using AxisInfoVisitor::AxisInfoVisitor;
- AxisInfo getAxisInfo(Operation *op,
- ArrayRef<LatticeElement<AxisInfo> *> operands) final {
+ AxisInfo
+ getAxisInfo(Operation *op,
+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) final {
return getAxisInfo(cast<OpTy>(op), operands);
}
bool match(Operation *op) final { return isa<OpTy>(op); }
- virtual AxisInfo getAxisInfo(OpTy op,
- ArrayRef<LatticeElement<AxisInfo> *> operands) {
+ virtual AxisInfo
+ getAxisInfo(OpTy op, ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) {
llvm_unreachable("Unimplemented getAxisInfo");
}
};
@@ -176,8 +195,9 @@ class BinaryOpVisitorImpl : public AxisInfoVisitorImpl<OpTy> {
public:
using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl;
- AxisInfo getAxisInfo(OpTy op,
- ArrayRef<LatticeElement<AxisInfo> *> operands) override {
+ AxisInfo
+ getAxisInfo(OpTy op,
+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
auto lhsInfo = operands[0]->getValue();
auto rhsInfo = operands[1]->getValue();
auto rank = lhsInfo.getRank();
@@ -230,7 +250,8 @@ class AxisInfoVisitorList {
(visitors.emplace_back(std::make_unique<Ts>()), ...);
}
- AxisInfo apply(Operation *op, ArrayRef<LatticeElement<AxisInfo> *> operands) {
+ AxisInfo apply(Operation *op,
+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) {
for (auto &visitor : visitors)
if (visitor->match(op))
return visitor->getAxisInfo(op, operands);
@@ -241,16 +262,19 @@ class AxisInfoVisitorList {
std::vector<std::unique_ptr<AxisInfoVisitor>> visitors;
};
-class AxisInfoAnalysis : public ForwardDataFlowAnalysis<AxisInfo> {
+class AxisInfoAnalysis
+ : public dataflow::SparseDataFlowAnalysis<dataflow::Lattice<AxisInfo>> {
private:
AxisInfoVisitorList visitors;
public:
- AxisInfoAnalysis(MLIRContext *context);
+ AxisInfoAnalysis(DataFlowSolver &solver);
+ using dataflow::SparseDataFlowAnalysis<
+ dataflow::Lattice<AxisInfo>>::getLatticeElement;
- ChangeResult
- visitOperation(Operation *op,
- ArrayRef<LatticeElement<AxisInfo> *> operands) override;
+ void visitOperation(Operation *op,
+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands,
+ ArrayRef<dataflow::Lattice<AxisInfo> *> results) override;
unsigned getPtrContiguity(Value ptr);
@@ -261,4 +285,4 @@ class AxisInfoAnalysis : public ForwardDataFlowAnalysis<AxisInfo> {
} // namespace mlir
-#endif
\ No newline at end of file
+#endif
diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h
index c5ac137dc1..ee7fadb59d 100644
--- a/include/triton/Analysis/Utility.h
+++ b/include/triton/Analysis/Utility.h
@@ -1,6 +1,7 @@
#ifndef TRITON_ANALYSIS_UTILITY_H
#define TRITON_ANALYSIS_UTILITY_H
+#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <algorithm>
@@ -12,7 +13,7 @@ namespace mlir {
class ReduceOpHelper {
public:
explicit ReduceOpHelper(triton::ReduceOp op) : op(op) {
- srcTy = op.operand().getType().cast<RankedTensorType>();
+ srcTy = op.getOperand().getType().cast<RankedTensorType>();
}
ArrayRef<int64_t> getSrcShape() { return srcTy.getShape(); }
@@ -103,6 +104,9 @@ SetVector<Operation *>
multiRootGetSlice(Operation *op, TransitiveFilter backwardFilter = nullptr,
TransitiveFilter forwardFilter = nullptr);
+// Create a basic DataFlowSolver with constant and dead code analysis included.
+std::unique_ptr<DataFlowSolver> createDataFlowSolver();
+
} // namespace mlir
#endif // TRITON_ANALYSIS_UTILITY_H
diff --git a/include/triton/Conversion/Passes.td b/include/triton/Conversion/Passes.td
index 70bb20b78e..be00eb2dac 100644
--- a/include/triton/Conversion/Passes.td
+++ b/include/triton/Conversion/Passes.td
@@ -12,7 +12,6 @@ def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleO
let dependentDialects = ["mlir::arith::ArithmeticDialect",
"mlir::math::MathDialect",
- "mlir::StandardOpsDialect",
// TODO: Does this pass depend on SCF?
"mlir::scf::SCFDialect",
"mlir::triton::TritonDialect",
@@ -41,8 +40,7 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
"mlir::tensor::TensorDialect",
"mlir::triton::TritonDialect",
"mlir::triton::gpu::TritonGPUDialect",
- "mlir::NVVM::NVVMDialect",
- "mlir::StandardOpsDialect"];
+ "mlir::NVVM::NVVMDialect"];
let options = [
Option<"computeCapability", "compute-capability",
diff --git a/include/triton/Dialect/Triton/IR/Dialect.h b/include/triton/Dialect/Triton/IR/Dialect.h
index e8012a51df..15869e262e 100644
--- a/include/triton/Dialect/Triton/IR/Dialect.h
+++ b/include/triton/Dialect/Triton/IR/Dialect.h
@@ -1,14 +1,15 @@
#ifndef TRITON_DIALECT_TRITON_IR_DIALECT_H_
#define TRITON_DIALECT_TRITON_IR_DIALECT_H_
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
-#include "mlir/Dialect/SCF/SCF.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
-
#include "triton/Dialect/Triton/IR/Dialect.h.inc"
#include "triton/Dialect/Triton/IR/OpsEnums.h.inc"
#include "triton/Dialect/Triton/IR/Traits.h"
diff --git a/include/triton/Dialect/Triton/IR/TritonDialect.td b/include/triton/Dialect/Triton/IR/TritonDialect.td
index 07b069e14f..d98ce73884 100644
--- a/include/triton/Dialect/Triton/IR/TritonDialect.td
+++ b/include/triton/Dialect/Triton/IR/TritonDialect.td
@@ -25,12 +25,9 @@ def Triton_Dialect : Dialect {
let dependentDialects = [
"arith::ArithmeticDialect",
"math::MathDialect",
- "StandardOpsDialect",
"scf::SCFDialect",
-
- // Since LLVM 15
- // "cf::ControlFlowDialect",
- // "func::FuncDialect"
+ "cf::ControlFlowDialect",
+ "func::FuncDialect"
];
let extraClassDeclaration = [{
@@ -38,6 +35,7 @@ def Triton_Dialect : Dialect {
}];
let hasConstantMaterializer = 1;
+ let useDefaultTypePrinterParser = 1;
}
include "triton/Dialect/Triton/IR/TritonTypes.td"
diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td
index 779e0b648c..0a69211179 100644
--- a/include/triton/Dialect/Triton/IR/TritonOps.td
+++ b/include/triton/Dialect/Triton/IR/TritonOps.td
@@ -141,11 +141,7 @@ def TT_LoadOp : TT_Op<"load",
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
];
- // let assemblyFormat = "operands attr-dict `:` type($result)";
- let parser = [{ return mlir::triton::parseLoadOp(parser, result); }];
-
- let printer = [{ return mlir::triton::printLoadOp(p, *this); }];
-
+ let hasCustomAssemblyFormat = 1;
let hasCanonicalizer = 1;
}
@@ -170,11 +166,7 @@ def TT_StoreOp : TT_Op<"store",
"triton::EvictionPolicy":$evict)>,
];
- // let assemblyFormat = "operands attr-dict `:` type($value)";
- let parser = [{ return mlir::triton::parseStoreOp(parser, result); }];
-
- let printer = [{ return mlir::triton::printStoreOp(p, *this); }];
-
+ let hasCustomAssemblyFormat = 1;
let hasCanonicalizer = 1;
}
diff --git a/include/triton/Dialect/Triton/IR/TritonTypes.td b/include/triton/Dialect/Triton/IR/TritonTypes.td
index 66d2a7b9a9..2fe2fd077d 100644
--- a/include/triton/Dialect/Triton/IR/TritonTypes.td
+++ b/include/triton/Dialect/Triton/IR/TritonTypes.td
@@ -1,6 +1,7 @@
#ifndef TRITON_TYPES
#define TRITON_TYPES
+include "mlir/IR/AttrTypeBase.td"
include "triton/Dialect/Triton/IR/TritonDialect.td"
//
@@ -58,6 +59,7 @@ def TT_Ptr : TritonTypeDef<"Pointer", "ptr"> {
}]>
];
+ let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
}
def TT_PtrTensor : TensorOf<[TT_Ptr]>;
diff --git a/include/triton/Dialect/Triton/Transforms/Passes.td b/include/triton/Dialect/Triton/Transforms/Passes.td
index 8f77aed774..a25cdc5680 100644
--- a/include/triton/Dialect/Triton/Transforms/Passes.td
+++ b/include/triton/Dialect/Triton/Transforms/Passes.td
@@ -16,8 +16,7 @@ def TritonCombineOps : Pass</*cli-arg*/"triton-combine", /*Op*/"mlir::ModuleOp">
let constructor = "mlir::triton::createCombineOpsPass()";
- let dependentDialects = ["mlir::arith::ArithmeticDialect",
- /*SelectOp*/"mlir::StandardOpsDialect"];
+ let dependentDialects = ["mlir::arith::ArithmeticDialect"];
}
#endif
diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h
index b4c8daec7b..dfc5f53ab1 100644
--- a/include/triton/Dialect/TritonGPU/IR/Dialect.h
+++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h
@@ -1,19 +1,17 @@
#ifndef TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
#define TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
-#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
// TritonGPU depends on Triton
#include "triton/Dialect/Triton/IR/Dialect.h"
-
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
#include "triton/Dialect/TritonGPU/IR/Traits.h"
#define GET_ATTRDEF_CLASSES
-#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc"
#define GET_OP_CLASSES
diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
index 0242c3cc17..af2aeb03a8 100644
--- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
+++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
@@ -1,6 +1,7 @@
#ifndef TRITONGPU_ATTRDEFS
#define TRITONGPU_ATTRDEFS
+include "mlir/IR/AttrTypeBase.td"
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
@@ -136,6 +137,7 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
];
let extraClassDeclaration = extraBaseClassDeclaration;
+ let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
@@ -273,6 +275,7 @@ for
// ArrayRefParameter<"unsigned">:$sizePerCTA
);
+ let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
@@ -422,6 +425,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
static constexpr int numBitsToHoldMmaV1ID{5};
}];
+ let hasCustomAssemblyFormat = 1;
}
def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
@@ -456,6 +460,8 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
template<class T>
SmallVector<T> paddedShape(ArrayRef<T> shape) const;
}];
+
+ let hasCustomAssemblyFormat = 1;
}
def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding"> {
@@ -492,6 +498,7 @@ section 9.7.13.4.1 for more details.
];
+ let hasCustomAssemblyFormat = 1;
let extraClassDeclaration = extraBaseClassDeclaration;
}
diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td
index 87ec1d36c6..6489a721b4 100644
--- a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td
+++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td
@@ -30,7 +30,7 @@ def TritonGPU_Dialect : Dialect {
}
}];
-
+ let useDefaultAttributePrinterParser = 1;
}
#endif
diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
index 510f8d0183..7aba11dc75 100644
--- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
+++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
@@ -59,7 +59,7 @@ def TTG_AsyncCommitGroupOp : TTG_Op<"async_commit_group"> {
// This is needed because these ops don't
// handle encodings
// e.g., https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td#L111
-def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect, Elementwise,
+def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect, Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding]> {
let summary = "integer comparison operation";
@@ -73,7 +73,7 @@ def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect, Elementwise,
let results = (outs TT_BoolLike:$result);
}
-def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect, Elementwise,
+def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect, Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultEncoding]> {
let summary = "floating-point comparison operation";
@@ -88,8 +88,8 @@ def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect, Elementwise,
}
// TODO: migrate to arith::SelectOp on LLVM16
-def TTG_SelectOp : TTG_Op<"select", [NoSideEffect, Elementwise,
- SameOperandsAndResultShape,
+def TTG_SelectOp : TTG_Op<"select", [NoSideEffect, Elementwise,
+ SameOperandsAndResultShape,
SameOperandsAndResultEncoding]> {
let summary = "select operation";
@@ -188,10 +188,7 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
}
}];
- // The custom parser could be replaced with oilist in LLVM-16
- let parser = [{ return parseInsertSliceAsyncOp(parser, result); }];
-
- let printer = [{ return printInsertSliceAsyncOp(p, *this); }];
+ let hasCustomAssemblyFormat = 1;
}
def TTG_AllocTensorOp : TTG_Op<"alloc_tensor", [MemoryEffects<[MemAlloc]>, // Allocate shared memory
diff --git a/lib/Analysis/Alias.cpp b/lib/Analysis/Alias.cpp
index a39e4de9aa..208fdd4afc 100644
--- a/lib/Analysis/Alias.cpp
+++ b/lib/Analysis/Alias.cpp
@@ -18,8 +18,9 @@ AliasInfo AliasInfo::join(const AliasInfo &lhs, const AliasInfo &rhs) {
return ret;
}
-ChangeResult SharedMemoryAliasAnalysis::visitOperation(
- Operation *op, ArrayRef<LatticeElement<AliasInfo> *> operands) {
+void SharedMemoryAliasAnalysis::visitOperation(
+ Operation *op, ArrayRef<const dataflow::Lattice<AliasInfo> *> operands,
+ ArrayRef<dataflow::Lattice<AliasInfo> *> results) {
AliasInfo aliasInfo;
bool pessimistic = true;
if (maybeSharedAllocationOp(op)) {
@@ -44,14 +45,11 @@ ChangeResult SharedMemoryAliasAnalysis::visitOperation(
}
if (pessimistic) {
- return markAllPessimisticFixpoint(op->getResults());
+ return markAllPessimisticFixpoint(results);
}
// Join all lattice elements
- ChangeResult result = ChangeResult::NoChange;
- for (Value value : op->getResults()) {
- result |= getLatticeElement(value).join(aliasInfo);
- }
- return result;
+ for (auto *result : results)
+ propagateIfChanged(result, result->join(aliasInfo));
}
AliasResult SharedMemoryAliasAnalysis::alias(Value lhs, Value rhs) {
diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp
index 712c08c475..b4de8dcd9d 100644
--- a/lib/Analysis/Allocation.cpp
+++ b/lib/Analysis/Allocation.cpp
@@ -1,4 +1,5 @@
#include "triton/Analysis/Allocation.h"
+#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Analysis/Liveness.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -33,10 +34,8 @@ constexpr int kPtrBitWidth = 64;
static std::pair<SmallVector<unsigned>, SmallVector<unsigned>>
getCvtOrder(const Attribute &srcLayout, const Attribute &dstLayout) {
- auto srcBlockedLayout = srcLayout.dyn_cast<BlockedEncodingAttr>();
auto srcMmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>();
auto srcDotLayout = srcLayout.dyn_cast<DotOperandEncodingAttr>();
- auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>();
auto dstMmaLayout = dstLayout.dyn_cast<MmaEncodingAttr>();
auto dstDotLayout = dstLayout.dyn_cast<DotOperandEncodingAttr>();
assert(!(srcMmaLayout && dstMmaLayout) &&
@@ -224,14 +223,12 @@ class AllocationAnalysis {
}
void getValueAlias(Value value, SharedMemoryAliasAnalysis &analysis) {
- LatticeElement<AliasInfo> *latticeElement =
- analysis.lookupLatticeElement(value);
- if (latticeElement) {
- auto &info = latticeElement->getValue();
- if (!info.getAllocs().empty()) {
- for (auto alloc : info.getAllocs()) {
- allocation->addAlias(value, alloc);
- }
+ dataflow::Lattice<AliasInfo> *latticeElement =
+ analysis.getLatticeElement(value);
+ if (latticeElement && !latticeElement->isUninitialized()) {
+ AliasInfo &info = latticeElement->getValue();
+ for (auto alloc : info.getAllocs()) {
+ allocation->addAlias(value, alloc);
}
}
}
@@ -244,14 +241,19 @@ class AllocationAnalysis {
getScratchValueSize(op);
});
// Get the alias values
- SharedMemoryAliasAnalysis aliasAnalysis(operation->getContext());
- aliasAnalysis.run(operation);
+ std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
+ SharedMemoryAliasAnalysis *aliasAnalysis =
+ solver->load<SharedMemoryAliasAnalysis>();
+ if (failed(solver->initializeAndRun(operation))) {
+ // TODO: return error instead of bailing out..
+ llvm_unreachable("failed to run SharedMemoryAliasAnalysis");
+ }
operation->walk<WalkOrder::PreOrder>([&](Operation *op) {
for (auto operand : op->getOperands()) {
- getValueAlias(operand, aliasAnalysis);
+ getValueAlias(operand, *aliasAnalysis);
}
for (auto value : op->getResults()) {
- getValueAlias(value, aliasAnalysis);
+ getValueAlias(value, *aliasAnalysis);
}
});
}
diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp
index 0b7142b04d..4af46c3fbb 100644
--- a/lib/Analysis/AxisInfo.cpp
+++ b/lib/Analysis/AxisInfo.cpp
@@ -1,4 +1,4 @@
-#include "mlir/Analysis/DataFlowAnalysis.h"
+#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "llvm/Support/raw_ostream.h"
@@ -52,7 +52,7 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
BlockArgument blockArg = value.dyn_cast<BlockArgument>();
if (blockArg && blockArg.getOwner()->isEntryBlock()) {
Operation *op = blockArg.getOwner()->getParentOp();
- if (FuncOp fun = dyn_cast<FuncOp>(op)) {
+ if (func::FuncOp fun = dyn_cast<func::FuncOp>(op)) {
Attribute attr =
fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility");
if (attr)
@@ -136,8 +136,9 @@ class CastOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
public:
using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl;
- AxisInfo getAxisInfo(OpTy op,
- ArrayRef<LatticeElement<AxisInfo> *> operands) override {
+ AxisInfo
+ getAxisInfo(OpTy op,
+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
return operands[0]->getValue();
}
};
@@ -147,8 +148,9 @@ class MakeRangeOpAxisInfoVisitor final
public:
using AxisInfoVisitorImpl<triton::MakeRangeOp>::AxisInfoVisitorImpl;
- AxisInfo getAxisInfo(triton::MakeRangeOp op,
- ArrayRef<LatticeElement<AxisInfo> *> operands) override {
+ AxisInfo
+ getAxisInfo(triton::MakeRangeOp op,
+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
auto start = op.start();
auto end = op.end();
return AxisInfo(/*contiguity=*/{end - start},
@@ -162,8 +164,9 @@ class ConstantOpAxisInfoVisitor final
public:
using AxisInfoVisitorImpl<arith::ConstantOp>::AxisInfoVisitorImpl;
- AxisInfo getAxisInfo(arith::ConstantOp op,
- ArrayRef<LatticeElement<AxisInfo> *> operands) override {
+ AxisInfo
+ getAxisInfo(arith::ConstantOp op,
+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
auto intAttr = op.getValue().dyn_cast<IntegerAttr>();
auto boolAttr = op.getValue().dyn_cast<BoolAttr>();
if (intAttr || boolAttr) {
@@ -416,8 +419,9 @@ class SplatOpAxisInfoVisitor final
public:
using AxisInfoVisitorImpl<triton::SplatOp>::AxisInfoVisitorImpl;
- AxisInfo getAxisInfo(triton::SplatOp op,
- ArrayRef<LatticeElement<AxisInfo> *> operands) override {
+ AxisInfo
+ getAxisInfo(triton::SplatOp op,
+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
Type _retTy = *op->result_type_begin();
TensorType retTy = _retTy.cast<TensorType>();
AxisInfo opInfo = operands[0]->getValue();
@@ -439,8 +443,9 @@ class ExpandDimsOpAxisInfoVisitor final
public:
using AxisInfoVisitorImpl<triton::ExpandDimsOp>::AxisInfoVisitorImpl;
- AxisInfo getAxisInfo(triton::ExpandDimsOp op,
- ArrayRef<LatticeElement<AxisInfo> *> operands) override {
+ AxisInfo
+ getAxisInfo(triton::ExpandDimsOp op,
+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
AxisInfo opInfo = operands[0]->getValue();
AxisInfo::DimVectorT contiguity = opInfo.getContiguity();
AxisInfo::DimVectorT divisibility = opInfo.getDivisibility();
@@ -458,8 +463,9 @@ class BroadcastOpAxisInfoVisitor final
public:
using AxisInfoVisitorImpl<triton::BroadcastOp>::AxisInfoVisitorImpl;
- AxisInfo getAxisInfo(triton::BroadcastOp op,
- ArrayRef<LatticeElement<AxisInfo> *> operands) override {
+ AxisInfo
+ getAxisInfo(triton::BroadcastOp op,
+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
Type _retTy = *op->result_type_begin();
Type _opTy = *op->operand_type_begin();
TensorType retTy = _retTy.cast<TensorType>();
@@ -486,8 +492,9 @@ class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
public:
using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl;
- AxisInfo getAxisInfo(OpTy op,
- ArrayRef<LatticeElement<AxisInfo> *> operands) override {
+ AxisInfo
+ getAxisInfo(OpTy op,
+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
auto resTy = op.getResult().getType().template dyn_cast<RankedTensorType>();
if (!resTy)
return AxisInfo();
@@ -596,8 +603,9 @@ class SelectOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
public:
using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl;
- AxisInfo getAxisInfo(OpTy op,
- ArrayRef<LatticeElement<AxisInfo> *> operands) override {
+ AxisInfo
+ getAxisInfo(OpTy op,
+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
auto resTy = op.getResult().getType().template dyn_cast<RankedTensorType>();
if (!resTy)
return AxisInfo();
@@ -757,8 +765,9 @@ class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
public:
using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl;
- AxisInfo getAxisInfo(OpTy op,
- ArrayRef<LatticeElement<AxisInfo> *> operands) override {
+ AxisInfo
+ getAxisInfo(OpTy op,
+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
auto lhsInfo = operands[0]->getValue();
auto rhsInfo = operands[1]->getValue();
std::optional<int64_t> constantValue;
@@ -786,8 +795,8 @@ class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
// AxisInfoAnalysis
//===----------------------------------------------------------------------===//
-AxisInfoAnalysis::AxisInfoAnalysis(MLIRContext *context)
- : ForwardDataFlowAnalysis<AxisInfo>(context) {
+AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
+ : dataflow::SparseDataFlowAnalysis<dataflow::Lattice<AxisInfo>>(solver) {
// UnrealizedConversionCast:
// This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is
// in the process of a PartialConversion, where UnrealizedConversionCast
@@ -819,7 +828,7 @@ AxisInfoAnalysis::AxisInfoAnalysis(MLIRContext *context)
visitors.append<LogicalOpAxisInfoVisitor<arith::AndIOp>,
LogicalOpAxisInfoVisitor<arith::OrIOp>,
LogicalOpAxisInfoVisitor<arith::XOrIOp>>();
- visitors.append<SelectOpAxisInfoVisitor<mlir::SelectOp>,
+ visitors.append<SelectOpAxisInfoVisitor<mlir::arith::SelectOp>,
SelectOpAxisInfoVisitor<triton::gpu::SelectOp>>();
visitors.append<ShLIOpAxisInfoVisitor, ShROpAxisInfoVisitor<arith::ShRUIOp>,
ShROpAxisInfoVisitor<arith::ShRSIOp>>();
@@ -829,11 +838,12 @@ AxisInfoAnalysis::AxisInfoAnalysis(MLIRContext *context)
MaxMinOpAxisInfoVisitor<arith::MinUIOp>>();
}
-ChangeResult AxisInfoAnalysis::visitOperation(
- Operation *op, ArrayRef<LatticeElement<AxisInfo> *> operands) {
+void AxisInfoAnalysis::visitOperation(
+ Operation *op, ArrayRef<const dataflow::Lattice<AxisInfo> *> operands,
+ ArrayRef<dataflow::Lattice<AxisInfo> *> results) {
AxisInfo curr = visitors.apply(op, operands);
if (curr.getRank() == 0) {
- return markAllPessimisticFixpoint(op->getResults());
+ return markAllPessimisticFixpoint(results);
}
// override with hint
auto newContiguity = curr.getContiguity();
@@ -854,11 +864,8 @@ ChangeResult AxisInfoAnalysis::visitOperation(
curr = mlir::AxisInfo(newContiguity, newDivisibility, newConstancy,
curr.getConstantValue());
// join all lattice elements
- ChangeResult result = ChangeResult::NoChange;
- for (Value value : op->getResults()) {
- result |= getLatticeElement(value).join(curr);
- }
- return result;
+ for (auto *result : results)
+ propagateIfChanged(result, result->join(curr));
}
unsigned AxisInfoAnalysis::getPtrContiguity(Value ptr) {
@@ -884,7 +891,10 @@ unsigned AxisInfoAnalysis::getPtrAlignment(Value ptr) {
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
if (!tensorTy)
return 1;
- auto axisInfo = lookupLatticeElement(ptr)->getValue();
+ dataflow::Lattice<AxisInfo> *latticeElement = getLatticeElement(ptr);
+ if (!latticeElement || latticeElement->isUninitialized())
+ return 1;
+ auto axisInfo = latticeElement->getValue();
auto layout = tensorTy.getEncoding();
auto order = triton::gpu::getOrder(layout);
auto maxMultipleBytes = axisInfo.getDivisibility(order[0]);
@@ -900,8 +910,11 @@ unsigned AxisInfoAnalysis::getMaskAlignment(Value mask) {
auto tensorTy = mask.getType().dyn_cast<RankedTensorType>();
if (!tensorTy)
return 1;
+ dataflow::Lattice<AxisInfo> *latticeElement = getLatticeElement(mask);
+ if (!latticeElement || latticeElement->isUninitialized())
+ return 1;
+ auto maskAxis = latticeElement->getValue();
auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding());
- auto maskAxis = lookupLatticeElement(mask)->getValue();
auto alignment = std::max<unsigned>(maskAxis.getConstancy(maskOrder[0]), 1);
return alignment;
}
diff --git a/lib/Analysis/CMakeLists.txt b/lib/Analysis/CMakeLists.txt
index afbc692510..1f761f845c 100644
--- a/lib/Analysis/CMakeLists.txt
+++ b/lib/Analysis/CMakeLists.txt
@@ -8,7 +8,7 @@ add_mlir_library(TritonAnalysis
DEPENDS
TritonTableGen
TritonGPUAttrDefsIncGen
-
+
LINK_LIBS PUBLIC
MLIRAnalysis
)
diff --git a/lib/Analysis/Membar.cpp b/lib/Analysis/Membar.cpp
index acc885e827..910274b2ac 100644
--- a/lib/Analysis/Membar.cpp
+++ b/lib/Analysis/Membar.cpp
@@ -2,7 +2,7 @@
#include "triton/Analysis/Alias.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
-#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
namespace mlir {
diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp
index d9e917e731..6ea52df272 100644
--- a/lib/Analysis/Utility.cpp
+++ b/lib/Analysis/Utility.cpp
@@ -1,5 +1,8 @@
#include "triton/Analysis/Utility.h"
+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
+#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Matchers.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <deque>
@@ -325,4 +328,55 @@ SetVector<Operation *> multiRootGetSlice(Operation *op,
return multiRootTopologicalSort(slice);
}
+namespace {
+// Copied from TestDeadCodeAnalysis.cpp, because some dead code analysis
+// interacts with constant propagation, but SparseConstantPropagation
+// doesn't seem to be sufficient.
+struct ConstantAnalysis : public DataFlowAnalysis {
+ using DataFlowAnalysis::DataFlowAnalysis;
+
+ LogicalResult initialize(Operation *top) override {
+ WalkResult result = top->walk([&](Operation *op) {
+ if (failed(visit(op)))
+ return WalkResult::interrupt();
+ return WalkResult::advance();
+ });
+ return success(!result.wasInterrupted());
+ }
+
+ LogicalResult visit(ProgramPoint point) override {
+ Operation *op = point.get<Operation *>();
+ Attribute value;
+ if (matchPattern(op, m_Constant(&value))) {
+ auto *constant = getOrCreate<dataflow::Lattice<dataflow::ConstantValue>>(
+ op->getResult(0));
+ propagateIfChanged(constant, constant->join(dataflow::ConstantValue(
+ value, op->getDialect())));
+ return success();
+ }
+ setAllToUnknownConstants(op->getResults());
+ for (Region &region : op->getRegions())
+ setAllToUnknownConstants(region.getArguments());
+ return success();
+ }
+
+ /// Set all given values as not constants.
+ void setAllToUnknownConstants(ValueRange values) {
+ dataflow::ConstantValue unknownConstant(nullptr, nullptr);
+ for (Value value : values) {
+ auto *constant =
+ getOrCreate<dataflow::Lattice<dataflow::ConstantValue>>(value);
+ propagateIfChanged(constant, constant->join(unknownConstant));
+ }
+ }
+};
+} // namespace
+
+std::unique_ptr<DataFlowSolver> createDataFlowSolver() {
+ auto solver = std::make_unique<DataFlowSolver>();
+ solver->load<dataflow::DeadCodeAnalysis>();
+ solver->load<ConstantAnalysis>();
+ return solver;
+}
+
} // namespace mlir
diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
index 6a46265bd7..e352eb3698 100644
--- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
+++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
@@ -159,9 +159,6 @@ struct ConvertLayoutOpConversion
Value smemBase) const {
auto accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
auto layout = type.getEncoding();
- auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>();
- auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>();
- auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>();
auto rank = type.getRank();
auto sizePerThread = getSizePerThread(layout);
auto accumSizePerThread = product<unsigned>(sizePerThread);
diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h b/lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h
index 4b89965aa9..1d9e00519b 100644
--- a/lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h
+++ b/lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h
@@ -7,10 +7,8 @@
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
-#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
-#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
-#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Matchers.h"
@@ -422,9 +420,9 @@ struct MMA16816ConversionHelper {
MMA16816ConversionHelper(Type dotOperand, MmaEncodingAttr mmaLayout,
Value thread, ConversionPatternRewriter &rewriter,
TypeConverter *typeConverter, Location loc)
- : mmaLayout(mmaLayout), thread(thread), helper(mmaLayout),
- rewriter(rewriter), typeConverter(typeConverter), loc(loc),
- ctx(mmaLayout.getContext()), wpt(mmaLayout.getWarpsPerCTA()) {
+ : mmaLayout(mmaLayout), wpt(mmaLayout.getWarpsPerCTA()), thread(thread),
+ helper(mmaLayout), rewriter(rewriter), typeConverter(typeConverter),
+ loc(loc), ctx(mmaLayout.getContext()) {
helper.deduceMmaType(dotOperand);
Value _32 = i32_val(32);
diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp
index 0f8070ca9f..e4bd47c411 100644
--- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp
+++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp
@@ -115,8 +115,6 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
auto DTensorTy = D.getType().cast<RankedTensorType>();
auto AShape = ATensorTy.getShape();
auto BShape = BTensorTy.getShape();
- auto DShape = DTensorTy.getShape();
- auto wpt = mmaLayout.getWarpsPerCTA();
bool isARow = ALayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
bool isBRow = BLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
@@ -221,7 +219,6 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
ConversionPatternRewriter &rewriter) const {
auto *ctx = rewriter.getContext();
auto loc = op.getLoc();
- auto threadId = getThreadId(rewriter, loc);
auto A = op.a();
auto B = op.b();
@@ -230,12 +227,10 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
auto aTensorTy = A.getType().cast<RankedTensorType>();
auto bTensorTy = B.getType().cast<RankedTensorType>();
- auto cTensorTy = C.getType().cast<RankedTensorType>();
auto dTensorTy = D.getType().cast<RankedTensorType>();
auto aShape = aTensorTy.getShape();
auto bShape = bTensorTy.getShape();
- auto cShape = cTensorTy.getShape();
BlockedEncodingAttr dLayout =
dTensorTy.getEncoding().cast<BlockedEncodingAttr>();
diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
index deb71b9597..0b9e67674b 100644
--- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
+++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
@@ -61,7 +61,6 @@ struct FpToFpOpConversion
convertFp16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
- auto ctx = rewriter.getContext();
auto fp16x2VecTy = vec_ty(f16_ty, 2);
Value fp16x2Vec0 = undef(fp16x2VecTy);
Value fp16x2Vec1 = undef(fp16x2VecTy);
@@ -153,7 +152,6 @@ struct FpToFpOpConversion
convertBf16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
const Value &v0, const Value &v1, const Value &v2,
const Value &v3) {
- auto ctx = rewriter.getContext();
auto bf16x2VecTy = vec_ty(i16_ty, 2);
Value bf16x2Vec0 = undef(bf16x2VecTy);
Value bf16x2Vec1 = undef(bf16x2VecTy);
diff --git a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp
index 9a8b4702bc..bae675f0cb 100644
--- a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp
+++ b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp
@@ -109,7 +109,8 @@ struct LoadOpConversion
DenseElementsAttr constAttr;
int64_t splatVal = 0;
if (other && valueElemTy.isa<IntegerType>() &&
- matchPattern(other, m_Constant(&constAttr)) && constAttr.isSplat()) {
+ matchPattern(other, m_Constant(&constAttr)) && constAttr.isSplat() &&
+ constAttr.getElementType().isa<IntegerType>()) {
otherIsSplatConstInt = true;
splatVal = constAttr.getSplatValue<APInt>().getSExtValue();
}
@@ -333,7 +334,6 @@ struct StoreOpConversion
elem = rewriter.create<LLVM::SExtOp>(loc, type::i8Ty(ctx), elem);
elem = bitcast(elem, valueElemTy);
- Type u32Ty = typeConverter->convertType(type::u32Ty(ctx));
llWord = insert_element(wordTy, llWord, elem, i32_val(elemIdx));
}
llWord = bitcast(llWord, valArgTy);
@@ -387,7 +387,6 @@ struct AtomicCASOpConversion
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
MLIRContext *ctx = rewriter.getContext();
- Value ptr = op.ptr();
Value llPtr = adaptor.ptr();
Value llCmp = adaptor.cmp();
diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
index 69abd889be..1c973dc196 100644
--- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
+++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
@@ -286,7 +286,6 @@ struct ReduceOpConversion
auto srcTy = op.operand().getType().cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding();
auto srcShape = srcTy.getShape();
- auto srcRank = srcTy.getRank();
auto order = getOrder(srcLayout);
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout);
@@ -351,7 +350,6 @@ struct ReduceOpConversion
Value zero = i32_val(0);
Value laneZero = icmp_eq(laneIdAxis, zero);
- Value warpZero = icmp_eq(warpIdAxis, zero);
for (auto it : accs) {
const SmallVector<unsigned> &key = it.first;
diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp
index 5b77150b1a..78cfa076bd 100644
--- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp
+++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp
@@ -11,11 +11,11 @@ using ::mlir::LLVM::getStructFromElements;
using ::mlir::triton::gpu::getElemsPerThread;
using ::mlir::triton::gpu::SharedEncodingAttr;
-struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> {
- using ConvertOpToLLVMPattern<ReturnOp>::ConvertOpToLLVMPattern;
+struct ReturnOpConversion : public ConvertOpToLLVMPattern<func::ReturnOp> {
+ using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
+ matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
unsigned numArguments = op.getNumOperands();
@@ -476,7 +476,6 @@ struct ExtractSliceOpConversion
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
auto elemPtrTy = ptr_ty(llvmElemTy, 3);
- auto resTy = op.getType().dyn_cast<RankedTensorType>();
smemObj = SharedMemoryObject(gep(elemPtrTy, smemObj.base, offset),
strideVals, offsetVals);
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h
index bb10d5b24a..00e399f848 100644
--- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h
+++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h
@@ -4,6 +4,7 @@
// TODO: refactor so that it doesn't fail if Allocation.h
// is included after utility.h (due to conflict in `store` macro
// and <atomic>
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "triton/Analysis/Allocation.h"
//
@@ -39,15 +40,15 @@ void vprintf_array(Value thread, ArrayRef<Value> arr, std::string info,
// TODO(Superjomn): remove the code when MLIR v15.0 is included.
// All the rights are reserved by the LLVM community.
-struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
+struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
private:
/// Only retain those attributes that are not constructed by
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
/// attributes.
- static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
- bool filterArgAttrs,
+ static void filterFuncAttributes(func::FuncOp op, bool filterArgAttrs,
SmallVectorImpl<NamedAttribute> &result) {
- for (const auto &attr : attrs) {
+
+ for (const auto &attr : op->getAttrs()) {
if (attr.getName() == SymbolTable::getSymbolAttrName() ||
attr.getName() == FunctionOpInterface::getTypeAttrName() ||
attr.getName() == "std.varargs" ||
@@ -65,27 +66,27 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
}
protected:
- using ConvertOpToLLVMPattern<FuncOp>::ConvertOpToLLVMPattern;
+ using ConvertOpToLLVMPattern<func::FuncOp>::ConvertOpToLLVMPattern;
// Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided
// to this legalization pattern.
LLVM::LLVMFuncOp
- convertFuncOpToLLVMFuncOp(FuncOp funcOp,
+ convertFuncOpToLLVMFuncOp(func::FuncOp funcOp,
ConversionPatternRewriter &rewriter) const {
// Convert the original function arguments. They are converted using the
// LLVMTypeConverter provided to this legalization pattern.
auto varargsAttr = funcOp->getAttrOfType<BoolAttr>("func.varargs");
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
auto llvmType = getTypeConverter()->convertFunctionSignature(
- funcOp.getType(), varargsAttr && varargsAttr.getValue(), result);
+ funcOp.getFunctionType(), varargsAttr && varargsAttr.getValue(),
+ result);
if (!llvmType)
return nullptr;
// Propagate argument/result attributes to all converted arguments/result
// obtained after converting a given original argument/result.
SmallVector<NamedAttribute, 4> attributes;
- filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/true,
- attributes);
+ filterFuncAttributes(funcOp, /*filterArgAttrs=*/true, attributes);
if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
assert(!resAttrDicts.empty() && "expected array to be non-empty");
auto newResAttrDicts =
@@ -131,7 +132,7 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
}
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
- /*dsoLocal*/ false, attributes);
+ /*dsoLocal*/ false, LLVM::CConv::C, attributes);
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
@@ -191,8 +192,8 @@ class ConvertTritonGPUOpToLLVMPatternBase {
const Allocation *allocation,
Value smem,
IndexCacheInfo indexCacheInfo)
- : converter(&typeConverter), indexCacheInfo(indexCacheInfo),
- allocation(allocation), smem(smem) {}
+ : converter(&typeConverter), allocation(allocation), smem(smem),
+ indexCacheInfo(indexCacheInfo) {}
LLVMTypeConverter *getTypeConverter() const { return converter; }
@@ -861,7 +862,6 @@ class ConvertTritonGPUOpToLLVMPatternBase {
ArrayRef<int64_t> shape) const {
auto parent = sliceLayout.getParent();
unsigned dim = sliceLayout.getDim();
- size_t rank = shape.size();
auto parentIndices =
emitIndices(loc, rewriter, parent, sliceLayout.paddedShape(shape));
unsigned numIndices = parentIndices.size();
diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp
index ff1af09835..6f66af4e34 100644
--- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp
+++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp
@@ -1,10 +1,11 @@
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
+#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
+#include "mlir/Conversion/ControlFlowToLLVM//ControlFlowToLLVM.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
-#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
-#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Pass/Pass.h"
@@ -40,7 +41,6 @@ class TritonLLVMConversionTarget : public ConversionTarget {
addIllegalDialect<triton::TritonDialect>();
addIllegalDialect<triton::gpu::TritonGPUDialect>();
addIllegalDialect<mlir::gpu::GPUDialect>();
- addIllegalDialect<mlir::StandardOpsDialect>();
addLegalOp<mlir::UnrealizedConversionCastOp>();
}
};
@@ -51,7 +51,7 @@ class TritonLLVMFunctionConversionTarget : public ConversionTarget {
: ConversionTarget(ctx) {
addLegalDialect<LLVM::LLVMDialect>();
addLegalDialect<NVVM::NVVMDialect>();
- addIllegalOp<mlir::FuncOp>();
+ addIllegalOp<mlir::func::FuncOp>();
addLegalOp<mlir::UnrealizedConversionCastOp>();
}
};
@@ -69,7 +69,7 @@ struct FuncOpConversion : public FuncOpConversionBase {
: FuncOpConversionBase(converter, benefit), numWarps(numWarps) {}
LogicalResult
- matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor,
+ matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
if (!newFuncOp)
@@ -133,7 +133,8 @@ class ConvertTritonGPUToLLVM
decomposeBlockedToDotOperand(mod);
// Step 2
- decomposeInsertSliceAsyncOp(mod);
+ if (failed(decomposeInsertSliceAsyncOp(mod)))
+ return signalPassFailure();
// Step 3
Allocation allocation(mod);
@@ -142,7 +143,7 @@ class ConvertTritonGPUToLLVM
// Step 4
RewritePatternSet scf_patterns(context);
- mlir::populateLoopToStdConversionPatterns(scf_patterns);
+ mlir::populateSCFToControlFlowConversionPatterns(scf_patterns);
mlir::ConversionTarget scf_target(*context);
scf_target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp,
scf::WhileOp, scf::ExecuteRegionOp>();
@@ -159,8 +160,10 @@ class ConvertTritonGPUToLLVM
return signalPassFailure();
// Step 6 - get axis and shared memory info
- AxisInfoAnalysis axisInfoAnalysis(mod.getContext());
- axisInfoAnalysis.run(mod);
+ std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
+ AxisInfoAnalysis *axisInfoAnalysis = solver->load<AxisInfoAnalysis>();
+ if (failed(solver->initializeAndRun(mod)))
+ return signalPassFailure();
initSharedMemory(allocation.getSharedMemorySize(), typeConverter);
mod->setAttr("triton_gpu.shared",
mlir::IntegerAttr::get(mlir::IntegerType::get(context, 32),
@@ -178,38 +181,39 @@ class ConvertTritonGPUToLLVM
// Normal conversions
populateTritonGPUToLLVMPatterns(typeConverter, patterns, numWarps,
- axisInfoAnalysis, &allocation, smem,
+ *axisInfoAnalysis, &allocation, smem,
indexCacheInfo, /*benefit=*/10);
// ConvertLayoutOp
populateConvertLayoutOpToLLVMPatterns(typeConverter, patterns, numWarps,
- axisInfoAnalysis, &allocation, smem,
+ *axisInfoAnalysis, &allocation, smem,
indexCacheInfo, /*benefit=*/10);
// DotOp
populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps,
- axisInfoAnalysis, &allocation, smem,
+ *axisInfoAnalysis, &allocation, smem,
/*benefit=*/10);
// ElementwiseOp
populateElementwiseOpToLLVMPatterns(typeConverter, patterns, numWarps,
- axisInfoAnalysis, &allocation, smem,
+ *axisInfoAnalysis, &allocation, smem,
/*benefit=*/10);
// LoadStoreOp
populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, numWarps,
- axisInfoAnalysis, &allocation, smem,
+ *axisInfoAnalysis, &allocation, smem,
indexCacheInfo, /*benefit=*/10);
// ReduceOp
populateReduceOpToLLVMPatterns(typeConverter, patterns, numWarps,
- axisInfoAnalysis, &allocation, smem,
+ *axisInfoAnalysis, &allocation, smem,
indexCacheInfo, /*benefit=*/10);
// ViewOp
populateViewOpToLLVMPatterns(typeConverter, patterns, numWarps,
- axisInfoAnalysis, &allocation, smem,
+ *axisInfoAnalysis, &allocation, smem,
/*benefit=*/10);
// Add arith/math's patterns to help convert scalar expression to LLVM.
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
patterns);
mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns);
- mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns);
+ mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
+ patterns);
mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns);
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
@@ -306,9 +310,11 @@ class ConvertTritonGPUToLLVM
});
}
- void decomposeInsertSliceAsyncOp(ModuleOp mod) const {
- AxisInfoAnalysis axisInfoAnalysis(mod.getContext());
- axisInfoAnalysis.run(mod);
+ LogicalResult decomposeInsertSliceAsyncOp(ModuleOp mod) const {
+ std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
+ AxisInfoAnalysis *axisInfoAnalysis = solver->load<AxisInfoAnalysis>();
+ if (failed(solver->initializeAndRun(mod)))
+ return failure();
// TODO(Keren): This is a hacky knob that may cause performance regression
// when decomposition has been performed. We should remove this knob once we
// have thorough analysis on async wait. Currently, we decompose
@@ -342,7 +348,7 @@ class ConvertTritonGPUToLLVM
auto resSharedLayout =
dstTy.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
auto resElemTy = dstTy.getElementType();
- unsigned inVec = axisInfoAnalysis.getPtrContiguity(src);
+ unsigned inVec = axisInfoAnalysis->getPtrContiguity(src);
unsigned outVec = resSharedLayout.getVec();
unsigned minVec = std::min(outVec, inVec);
auto maxBitWidth =
@@ -400,11 +406,11 @@ class ConvertTritonGPUToLLVM
} else if (decomposed) {
// Wait for all previous async ops
OpBuilder builder(asyncWaitOp);
- auto newAsyncWaitOp =
- builder.create<triton::gpu::AsyncWaitOp>(asyncWaitOp.getLoc(), 0);
+ builder.create<triton::gpu::AsyncWaitOp>(asyncWaitOp.getLoc(), 0);
asyncWaitOp.erase();
}
});
+ return success();
}
};
diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.h b/lib/Conversion/TritonGPUToLLVM/Utility.h
index d35dac28c5..11976908cf 100644
--- a/lib/Conversion/TritonGPUToLLVM/Utility.h
+++ b/lib/Conversion/TritonGPUToLLVM/Utility.h
@@ -220,10 +220,7 @@ struct SharedMemoryObject {
ConversionPatternRewriter &rewriter)
: base(base) {
strides = getStridesFromShapeAndOrder(shape, order, loc, rewriter);
-
- for (auto idx : order) {
- offsets.emplace_back(i32_val(0));
- }
+ offsets.append(order.size(), i32_val(0));
}
SmallVector<Value> getElems() const {
diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
index fe42202c34..5f230f787f 100644
--- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
+++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
@@ -1,10 +1,10 @@
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
-#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
@@ -59,10 +59,13 @@ class ArithConstantPattern : public OpConversionPattern<arith::ConstantOp> {
Type retType = getTypeConverter()->convertType(op.getType());
auto value = adaptor.getValue().dyn_cast<DenseElementsAttr>();
assert(value);
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(
- op, retType,
- value.reshape(retType) // This is a hack. We just want to add encoding
- );
+ if (value.getElementType().isInteger(1) && value.isSplat())
+ // Workaround until https://reviews.llvm.org/D133743 is included.
+ value = DenseElementsAttr::get(retType, value.getSplatValue<bool>());
+ else
+ // This is a hack. We just want to add encoding
+ value = value.reshape(retType);
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, retType, value);
return success();
}
};
@@ -127,12 +130,12 @@ void populateArithmeticPatternsAndLegality(
}
// this shouldn't exist if mlir's SelectOp checked encodings properly
-class StdSelectPattern : public OpConversionPattern<SelectOp> {
+class StdSelectPattern : public OpConversionPattern<arith::SelectOp> {
public:
- using OpConversionPattern<SelectOp>::OpConversionPattern;
+ using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(SelectOp op, typename SelectOp::Adaptor adaptor,
+ matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::gpu::SelectOp>(
@@ -148,8 +151,8 @@ void populateStdPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
MLIRContext *context = patterns.getContext();
// Rewrite rule
patterns.add<StdSelectPattern>(typeConverter, context);
- target.addLegalOp<ReturnOp>(); // this is ok because all functions are inlined
- // by the frontend
+ target.addLegalOp<func::ReturnOp>(); // this is ok because all functions are
+ // inlined by the frontend
}
void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
@@ -455,18 +458,19 @@ struct TritonPrintfPattern : public OpConversionPattern<triton::PrintfOp> {
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
- patterns.add< // TODO: view should have custom pattern that views the layout
- TritonGenericPattern<triton::ViewOp>,
- TritonGenericPattern<triton::BitcastOp>,
- TritonGenericPattern<triton::FpToFpOp>,
- TritonGenericPattern<triton::IntToPtrOp>,
- TritonGenericPattern<triton::PtrToIntOp>,
- TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
- TritonGenericPattern<triton::AddPtrOp>, TritonCatPattern,
- TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern,
- TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
- TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern,
- TritonAtomicRMWPattern>(typeConverter, context);
+ patterns
+ .insert< // TODO: view should have custom pattern that views the layout
+ TritonGenericPattern<triton::ViewOp>,
+ TritonGenericPattern<triton::BitcastOp>,
+ TritonGenericPattern<triton::FpToFpOp>,
+ TritonGenericPattern<triton::IntToPtrOp>,
+ TritonGenericPattern<triton::PtrToIntOp>,
+ TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
+ TritonGenericPattern<triton::AddPtrOp>, TritonCatPattern,
+ TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern,
+ TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
+ TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern,
+ TritonAtomicRMWPattern>(typeConverter, context);
}
//
@@ -623,29 +627,28 @@ void populateSCFPatterns(TritonGPUTypeConverter &typeConverter,
// CF
-class CFBranchPattern : public OpConversionPattern<BranchOp> {
+class CFBranchPattern : public OpConversionPattern<cf::BranchOp> {
public:
- using OpConversionPattern<BranchOp>::OpConversionPattern;
+ using OpConversionPattern<cf::BranchOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(BranchOp op, BranchOp::Adaptor adaptor,
+ matchAndRewrite(cf::BranchOp op, cf::BranchOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto converter = getTypeConverter();
- auto newOp = rewriter.replaceOpWithNewOp<BranchOp>(op, op.getSuccessor(),
- adaptor.getOperands());
+ auto newOp = rewriter.replaceOpWithNewOp<cf::BranchOp>(
+ op, op.getSuccessor(), adaptor.getOperands());
return success();
}
};
-class CFCondBranchPattern : public OpConversionPattern<CondBranchOp> {
+class CFCondBranchPattern : public OpConversionPattern<cf::CondBranchOp> {
public:
- using OpConversionPattern<CondBranchOp>::OpConversionPattern;
+ using OpConversionPattern<cf::CondBranchOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(CondBranchOp op, CondBranchOp::Adaptor adaptor,
+ matchAndRewrite(cf::CondBranchOp op, cf::CondBranchOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto converter = getTypeConverter();
- auto newOp = rewriter.replaceOpWithNewOp<CondBranchOp>(
+ auto newOp = rewriter.replaceOpWithNewOp<cf::CondBranchOp>(
op, adaptor.getCondition(), op.getTrueDest(),
adaptor.getTrueDestOperands(), op.getFalseDest(),
adaptor.getFalseDestOperands());
diff --git a/lib/Dialect/Triton/IR/CMakeLists.txt b/lib/Dialect/Triton/IR/CMakeLists.txt
index 2d679b21fd..705554ba6b 100644
--- a/lib/Dialect/Triton/IR/CMakeLists.txt
+++ b/lib/Dialect/Triton/IR/CMakeLists.txt
@@ -10,11 +10,7 @@ add_mlir_dialect_library(TritonIR
LINK_LIBS PUBLIC
MLIRIR
- MLIRArithmetic
- MLIRSCF
-
- # Since LLVM 15
- # MLIRFunc
- # else
- MLIRStandard
+ MLIRArithmeticDialect
+ MLIRSCFDialect
+ MLIRFuncDialect
)
diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp
index 3aadbfa0c0..86570359c5 100644
--- a/lib/Dialect/Triton/IR/Ops.cpp
+++ b/lib/Dialect/Triton/IR/Ops.cpp
@@ -1,10 +1,9 @@
-#include "triton/Dialect/Triton/IR/Dialect.h"
-#include "triton/Dialect/Triton/IR/Types.h"
-
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OperationSupport.h"
+#include "triton/Dialect/Triton/IR/Dialect.h"
+#include "triton/Dialect/Triton/IR/Types.h"
namespace mlir {
namespace triton {
@@ -38,8 +37,8 @@ static Type getPointerTypeSameShape(Type type) {
}
// Parser & printer for assembly forms
-ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
- SmallVector<OpAsmParser::OperandType, 4> allOperands;
+ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
Type resultTypes[1];
SMLoc allOperandLoc = parser.getCurrentLocation();
if (parser.parseOperandList(allOperands) ||
@@ -73,18 +72,18 @@ ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
return success();
}
-void printLoadOp(OpAsmPrinter &printer, LoadOp loadOp) {
+void LoadOp::print(OpAsmPrinter &printer) {
printer << " ";
- printer << loadOp.getOperation()->getOperands();
+ printer << getOperation()->getOperands();
// "operand_segment_sizes" can be deduced, so we don't print it.
- printer.printOptionalAttrDict(loadOp->getAttrs(),
- {loadOp.operand_segment_sizesAttrName()});
+ printer.printOptionalAttrDict(getOperation()->getAttrs(),
+ {operand_segment_sizesAttrName()});
printer << " : ";
- printer.printStrippedAttrOrType(loadOp.result().getType());
+ printer.printStrippedAttrOrType(getResult().getType());
}
-ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
- SmallVector<OpAsmParser::OperandType, 4> allOperands;
+ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
Type valueType;
SMLoc allOperandLoc = parser.getCurrentLocation();
if (parser.parseOperandList(allOperands) ||
@@ -104,12 +103,12 @@ ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
return success();
}
-void printStoreOp(OpAsmPrinter &printer, StoreOp storeOp) {
+void StoreOp::print(OpAsmPrinter &printer) {
printer << " ";
- printer << storeOp.getOperation()->getOperands();
- printer.printOptionalAttrDict(storeOp->getAttrs(), /*elidedAttrs=*/{});
+ printer << getOperation()->getOperands();
+ printer.printOptionalAttrDict(getOperation()->getAttrs(), /*elidedAttrs=*/{});
printer << " : ";
- printer.printStrippedAttrOrType(storeOp.value().getType());
+ printer.printStrippedAttrOrType(value().getType());
}
} // namespace triton
@@ -319,7 +318,8 @@ OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
if (!constOperand)
return {};
auto shapedType = getType().cast<ShapedType>();
- auto ret = SplatElementsAttr::get(shapedType, {constOperand.getValue()});
+ auto ret = SplatElementsAttr::get(
+ shapedType, ArrayRef<Attribute>(constOperand.getValue()));
return ret;
}
diff --git a/lib/Dialect/Triton/Transforms/Combine.cpp b/lib/Dialect/Triton/Transforms/Combine.cpp
index 2261472170..11570283d6 100644
--- a/lib/Dialect/Triton/Transforms/Combine.cpp
+++ b/lib/Dialect/Triton/Transforms/Combine.cpp
@@ -57,13 +57,13 @@ DenseElementsAttr getConstantValue(Builder &builder, Attribute value,
class CombineSelectMaskedLoadPattern : public mlir::RewritePattern {
public:
CombineSelectMaskedLoadPattern(mlir::MLIRContext *context)
- : mlir::RewritePattern(mlir::SelectOp::getOperationName(), 3, context,
- {triton::LoadOp::getOperationName()}) {}
+ : mlir::RewritePattern(mlir::arith::SelectOp::getOperationName(), 3,
+ context, {triton::LoadOp::getOperationName()}) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
- auto selectOp = llvm::dyn_cast<mlir::SelectOp>(op);
+ auto selectOp = llvm::dyn_cast<mlir::arith::SelectOp>(op);
if (!selectOp)
return mlir::failure();
diff --git a/lib/Dialect/Triton/Transforms/Combine.td b/lib/Dialect/Triton/Transforms/Combine.td
index 14f286b26e..ded0e346e6 100644
--- a/lib/Dialect/Triton/Transforms/Combine.td
+++ b/lib/Dialect/Triton/Transforms/Combine.td
@@ -1,9 +1,9 @@
#ifndef TRITON_PATTERNS
#define TRITON_PATTERNS
-include "mlir/Dialect/StandardOps/IR/Ops.td"
include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.td"
include "triton/Dialect/Triton/IR/TritonOps.td"
+include "mlir/IR/PatternBase.td"
// AddIOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d)
diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp
index 1fbc609e88..bfc3f3d3da 100644
--- a/lib/Dialect/TritonGPU/IR/Dialect.cpp
+++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp
@@ -1,14 +1,14 @@
+#include "triton/Dialect/Triton/IR/Dialect.h"
+
#include <numeric>
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "triton/Analysis/Utility.h"
-#include "triton/Dialect/Triton/IR/Dialect.h"
+#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "llvm/ADT/TypeSwitch.h"
-#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
-
using namespace mlir;
using namespace mlir::triton::gpu;
@@ -366,7 +366,6 @@ template SmallVector<int64_t>
SliceEncodingAttr::paddedShape<int64_t>(ArrayRef<int64_t> shape) const;
unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
- size_t rank = shape.size();
auto parent = getParent();
return ::getElemsPerThread(parent, paddedShape(shape));
}
@@ -655,9 +654,9 @@ void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
// InsertSliceAsyncOp
//===----------------------------------------------------------------------===//
-ParseResult parseInsertSliceAsyncOp(OpAsmParser &parser,
- OperationState &result) {
- SmallVector<OpAsmParser::OperandType, 8> allOperands;
+ParseResult InsertSliceAsyncOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ SmallVector<OpAsmParser::UnresolvedOperand, 8> allOperands;
Type srcType, dstType;
SMLoc allOperandLoc = parser.getCurrentLocation();
if (parser.parseOperandList(allOperands) ||
@@ -696,18 +695,16 @@ ParseResult parseInsertSliceAsyncOp(OpAsmParser &parser,
return success();
}
-void printInsertSliceAsyncOp(OpAsmPrinter &printer,
- InsertSliceAsyncOp insertSliceAsyncOp) {
+void InsertSliceAsyncOp::print(OpAsmPrinter &printer) {
printer << " ";
- printer << insertSliceAsyncOp.getOperation()->getOperands();
+ printer << getOperation()->getOperands();
// "operand_segment_sizes" can be deduced, so we don't print it.
- printer.printOptionalAttrDict(
- insertSliceAsyncOp->getAttrs(),
- {insertSliceAsyncOp.operand_segment_sizesAttrName()});
+ printer.printOptionalAttrDict(getOperation()->getAttrs(),
+ {operand_segment_sizesAttrName()});
printer << " : ";
- printer.printStrippedAttrOrType(insertSliceAsyncOp.src().getType());
+ printer.printStrippedAttrOrType(src().getType());
printer << " -> ";
- printer.printStrippedAttrOrType(insertSliceAsyncOp.result().getType());
+ printer.printStrippedAttrOrType(result().getType());
}
//===----------------------------------------------------------------------===//
diff --git a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp
index 82407980d3..ee6009f44a 100644
--- a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp
+++ b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp
@@ -27,7 +27,11 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
auto origType = ptr.getType().cast<RankedTensorType>();
// Get the shape of the tensor.
size_t rank = origType.getRank();
- AxisInfo info = axisInfo.lookupLatticeElement(ptr)->getValue();
+ dataflow::Lattice<AxisInfo> *latticeElement =
+ axisInfo.getLatticeElement(ptr);
+ AxisInfo info = latticeElement && !latticeElement->isUninitialized()
+ ? latticeElement->getValue()
+ : AxisInfo();
// Get the contiguity order of `ptr`
auto order = argSort(info.getContiguity());
// The desired divisibility is the maximum divisibility
@@ -40,7 +44,7 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
for (Value val : op->getResults()) {
if (val.getType() != origType)
continue;
- auto valInfo = axisInfo.lookupLatticeElement(val);
+ auto valInfo = axisInfo.getLatticeElement(val);
auto currOrder = argSort(valInfo->getValue().getContiguity());
if (order == currOrder)
withSameOrder.insert(val);
@@ -55,7 +59,7 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
unsigned elemNumBytes = std::max(elemNumBits / 8, 1u);
unsigned perThread = 1;
for (Value val : withSameOrder) {
- AxisInfo info = axisInfo.lookupLatticeElement(val)->getValue();
+ AxisInfo info = axisInfo.getLatticeElement(val)->getValue();
unsigned maxMultipleBytes = info.getDivisibility(order[0]);
unsigned maxMultiple = std::max(maxMultipleBytes / elemNumBytes, 1u);
unsigned maxContig = info.getContiguity(order[0]);
@@ -123,8 +127,10 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
void runOnOperation() override {
Operation *op = getOperation();
// Run axis info analysis
- AxisInfoAnalysis axisInfo(&getContext());
- axisInfo.run(op);
+ std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
+ AxisInfoAnalysis *axisInfo = solver->load<AxisInfoAnalysis>();
+ if (failed(solver->initializeAndRun(op)))
+ return signalPassFailure();
// For each i/o operation, we determine what layout
// the pointers should have for best memory coalescing
@@ -146,10 +152,10 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
RankedTensorType ty = ptr.getType().template dyn_cast<RankedTensorType>();
if (!ty || !ty.getElementType().isa<PointerType>())
return;
- AxisInfo info = axisInfo.lookupLatticeElement(ptr)->getValue();
+ AxisInfo info = axisInfo->getLatticeElement(ptr)->getValue();
auto mod = curr->getParentOfType<ModuleOp>();
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
- auto convertType = getTypeConverter(axisInfo, ptr, numWarps);
+ auto convertType = getTypeConverter(*axisInfo, ptr, numWarps);
layoutMap[ptr] = convertType;
});
diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp
index efa37ff2dc..089ce3996c 100644
--- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp
+++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp
@@ -1,6 +1,6 @@
#include "Utility.h"
#include "mlir/Analysis/SliceAnalysis.h"
-#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Matchers.h"
diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.td b/lib/Dialect/TritonGPU/Transforms/Combine.td
index 6bf1b14866..6a7b10dbcb 100644
--- a/lib/Dialect/TritonGPU/Transforms/Combine.td
+++ b/lib/Dialect/TritonGPU/Transforms/Combine.td
@@ -3,5 +3,6 @@
include "triton/Dialect/TritonGPU/IR/TritonGPUOps.td"
include "triton/Dialect/Triton/IR/TritonOps.td"
+include "mlir/IR/PatternBase.td"
#endif
diff --git a/lib/Dialect/TritonGPU/Transforms/DecomposeConversions.cpp b/lib/Dialect/TritonGPU/Transforms/DecomposeConversions.cpp
index 4bd3bc76bf..b2f8defd81 100644
--- a/lib/Dialect/TritonGPU/Transforms/DecomposeConversions.cpp
+++ b/lib/Dialect/TritonGPU/Transforms/DecomposeConversions.cpp
@@ -1,5 +1,5 @@
#include "mlir/Analysis/SliceAnalysis.h"
-#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Matchers.h"
diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp
index 9b2f42231e..85f746c1dc 100644
--- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp
+++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp
@@ -2,6 +2,7 @@
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/TypeUtilities.h"
#include "triton/Analysis/AxisInfo.h"
+#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
@@ -160,15 +161,18 @@ ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(Operation *op,
LogicalResult LoopPipeliner::initialize() {
Block *loop = forOp.getBody();
- AxisInfoAnalysis axisInfoAnalysis(forOp.getContext());
- axisInfoAnalysis.run(forOp->getParentOfType<ModuleOp>());
+ std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
+ AxisInfoAnalysis *axisInfoAnalysis = solver->load<AxisInfoAnalysis>();
+ if (failed(solver->initializeAndRun(forOp->getParentOfType<ModuleOp>()))) {
+ return failure();
+ }
// can we use forOp.walk(...) here?
SmallVector<triton::LoadOp, 2> allLoads;
for (Operation &op : *loop)
if (auto loadOp = dyn_cast<triton::LoadOp>(&op)) {
auto ptr = loadOp.ptr();
- unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr);
+ unsigned vec = axisInfoAnalysis->getPtrContiguity(ptr);
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
if (!tensorTy)
continue;
diff --git a/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp b/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp
index 0e7dbe5264..b95a4f50a6 100644
--- a/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp
+++ b/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp
@@ -1,5 +1,5 @@
#include "mlir/Analysis/SliceAnalysis.h"
-#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Matchers.h"
diff --git a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp
index 37ac710995..762e887f36 100644
--- a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp
+++ b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp
@@ -82,12 +82,12 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
scf::ReduceReturnOp>();
addDynamicallyLegalDialect<arith::ArithmeticDialect, math::MathDialect,
- triton::TritonDialect, StandardOpsDialect,
- scf::SCFDialect>([&](Operation *op) {
- if (typeConverter.isLegal(op))
- return true;
- return false;
- });
+ triton::TritonDialect, scf::SCFDialect>(
+ [&](Operation *op) {
+ if (typeConverter.isLegal(op))
+ return true;
+ return false;
+ });
// We have requirements for the data layouts
addDynamicallyLegalOp<triton::DotOp>([](triton::DotOp dotOp) -> bool {
diff --git a/lib/Dialect/TritonGPU/Transforms/UpdateMmaForVolta.cpp b/lib/Dialect/TritonGPU/Transforms/UpdateMmaForVolta.cpp
index c229104286..c911fd4a5c 100644
--- a/lib/Dialect/TritonGPU/Transforms/UpdateMmaForVolta.cpp
+++ b/lib/Dialect/TritonGPU/Transforms/UpdateMmaForVolta.cpp
@@ -1,5 +1,5 @@
#include "Utility.h"
-#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -118,8 +118,8 @@ void setOpResultType(Operation *op, ArrayRef<Type> newTypes) {
.get("value")
.dyn_cast<mlir::DenseElementsAttr>();
if (attr) {
- auto newAttr = mlir::DenseElementsAttr::getFromRawBuffer(
- newType, attr.getRawData(), true);
+ auto newAttr =
+ mlir::DenseElementsAttr::getFromRawBuffer(newType, attr.getRawData());
op->setAttr("value", newAttr);
}
}
diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp
index ed15f02f67..6400f1633a 100644
--- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp
+++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp
@@ -1,5 +1,5 @@
#include "Utility.h"
-#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
diff --git a/lib/Target/LLVMIR/CMakeLists.txt b/lib/Target/LLVMIR/CMakeLists.txt
index f1bbd0bf4e..ac8973ad19 100644
--- a/lib/Target/LLVMIR/CMakeLists.txt
+++ b/lib/Target/LLVMIR/CMakeLists.txt
@@ -6,8 +6,7 @@ add_mlir_translation_library(TritonLLVMIR
LINK_LIBS PUBLIC
MLIRIR
- MLIRLLVMIR
- MLIRSCFToStandard
+ MLIRLLVMDialect
MLIRSupport
MLIRTargetLLVMIRExport
)
diff --git a/lib/Target/PTX/PTXTranslation.cpp b/lib/Target/PTX/PTXTranslation.cpp
index 4cb0d8193c..6a5453a6e7 100644
--- a/lib/Target/PTX/PTXTranslation.cpp
+++ b/lib/Target/PTX/PTXTranslation.cpp
@@ -1,11 +1,14 @@
#include "triton/Target/PTX/PTXTranslation.h"
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
+#include <optional>
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/MC/TargetRegistry.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/CommandLine.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
diff --git a/python/setup.py b/python/setup.py
index 2ac3accd25..4530b36714 100644
--- a/python/setup.py
+++ b/python/setup.py
@@ -57,19 +57,10 @@ def get_pybind11_package_info():
def get_llvm_package_info():
# download if nothing is installed
system = platform.system()
- if system == "Darwin":
- system_suffix = "apple-darwin"
- elif system == "Linux":
- vglibc = tuple(map(int, platform.libc_ver()[1].split('.')))
- vglibc = vglibc[0] * 100 + vglibc[1]
- linux_suffix = 'ubuntu-18.04' if vglibc > 217 else 'centos-7'
- system_suffix = f"linux-gnu-{linux_suffix}"
- else:
- raise RuntimeError(f"unsupported system: {system}")
+ system_suffix = {"Linux": "linux-gnu-ubuntu-18.04", "Darwin": "apple-darwin"}[system]
use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False")
- release_suffix = "assert" if use_assert_enabled_llvm else "release"
- name = f'llvm+mlir-14.0.6-x86_64-{system_suffix}-{release_suffix}'
- url = f"https://github.com/ptillet/triton-llvm-releases/releases/download/llvm-14.0.6-f28c006a5895/{name}.tar.xz"
+ name = 'llvm+mlir-15.0.7-x86_64-{}-{}'.format(system_suffix, "assert" if use_assert_enabled_llvm else "release")
+ url = "https://github.com/ptillet/triton-llvm-releases/releases/download/llvm-15.0.7-8dfdcc7b7bf6/{}.tar.xz".format(name)
return Package("llvm", name, url, "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
diff --git a/python/src/triton.cc b/python/src/triton.cc
index c40b117a55..f190eacc34 100644
--- a/python/src/triton.cc
+++ b/python/src/triton.cc
@@ -8,9 +8,10 @@
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
-#include "mlir/Parser.h"
+#include "mlir/Parser/Parser.h"
#include "mlir/Support/FileUtilities.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "triton/Analysis/Allocation.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
@@ -195,7 +196,7 @@ void init_triton_ir(py::module &&m) {
std::string attrName = name + "_arg" + std::to_string(id);
mlir::Block *owner = arg.getOwner();
if (owner->isEntryBlock() &&
- !mlir::isa<mlir::FuncOp>(owner->getParentOp())) {
+ !mlir::isa<mlir::func::FuncOp>(owner->getParentOp())) {
owner->getParentOp()->setAttr(attrName, attr);
}
}
@@ -348,7 +349,7 @@ void init_triton_ir(py::module &&m) {
return str;
})
.def("push_back",
- [](mlir::ModuleOp &self, mlir::FuncOp &funcOp) -> void {
+ [](mlir::ModuleOp &self, mlir::func::FuncOp &funcOp) -> void {
self.push_back(funcOp);
})
.def("has_function",
@@ -358,16 +359,18 @@ void init_triton_ir(py::module &&m) {
return false;
})
.def("get_function",
- [](mlir::ModuleOp &self, std::string &funcName) -> mlir::FuncOp {
- return self.lookupSymbol<mlir::FuncOp>(funcName);
- })
- .def("get_single_function", [](mlir::ModuleOp &self) -> mlir::FuncOp {
- llvm::SmallVector<mlir::FuncOp> funcs;
- self.walk([&](mlir::FuncOp func) { funcs.push_back(func); });
- if (funcs.size() != 1)
- throw std::runtime_error("Expected a single function");
- return funcs[0];
- });
+ [](mlir::ModuleOp &self,
+ std::string &funcName) -> mlir::func::FuncOp {
+ return self.lookupSymbol<mlir::func::FuncOp>(funcName);
+ })
+ .def("get_single_function",
+ [](mlir::ModuleOp &self) -> mlir::func::FuncOp {
+ llvm::SmallVector<mlir::func::FuncOp> funcs;
+ self.walk([&](mlir::func::FuncOp func) { funcs.push_back(func); });
+ if (funcs.size() != 1)
+ throw std::runtime_error("Expected a single function");
+ return funcs[0];
+ });
m.def("make_attr",
[](const std::vector<int> &values, mlir::MLIRContext &context) {
@@ -388,47 +391,48 @@ void init_triton_ir(py::module &&m) {
registry.insert<mlir::triton::TritonDialect,
mlir::triton::gpu::TritonGPUDialect,
mlir::math::MathDialect, mlir::arith::ArithmeticDialect,
- mlir::StandardOpsDialect, mlir::scf::SCFDialect>();
+ mlir::func::FuncDialect, mlir::scf::SCFDialect>();
context.appendDialectRegistry(registry);
context.loadAllAvailableDialects();
// parse module
- mlir::OwningOpRef<mlir::ModuleOp> module(
- mlir::parseSourceFile(inputFilename, &context));
+ mlir::OwningOpRef<mlir::ModuleOp> module =
+ mlir::parseSourceFile<mlir::ModuleOp>(inputFilename, &context);
+ if (!module)
+ throw std::runtime_error("Parse MLIR file failed.");
// locations are incompatible with ptx < 7.5 !
module->walk([](mlir::Operation *op) {
op->setLoc(mlir::UnknownLoc::get(op->getContext()));
});
- if (!module)
- throw std::runtime_error("Parse MLIR file failed.");
return module->clone();
},
ret::take_ownership);
- py::class_<mlir::FuncOp, mlir::OpState>(m, "function")
+ py::class_<mlir::func::FuncOp, mlir::OpState>(m, "function")
// .def_property_readonly("attrs", &ir::function::attrs)
// .def("add_attr", &ir::function::add_attr);
.def("args",
- [](mlir::FuncOp &self, unsigned idx) -> mlir::BlockArgument {
+ [](mlir::func::FuncOp &self, unsigned idx) -> mlir::BlockArgument {
return self.getArgument(idx);
})
.def(
"add_entry_block",
- [](mlir::FuncOp &self) -> mlir::Block * {
+ [](mlir::func::FuncOp &self) -> mlir::Block * {
return self.addEntryBlock();
},
ret::reference)
.def(
"set_arg_attr",
- [](mlir::FuncOp &self, int arg_no, const std::string &name, int val) {
+ [](mlir::func::FuncOp &self, int arg_no, const std::string &name,
+ int val) {
// set arg attributes "name" to value "val"
auto attrTy = mlir::IntegerType::get(self.getContext(), 32);
self.setArgAttr(arg_no, name, mlir::IntegerAttr::get(attrTy, val));
},
ret::reference)
- .def_property_readonly("type", &mlir::FuncOp::getType)
- .def("reset_type", &mlir::FuncOp::setType);
+ .def_property_readonly("type", &mlir::func::FuncOp::getFunctionType)
+ .def("reset_type", &mlir::func::FuncOp::setType);
py::class_<mlir::OpBuilder::InsertPoint>(m, "InsertPoint");
@@ -445,13 +449,13 @@ void init_triton_ir(py::module &&m) {
.def("ret",
[](mlir::OpBuilder &self, std::vector<mlir::Value> &vals) -> void {
auto loc = self.getUnknownLoc();
- self.create<mlir::ReturnOp>(loc, vals);
+ self.create<mlir::func::ReturnOp>(loc, vals);
})
.def("call",
- [](mlir::OpBuilder &self, mlir::FuncOp &func,
+ [](mlir::OpBuilder &self, mlir::func::FuncOp &func,
std::vector<mlir::Value> &args) -> mlir::OpState {
auto loc = self.getUnknownLoc();
- return self.create<mlir::CallOp>(loc, func, args);
+ return self.create<mlir::func::CallOp>(loc, func, args);
})
// insertion block/point
.def("set_insertion_point_to_start",
@@ -618,15 +622,16 @@ void init_triton_ir(py::module &&m) {
.def("get_or_insert_function",
[](mlir::OpBuilder &self, mlir::ModuleOp &module,
std::string &funcName, mlir::Type &funcType,
- std::string &visibility) -> mlir::FuncOp {
+ std::string &visibility) -> mlir::func::FuncOp {
if (mlir::Operation *funcOperation = module.lookupSymbol(funcName))
- return llvm::dyn_cast<mlir::FuncOp>(funcOperation);
+ return llvm::dyn_cast<mlir::func::FuncOp>(funcOperation);
auto loc = self.getUnknownLoc();
if (auto funcTy = funcType.dyn_cast<mlir::FunctionType>()) {
llvm::SmallVector<mlir::NamedAttribute> attrs = {
mlir::NamedAttribute(self.getStringAttr("sym_visibility"),
self.getStringAttr(visibility))};
- return self.create<mlir::FuncOp>(loc, funcName, funcTy, attrs);
+ return self.create<mlir::func::FuncOp>(loc, funcName, funcTy,
+ attrs);
}
throw std::runtime_error("invalid function type");
})
@@ -658,15 +663,15 @@ void init_triton_ir(py::module &&m) {
[](mlir::OpBuilder &self, mlir::Value condition,
mlir::Block *trueDest, mlir::Block *falseDest) {
auto loc = self.getUnknownLoc();
- self.create<mlir::CondBranchOp>(loc, condition, trueDest,
- falseDest);
+ self.create<mlir::cf::CondBranchOp>(loc, condition, trueDest,
+ falseDest);
return;
})
.def("create_branch",
[](mlir::OpBuilder &self, mlir::Block *dest,
std::vector<mlir::Value> &args) {
auto loc = self.getUnknownLoc();
- self.create<mlir::BranchOp>(loc, dest, args);
+ self.create<mlir::cf::BranchOp>(loc, dest, args);
return;
})
// Structured control flow
@@ -792,14 +797,14 @@ void init_triton_ir(py::module &&m) {
.def("create_to_index",
[](mlir::OpBuilder &self, mlir::Value &input) -> mlir::Value {
auto loc = self.getUnknownLoc();
- return self.create<mlir::arith::IndexCastOp>(loc, input,
- self.getIndexType());
+ return self.create<mlir::arith::IndexCastOp>(
+ loc, self.getIndexType(), input);
})
.def("create_index_to_si",
[](mlir::OpBuilder &self, mlir::Value &input) -> mlir::Value {
auto loc = self.getUnknownLoc();
- return self.create<mlir::arith::IndexCastOp>(loc, input,
- self.getI32Type());
+ return self.create<mlir::arith::IndexCastOp>(
+ loc, self.getI32Type(), input);
})
.def("create_fmul",
[](mlir::OpBuilder &self, mlir::Value &lhs,
@@ -1316,8 +1321,8 @@ void init_triton_ir(py::module &&m) {
[](mlir::OpBuilder &self, mlir::Value &condition,
mlir::Value &trueValue, mlir::Value &falseValue) -> mlir::Value {
auto loc = self.getUnknownLoc();
- return self.create<mlir::SelectOp>(loc, condition, trueValue,
- falseValue);
+ return self.create<mlir::arith::SelectOp>(loc, condition,
+ trueValue, falseValue);
})
.def("create_printf",
[](mlir::OpBuilder &self, const std::string &prefix,
@@ -1429,7 +1434,7 @@ void init_triton_ir(py::module &&m) {
self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass());
})
.def("add_scf_to_cfg", [](mlir::PassManager &self) {
- self.addPass(mlir::createLowerToCFGPass());
+ self.addPass(mlir::createConvertSCFToCFPass());
});
}
diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py
index 432544a8a4..018f544714 100644
--- a/python/test/unit/language/test_core.py
+++ b/python/test/unit/language/test_core.py
@@ -1918,7 +1918,7 @@ def test_convert2d(dtype, shape, src_layout, dst_layout, device='cuda'):
#dst = {dst_layout}
""" + """
module attributes {"triton_gpu.num-warps" = 4 : i32} {
- func public @kernel_0d1d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
+ func.func public @kernel_0d1d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<128> : tensor<128x1xi32, #src>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>>
diff --git a/python/triton/compiler.py b/python/triton/compiler.py
index 5d167634df..c36589037c 100644
--- a/python/triton/compiler.py
+++ b/python/triton/compiler.py
@@ -1514,14 +1514,14 @@ def make_hash(fn, **kwargs):
return hashlib.md5((Path(fn).read_text() + triton.runtime.jit.version_key()).encode("utf-8")).hexdigest()
-# - ^\s*func\s+ : match the start of the string, any leading whitespace, the keyword func,
+# - ^\s*func\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
# and any following whitespace
# - (public\s+)? : optionally match the keyword public and any following whitespace
# - (@\w+) : match an @ symbol followed by one or more word characters
# (letters, digits, or underscores), and capture it as group 1 (the function name)
# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
# zero or more arguments separated by commas, and capture it as group 2 (the argument list)
-mlir_prototype_pattern = r'^\s*func\s+(?:public\s+)?(@\w+)(\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*\{\s*$'
+mlir_prototype_pattern = r'^\s*func\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*\{\s*$'
ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)"
prototype_pattern = {
"ttir": mlir_prototype_pattern,
diff --git a/test/Analysis/test-alias.mlir b/test/Analysis/test-alias.mlir
index b3d5673f85..bb21615e68 100644
--- a/test/Analysis/test-alias.mlir
+++ b/test/Analysis/test-alias.mlir
@@ -11,7 +11,7 @@
// CHECK-LABEL: matmul_loop
// There shouldn't be any aliasing with the dot op encoding.
-func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
+func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
%b_ptr_init = tt.broadcast %B : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #BL>
%a_mask = arith.constant dense<true> : tensor<128x32xi1, #AL>
@@ -36,7 +36,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
}
// CHECK-LABEL: alloc
-func @alloc(%A : !tt.ptr<f16>) {
+func.func @alloc(%A : !tt.ptr<f16>) {
// CHECK: %cst -> %cst
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
@@ -46,7 +46,7 @@ func @alloc(%A : !tt.ptr<f16>) {
}
// CHECK-LABEL: convert
-func @convert(%A : !tt.ptr<f16>) {
+func.func @convert(%A : !tt.ptr<f16>) {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
// CHECK: %0 -> %0
%cst1 = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A_SHARED>
@@ -54,7 +54,7 @@ func @convert(%A : !tt.ptr<f16>) {
}
// CHECK-LABEL: trans
-func @trans(%A : !tt.ptr<f16>) {
+func.func @trans(%A : !tt.ptr<f16>) {
// CHECK: %cst -> %cst
%tensor = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
// CHECK: %0 -> %cst
@@ -63,7 +63,7 @@ func @trans(%A : !tt.ptr<f16>) {
}
// CHECK-LABEL: insert_slice_async
-func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
+func.func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
%mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL>
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
@@ -76,7 +76,7 @@ func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
}
// CHECK-LABEL: insert_slice
-func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) {
+func.func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) {
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
%mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL>
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
@@ -90,7 +90,7 @@ func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) {
}
// CHECK-LABEL: extract_slice
-func @extract_slice(%A : !tt.ptr<f16>) {
+func.func @extract_slice(%A : !tt.ptr<f16>) {
// CHECK: %cst -> %cst
%cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED>
%index = arith.constant 0 : index
@@ -100,7 +100,7 @@ func @extract_slice(%A : !tt.ptr<f16>) {
}
// CHECK-LABEL: if_cat
-func @if_cat(%i1 : i1) {
+func.func @if_cat(%i1 : i1) {
// CHECK: %cst -> %cst
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
// CHECK: %cst_0 -> %cst_0
@@ -119,7 +119,7 @@ func @if_cat(%i1 : i1) {
}
// CHECK-LABEL: if_alias
-func @if_alias(%i1 : i1) {
+func.func @if_alias(%i1 : i1) {
// CHECK: %cst -> %cst
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
// CHECK-NEXT: %cst_0 -> %cst_0
@@ -134,7 +134,7 @@ func @if_alias(%i1 : i1) {
}
// CHECK-LABEL: for
-func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
+func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
// CHECK: %cst -> %cst
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK-NEXT: %cst_0 -> %cst_0
@@ -154,7 +154,7 @@ func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.p
}
// CHECK-LABEL: for_if
-func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
+func.func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
// CHECK: %cst -> %cst
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK-NEXT: %cst_0 -> %cst_0
@@ -180,7 +180,7 @@ func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !t
}
// CHECK-LABEL: for_if_for
-func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
+func.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
// CHECK: %cst -> %cst
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK-NEXT: %cst_0 -> %cst_0
diff --git a/test/Analysis/test-alignment.mlir b/test/Analysis/test-alignment.mlir
index 0ab34c7a78..af8ea6f856 100644
--- a/test/Analysis/test-alignment.mlir
+++ b/test/Analysis/test-alignment.mlir
@@ -1,288 +1,288 @@
-// RUN: triton-opt %s -test-print-alignment -split-input-file 2>&1 | FileCheck %s
+// RUN: triton-opt %s -test-print-alignment -split-input-file -o %t 2>&1 | FileCheck %s
-// CHECK-LABEL: cast
-func @cast() {
- // CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [1]
+// CHECK-LABEL: @cast
+func.func @cast() {
+ // CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1
%cst = arith.constant 1 : i32
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [1]
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1
%0 = arith.extsi %cst : i32 to i64
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1]
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1
%cst_tensor = arith.constant dense<1> : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1]
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1
%1 = tt.bitcast %cst_tensor : tensor<128xi32> -> tensor<128xi64>
return
}
// -----
-// CHECK-LABEL: add
-func @add() {
- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
+// CHECK-LABEL: @add
+func.func @add() {
+ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1]
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1
%1 = arith.constant dense<1> : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [128], divisibility = [1], constancy = [1], constant_value = <none>
%2 = arith.addi %0, %1 : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [127]
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 127
%3 = arith.constant dense<127> : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128]
+ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128
%4 = arith.addi %1, %3 : tensor<128xi32>
return
}
// -----
-// CHECK-LABEL: sub
-func @sub() {
- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
+// CHECK-LABEL: @sub
+func.func @sub() {
+ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1]
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1
%1 = arith.constant dense<1> : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [128], divisibility = [1], constancy = [1], constant_value = <none>
%2 = arith.subi %0, %1 : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [129]
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 129
%3 = arith.constant dense<129> : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128]
+ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128
%4 = arith.subi %3, %1 : tensor<128xi32>
return
}
// -----
-// CHECK-LABEL: mul
-func @mul() {
- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
+// CHECK-LABEL: @mul
+func.func @mul() {
+ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1]
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1
%1 = arith.constant dense<1> : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
%2 = arith.muli %0, %1 : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128]
+ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128
%3 = arith.constant dense<128> : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128]
+ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128
%4 = arith.muli %3, %1 : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [2] ; Constancy: [128] ; ConstantValue: [2]
+ // CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [128], constant_value = 2
%5 = arith.constant dense<2> : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [256] ; Constancy: [128] ; ConstantValue: [256]
+ // CHECK-NEXT: contiguity = [1], divisibility = [256], constancy = [128], constant_value = 256
%6 = arith.muli %4, %5 : tensor<128xi32>
return
}
// -----
-// CHECK-LABEL: div
-func @div() {
- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
+// CHECK-LABEL: @div
+func.func @div() {
+ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1]
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1
%1 = arith.constant dense<1> : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
%2 = arith.divsi %0, %1 : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
%3 = arith.divui %1, %0 : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64]
+ // CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64
%4 = arith.constant dense<64> : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [16777216] ; Constancy: [64] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1], divisibility = [16777216], constancy = [64], constant_value = <none>
%5 = arith.divsi %0, %4 : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
%6 = arith.divsi %4, %0 : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64]
+ // CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64
%7 = arith.divsi %4, %1 : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [2] ; Constancy: [128] ; ConstantValue: [66]
+ // CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [128], constant_value = 66
%8 = arith.constant dense<66> : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [2] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [2], constant_value = <none>
%9 = arith.divui %0, %8 : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [8192] ; Constancy: [1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [128], divisibility = [8192], constancy = [1], constant_value = <none>
%10 = tt.make_range {end = 8320 : i32, start = 8192 : i32} : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [64] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [64], constant_value = <none>
%11 = arith.divsi %10, %4 : tensor<128xi32>
- return
+ return
}
// -----
-// CHECK-LABEL: rem
-func @rem() {
- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
+// CHECK-LABEL: @rem
+func.func @rem() {
+ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1]
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1
%1 = arith.constant dense<1> : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0]
+ // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0
%2 = arith.remsi %0, %1 : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
%3 = arith.remui %1, %0 : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64]
+ // CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64
%4 = arith.constant dense<64> : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [64] ; Divisibility: [64] ; Constancy: [1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [64], divisibility = [64], constancy = [1], constant_value = <none>
%5 = arith.remsi %0, %4 : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [1], constant_value = <none>
%6 = arith.remsi %4, %0 : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [2] ; Constancy: [128] ; ConstantValue: [66]
+ // CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [128], constant_value = 66
%7 = arith.constant dense<66> : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [2] ; Divisibility: [2] ; Constancy: [1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [2], divisibility = [2], constancy = [1], constant_value = <none>
%8 = arith.remui %0, %7 : tensor<128xi32>
- return
+ return
}
// -----
-// CHECK-LABEL: broadcast
-func @broadcast() {
- // CHECK: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64]
+// CHECK-LABEL: @broadcast
+func.func @broadcast() {
+ // CHECK: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64
%0 = arith.constant dense<64> : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [64, 1] ; Constancy: [128, 1] ; ConstantValue: [64]
+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 1], constancy = [128, 1], constant_value = 64
%1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32>
- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [64, 1] ; Constancy: [128, 128] ; ConstantValue: [64]
+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 1], constancy = [128, 128], constant_value = 64
%2 = tt.broadcast %1 : (tensor<128x1xi32>) -> tensor<128x128xi32>
return
}
// -----
-// CHECK-LABEL: splat
-func @splat(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
- // CHECK: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 128] ; ConstantValue: [None]
+// CHECK-LABEL: @splat
+func.func @splat(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
+ // CHECK: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 128], constant_value = <none>
%0 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<128x128x!tt.ptr<f32>>
return
}
// -----
-// CHECK-LABEL: cmp
-func @cmp() {
- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
+// CHECK-LABEL: @cmp
+func.func @cmp() {
+ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0]
+ // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0
%1 = arith.constant dense<0> : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = <none>
%2 = arith.cmpi eq, %0, %1 : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = <none>
%3 = arith.cmpi slt, %0, %1 : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
%4 = arith.cmpi sle, %0, %1 : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = <none>
%5 = arith.cmpi sge, %0, %1 : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [8] ; Constancy: [128] ; ConstantValue: [8]
+ // CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8
%6 = arith.constant dense<8> : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [8] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = <none>
%7 = arith.cmpi sgt, %0, %6 : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [0]
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 0
%8 = arith.cmpi sgt, %1, %6 : tensor<128xi32>
return
}
// -----
-// CHECK-LABEL: logic
-func @logic() {
- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
+// CHECK-LABEL: @logic
+func.func @logic() {
+ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64]
+ // CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64
%1 = arith.constant dense<64> : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [16777216] ; Constancy: [64] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1], divisibility = [16777216], constancy = [64], constant_value = <none>
%2 = arith.divsi %0, %1 : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [8] ; Constancy: [128] ; ConstantValue: [8]
+ // CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8
%3 = arith.constant dense<8> : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [134217728] ; Constancy: [8] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1], divisibility = [134217728], constancy = [8], constant_value = <none>
%4 = arith.divsi %0, %3 : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
%5 = arith.andi %0, %1 : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
%6 = arith.ori %0, %1 : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
%7 = arith.xori %0, %1 : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [8] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = <none>
%8 = arith.andi %2, %4 : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [8] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = <none>
%9 = arith.ori %2, %4 : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [8] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = <none>
%10 = arith.xori %2, %4 : tensor<128xi32>
return
}
// -----
-// CHECK-LABEL: select
-func @select() {
- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
+// CHECK-LABEL: @select
+func.func @select() {
+ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0]
+ // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0
%1 = arith.constant dense<0> : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = <none>
%2 = arith.cmpi eq, %0, %1 : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = <none>
%3 = arith.cmpi slt, %0, %1 : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [1] ; ConstantValue: [0]
+ // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [1], constant_value = 0
%4 = arith.constant 0 : i1
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0]
+ // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0
%7 = tt.splat %4 : (i1) -> tensor<128xi1>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0]
- %5 = select %4, %3, %7 : tensor<128xi1>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0
+ %5 = arith.select %4, %3, %7 : tensor<128xi1>
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = <none>
%8 = "triton_gpu.select"(%7, %3, %2) : (tensor<128xi1>, tensor<128xi1>, tensor<128xi1>) -> tensor<128xi1>
return
}
// -----
-func @shift() {
- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
+func.func @shift() {
+ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [8] ; Constancy: [128] ; ConstantValue: [8]
+ // CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8
%1 = arith.constant dense<8> : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4] ; Constancy: [128] ; ConstantValue: [4]
+ // CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [128], constant_value = 4
%2 = arith.constant dense<4> : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [274877906944] ; Constancy: [1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1], divisibility = [274877906944], constancy = [1], constant_value = <none>
%3 = arith.shli %0, %1 : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [67108864] ; Constancy: [1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1], divisibility = [67108864], constancy = [1], constant_value = <none>
%4 = arith.shrsi %0, %2 : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128]
+ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128
%5 = arith.shli %1, %2 : tensor<128xi32>
return
}
// -----
-func @max_min() {
- // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
+func.func @max_min() {
+ // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [64] ; Constancy: [1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [128], divisibility = [64], constancy = [1], constant_value = <none>
%1 = tt.make_range {end = 192 : i32, start = 64 : i32} : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
%2 = arith.maxsi %0, %1 : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
%3 = arith.minsi %0, %1 : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [8] ; Constancy: [128] ; ConstantValue: [8]
+ // CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8
%4 = arith.constant dense<8> : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4] ; Constancy: [128] ; ConstantValue: [4]
+ // CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [128], constant_value = 4
%5 = arith.constant dense<4> : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [8]
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 8
%6 = arith.maxsi %4, %5 : tensor<128xi32>
return
}
// -----
-// CHECK-LABEL: for
-func @for() {
- // CHECK: Contiguity: [1, 1] ; Divisibility: [4611686018427387904, 4611686018427387904] ; Constancy: [128, 32] ; ConstantValue: [0]
+// CHECK-LABEL: @for
+func.func @for() {
+ // CHECK: contiguity = [1, 1], divisibility = [4611686018427387904, 4611686018427387904], constancy = [128, 32], constant_value = 0
%a_init = arith.constant dense<0> : tensor<128x32xi32>
- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [128, 32] ; ConstantValue: [1]
+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = 1
%b_init = arith.constant dense<1> : tensor<128x32xi32>
- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [4, 4] ; Constancy: [128, 32] ; ConstantValue: [4]
+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [4, 4], constancy = [128, 32], constant_value = 4
%c_init = arith.constant dense<4> : tensor<128x32xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1] ; ConstantValue: [128]
+ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [1], constant_value = 128
%ub = arith.constant 128 : index
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [1] ; ConstantValue: [0]
+ // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [1], constant_value = 0
%lb = arith.constant 0 : index
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [1] ; ConstantValue: [16]
+ // CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [1], constant_value = 16
%step = arith.constant 16 : index
%a, %b, %c = scf.for %iv = %lb to %ub step %step iter_args(%a = %a_init, %b = %b_init, %c = %c_init) -> (tensor<128x32xi32>, tensor<128x32xi32>, tensor<128x32xi32>) {
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [1], constant_value = <none>
%t = arith.index_cast %iv : index to i32
- // CHECK: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [128, 32] ; ConstantValue: [None]
- // CHECK: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [128, 32] ; ConstantValue: [None]
- // CHECK: Contiguity: [1, 1] ; Divisibility: [4, 4] ; Constancy: [128, 32] ; ConstantValue: [4]
+ // CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = <none>
+ // CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = <none>
+ // CHECK: contiguity = [1, 1], divisibility = [4, 4], constancy = [128, 32], constant_value = 4
scf.yield %b, %a, %c : tensor<128x32xi32>, tensor<128x32xi32>, tensor<128x32xi32>
}
return
@@ -290,53 +290,53 @@ func @for() {
// -----
-// CHECK-LABEL: permute_2d
-func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
- // CHECK: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [128, 128] ; ConstantValue: [1]
+// CHECK-LABEL: @permute_2d
+func.func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
+ // CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 128], constant_value = 1
%cst = arith.constant dense<true> : tensor<128x128xi1>
- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = <none>
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32>
- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [1073741824, 1] ; Constancy: [1, 1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [128, 1], divisibility = [1073741824, 1], constancy = [1, 1], constant_value = <none>
%2 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32>
- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = <none>
%3 = tt.splat %arg1 : (i32) -> tensor<128x1xi32>
- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [17179869184, 16] ; Constancy: [1, 1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [17179869184, 16], constancy = [1, 1], constant_value = <none>
%4 = arith.muli %2, %3 : tensor<128x1xi32>
- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = <none>
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>>
- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 1], constant_value = <none>
%6 = tt.addptr %5, %4 : tensor<128x1x!tt.ptr<f32>>, tensor<128x1xi32>
- // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 1073741824] ; Constancy: [1, 1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [1, 1], constant_value = <none>
%7 = tt.expand_dims %1 {axis = 0 : i32}: (tensor<128xi32>) -> tensor<1x128xi32>
- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 128], constant_value = <none>
%8 = tt.broadcast %6 : (tensor<128x1x!tt.ptr<f32>>) -> tensor<128x128x!tt.ptr<f32>>
- // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 1073741824] ; Constancy: [128, 1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [128, 1], constant_value = <none>
%9 = tt.broadcast %7 : (tensor<1x128xi32>) -> tensor<128x128xi32>
- // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 16] ; Constancy: [1, 1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 16], constancy = [1, 1], constant_value = <none>
%10 = tt.addptr %8, %9 : tensor<128x128x!tt.ptr<f32>>, tensor<128x128xi32>
- // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [1073741824, 1] ; Constancy: [1, 1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [128, 1], divisibility = [1073741824, 1], constancy = [1, 1], constant_value = <none>
%11 = tt.expand_dims %0 {axis = 1 : i32}: (tensor<128xi32>) -> tensor<128x1xi32>
- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = <none>
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>>
- // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 1], constancy = [1, 1], constant_value = <none>
%13 = tt.addptr %12, %11 : tensor<128x1x!tt.ptr<f32>>, tensor<128x1xi32>
- // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 1073741824] ; Constancy: [1, 1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [1, 1], constant_value = <none>
%14 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32>) -> tensor<1x128xi32>
- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 128], constant_value = <none>
%15 = tt.splat %arg3 : (i32) -> tensor<1x128xi32>
- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 17179869184] ; Constancy: [1, 1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 17179869184], constancy = [1, 1], constant_value = <none>
%16 = arith.muli %14, %15 : tensor<1x128xi32>
- // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 128] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 1], constancy = [1, 128], constant_value = <none>
%17 = tt.broadcast %13 : (tensor<128x1x!tt.ptr<f32>>) -> tensor<128x128x!tt.ptr<f32>>
- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 17179869184] ; Constancy: [128, 1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 17179869184], constancy = [128, 1], constant_value = <none>
%18 = tt.broadcast %16 : (tensor<1x128xi32>) -> tensor<128x128xi32>
- // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 1], constancy = [1, 1], constant_value = <none>
%19 = tt.addptr %17, %18 : tensor<128x128x!tt.ptr<f32>>, tensor<128x128xi32>
- // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = <none>
%20 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32>
tt.store %19, %20, %cst : tensor<128x128xf32>
return
@@ -347,29 +347,29 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t
module {
// This is a tiny test for verifying StoreOp-related alignment, It simply store a constant to a buffer.
-// CHECK-LABEL: store_constant_align
-func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n: i32 {tt.divisibility = 16 : i32}) {
- // CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
+// CHECK-LABEL: @store_constant_align
+func.func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n: i32 {tt.divisibility = 16 : i32}) {
+ // CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
%pid = tt.get_program_id {axis = 0 : i32} : i32
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1] ; ConstantValue: [128]
+ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [1], constant_value = 128
%c128_i32 = arith.constant 128 : i32
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [1], constant_value = <none>
%1 = arith.muli %pid, %c128_i32 : i32
- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
%2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = <none>
%3 = tt.splat %1 : (i32) -> tensor<128xi32>
- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [128] ; Constancy: [1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [128], divisibility = [128], constancy = [1], constant_value = <none>
%4 = arith.addi %3, %2 : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [128] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [128], constant_value = <none>
%5 = tt.splat %addr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
- // CHECK-NEXT: Contiguity: [128] ; Divisibility: [16] ; Constancy: [1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [128], divisibility = [16], constancy = [1], constant_value = <none>
%6 = tt.addptr %5, %4 : tensor<128x!tt.ptr<f32>>, tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [128] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [128], constant_value = <none>
%9 = tt.splat %n : (i32) -> tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [16] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [16], constant_value = <none>
%mask = arith.cmpi slt, %4, %9 : tensor<128xi32>
- // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None]
+ // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
%cst = arith.constant dense<0.0> : tensor<128xf32>
tt.store %5, %cst, %mask : tensor<128xf32>
return
@@ -381,8 +381,8 @@ func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n:
// This IR is dumped from vecadd test.
// Note, the hint {tt.divisibility = 16 : i32} for %n_elements affects the alignment of mask.
-// CHECK-LABEL: vecadd_mask_align_16
-func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) {
+// CHECK-LABEL: @vecadd_mask_align_16
+func.func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) {
%c64_i32 = arith.constant 64 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.muli %0, %c64_i32 : i32
@@ -394,13 +394,13 @@ func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %ar
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
%8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
%9 = tt.splat %n_elements : (i32) -> tensor<64xi32>
- // CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [16] ; ConstantValue: [None] ( %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : tensor<64xi32> )
+ // CHECK: arith.cmpi slt, %{{.*}} => contiguity = [1], divisibility = [1], constancy = [16], constant_value = <none>
%mask = arith.cmpi slt, %4, %9 : tensor<64xi32>
%11 = tt.load %6, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
%12 = tt.load %8, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
%13 = arith.addf %11, %12 : tensor<64xf32>
%14 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
- // CHECK: Contiguity: [64] ; Divisibility: [16] ; Constancy: [1] ; ConstantValue: [None] ( %{{.*}} = tt.addptr %{{.*}}, %{{.*}} : tensor<64x!tt.ptr<f32>>, tensor<64xi32> )
+ // CHECK: tt.addptr %{{.*}} => contiguity = [64], divisibility = [16], constancy = [1], constant_value = <none>
%15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
tt.store %15, %13, %mask : tensor<64xf32>
return
@@ -410,8 +410,8 @@ func @vecadd_mask_align_16(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %ar
// This IR is dumped from vecadd test.
// Note, there is no divisibility hint for %n_elements, Triton should assume its divisibility to be 1 by default.
-// CHECK-LABEL: vecadd_mask_align_1
-func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
+// CHECK-LABEL: @vecadd_mask_align_1
+func.func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
%c64_i32 = arith.constant 64 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.muli %0, %c64_i32 : i32
@@ -423,7 +423,7 @@ func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<64x!tt.ptr<f32>>
%8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
%9 = tt.splat %n_elements : (i32) -> tensor<64xi32>
- // CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] ( %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : tensor<64xi32> )
+ // CHECK: arith.cmpi slt, %{{.*}} => contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
%10 = arith.cmpi slt, %4, %9 : tensor<64xi32>
%11 = tt.load %6, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
%12 = tt.load %8, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir
index efb00c404d..f79222aa7b 100644
--- a/test/Analysis/test-allocation.mlir
+++ b/test/Analysis/test-allocation.mlir
@@ -13,7 +13,7 @@
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: matmul_loop
-func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
+func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
%b_ptr_init = tt.broadcast %B : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #BL>
@@ -46,7 +46,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
// Shared memory is available after a tensor's liveness range ends
// CHECK-LABEL: reusable
-func @reusable(%A : !tt.ptr<f16>) {
+func.func @reusable(%A : !tt.ptr<f16>) {
%cst1 = arith.constant dense<true> : tensor<128x32xi1, #AL>
%cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL>
%cst3 = arith.constant dense<true> : tensor<32x128xi1, #AL>
@@ -78,7 +78,7 @@ func @reusable(%A : !tt.ptr<f16>) {
// %cst1->%cst4
// %cst3->%g->%h->%i
// CHECK-LABEL: preallocate
-func @preallocate(%A : !tt.ptr<f16>) {
+func.func @preallocate(%A : !tt.ptr<f16>) {
// CHECK: offset = 0, size = 512
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
// CHECK-NEXT: offset = 1024, size = 512
@@ -113,7 +113,7 @@ func @preallocate(%A : !tt.ptr<f16>) {
// Unused tensors are immediately released
// CHECK-LABEL: unused
-func @unused(%A : !tt.ptr<f16>) {
+func.func @unused(%A : !tt.ptr<f16>) {
// CHECK: offset = 0, size = 1024
%cst0 = arith.constant dense<0.000000e+00> : tensor<32x16xf16, #A_SHARED>
// CHECK-NEXT: offset = 0, size = 512
@@ -128,7 +128,7 @@ func @unused(%A : !tt.ptr<f16>) {
// cst0 is alive through the entire function, it cannot be released before the end of the function
// CHECK-LABEL: longlive
-func @longlive(%A : !tt.ptr<f16>) {
+func.func @longlive(%A : !tt.ptr<f16>) {
// CHECK: offset = 0, size = 512
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
// CHECK-NEXT: offset = 512, size = 512
@@ -156,7 +156,7 @@ func @longlive(%A : !tt.ptr<f16>) {
}
// CHECK-LABEL: alloc
-func @alloc(%A : !tt.ptr<f16>) {
+func.func @alloc(%A : !tt.ptr<f16>) {
// CHECK: offset = 0, size = 512
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
@@ -167,7 +167,7 @@ func @alloc(%A : !tt.ptr<f16>) {
}
// CHECK-LABEL: scratch
-func @scratch() {
+func.func @scratch() {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
// CHECK: scratch offset = 0, size = 512
%b = tt.reduce %cst0 {redOp = 1 : i32, axis = 0 : i32} : tensor<16x16xf16, #AL> -> tensor<16xf16, #sliceAd0>
@@ -176,7 +176,7 @@ func @scratch() {
}
// CHECK-LABEL: trans
-func @trans(%A : !tt.ptr<f16>) {
+func.func @trans(%A : !tt.ptr<f16>) {
// CHECK: offset = 0, size = 1024
%tensor = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
%b = tt.trans %tensor : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED_T>
@@ -184,7 +184,7 @@ func @trans(%A : !tt.ptr<f16>) {
}
// CHECK-LABEL: insert_slice_async
-func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
+func.func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
%mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL>
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
@@ -197,7 +197,7 @@ func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
}
// CHECK-LABEL: extract_slice
-func @extract_slice(%A : !tt.ptr<f16>) {
+func.func @extract_slice(%A : !tt.ptr<f16>) {
// CHECK: offset = 0, size = 512
%cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED>
%index = arith.constant 0 : index
@@ -209,7 +209,7 @@ func @extract_slice(%A : !tt.ptr<f16>) {
// B0 -> (B1) -> B0
// Memory used by B1 can be reused by B0.
// CHECK-LABEL: if
-func @if(%i1 : i1) {
+func.func @if(%i1 : i1) {
// CHECK: offset = 0, size = 512
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
// CHECK-NEXT: offset = 512, size = 512
@@ -233,7 +233,7 @@ func @if(%i1 : i1) {
// B0 -> (B1) -> (B2) -> B0
// Memory used by B0 cannot be reused by B1 or B2.
// CHECK-LABEL: if_else
-func @if_else(%i1 : i1) {
+func.func @if_else(%i1 : i1) {
// CHECK: offset = 0, size = 512
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
// CHECK-NEXT: offset = 512, size = 512
@@ -260,7 +260,7 @@ func @if_else(%i1 : i1) {
// Block arguments and yields are memory aliases that do not trigger a new
// allocation.
// CHECK-LABEL: for
-func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
+func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
// CHECK: offset = 0, size = 8192
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK-NEXT: offset = 8192, size = 8192
@@ -275,7 +275,7 @@ func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.p
}
// CHECK-LABEL: for_if_slice
-func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
+func.func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
// CHECK: offset = 0, size = 8192
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK-NEXT: offset = 8192, size = 8192
@@ -296,7 +296,7 @@ func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %
// c0 cannot be released in the loop
// CHECK-LABEL: for_use_ancestor
-func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
+func.func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
// CHECK: offset = 0, size = 8192
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK-NEXT: offset = 8192, size = 8192
@@ -316,7 +316,7 @@ func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16
// a_shared_init, b_shared_init, and c_shared_init's liveness ranges are span over the entire function before cst2.
// So they cannot be reused by cst0 and cst1, but can be reused by cst2.
// CHECK-LABEL: for_if_for
-func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
+func.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %i1 : i1) {
// CHECK: offset = 0, size = 8192
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK-NEXT: offset = 8192, size = 8192
diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir
index 7199e5f53d..17880b2094 100644
--- a/test/Analysis/test-membar.mlir
+++ b/test/Analysis/test-membar.mlir
@@ -14,7 +14,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: matmul_loop
// There shouldn't be any membar with the dot op encoding.
-func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
+func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
%b_ptr_init = tt.broadcast %B : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #BL>
@@ -42,7 +42,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
}
// CHECK-LABEL: raw_single_block
-func @raw_single_block(%A : !tt.ptr<f16>) {
+func.func @raw_single_block(%A : !tt.ptr<f16>) {
%cst1 = arith.constant dense<true> : tensor<128x32xi1, #AL>
%cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL>
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
@@ -54,7 +54,7 @@ func @raw_single_block(%A : !tt.ptr<f16>) {
}
// CHECK-LABEL: war_single_block
-func @war_single_block(%A : !tt.ptr<f16>) {
+func.func @war_single_block(%A : !tt.ptr<f16>) {
%cst1 = arith.constant dense<true> : tensor<128x32xi1, #AL>
%cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL>
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
@@ -70,7 +70,7 @@ func @war_single_block(%A : !tt.ptr<f16>) {
}
// CHECK-LABEL: scratch
-func @scratch() {
+func.func @scratch() {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
// CHECK: Membar 1
%a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
@@ -81,7 +81,7 @@ func @scratch() {
}
// CHECK-LABEL: async_wait
-func @async_wait() {
+func.func @async_wait() {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
// CHECK: Membar 1
%a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
@@ -92,7 +92,7 @@ func @async_wait() {
}
// CHECK-LABEL: alloc
-func @alloc() {
+func.func @alloc() {
%cst0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED>
%a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
// CHECK: Membar 2
@@ -101,7 +101,7 @@ func @alloc() {
}
// CHECK-LABEL: extract_slice
-func @extract_slice() {
+func.func @extract_slice() {
%cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED>
%index = arith.constant 0 : index
%cst1 = tensor.extract_slice %cst0[%index, 0, 0][1, 16, 16][1, 1, 1] : tensor<1x16x16xf16, #A_SHARED> to tensor<16x16xf16, #A_SHARED>
@@ -113,14 +113,14 @@ func @extract_slice() {
}
// CHECK-LABEL: trans
-func @trans() {
+func.func @trans() {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
%b = tt.trans %cst0 : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED_T>
return
}
// CHECK-LABEL: insert_slice_async
-func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
+func.func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
%mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL>
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
@@ -135,7 +135,7 @@ func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
}
// CHECK-LABEL: insert_slice
-func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) {
+func.func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) {
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>, #AL>
%mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL>
%other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
@@ -153,7 +153,7 @@ func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) {
// If branch inserted a barrier for %cst0 and %cst1, but else didn't, then the barrier should be inserted in the parent region
// CHECK-LABEL: multi_blocks
-func @multi_blocks(%i1 : i1) {
+func.func @multi_blocks(%i1 : i1) {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
scf.if %i1 {
@@ -174,7 +174,7 @@ func @multi_blocks(%i1 : i1) {
// Both branches inserted a barrier for %cst0 and %cst1, then the barrier doesn't need to be inserted in the parent region
// CHECK-LABEL: multi_blocks_join_barrier
-func @multi_blocks_join_barrier(%i1 : i1) {
+func.func @multi_blocks_join_barrier(%i1 : i1) {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
scf.if %i1 {
@@ -192,7 +192,7 @@ func @multi_blocks_join_barrier(%i1 : i1) {
// Read yielded tensor requires a barrier
// CHECK-LABEL: multi_blocks_yield
-func @multi_blocks_yield(%i1 : i1) {
+func.func @multi_blocks_yield(%i1 : i1) {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
%a = scf.if %i1 -> (tensor<32x16xf16, #A_SHARED>) {
@@ -212,7 +212,7 @@ func @multi_blocks_yield(%i1 : i1) {
// Conservatively add a barrier as if the branch (%i1) is never taken
// CHECK-LABEL: multi_blocks_noelse
-func @multi_blocks_noelse(%i1 : i1) {
+func.func @multi_blocks_noelse(%i1 : i1) {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
scf.if %i1 {
@@ -226,7 +226,7 @@ func @multi_blocks_noelse(%i1 : i1) {
// Conservatively add a barrier as if the branch (%i2) is never taken
// CHECK-LABEL: multi_blocks_nested_scf
-func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) {
+func.func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
scf.if %i1 {
@@ -247,7 +247,7 @@ func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) {
}
// CHECK-LABEL: for
-func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
+func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
@@ -262,7 +262,7 @@ func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.p
// Although a_shared and b_shared are synced before entering the loop,
// they are reassociated with aliases (c_shared) and thus require a barrier.
// CHECK-LABEL: for_alias
-func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
+func.func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK-NEXT: Membar 2
@@ -282,7 +282,7 @@ func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B :
// Although cst2 is not an argument of scf.yield, its memory is reused by cst1.
// So we need a barrier both before and after cst1
// CHECK-LABEL: for_reuse
-func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
+func.func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK-NEXT: Membar 2
@@ -302,7 +302,7 @@ func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B :
// CHECK-LABEL: for_reuse_nested
-func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
+func.func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
%a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
%b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
// CHECK-NEXT: Membar 2
diff --git a/test/Conversion/triton_ops.mlir b/test/Conversion/triton_ops.mlir
index e9ee502435..0e979b148d 100644
--- a/test/Conversion/triton_ops.mlir
+++ b/test/Conversion/triton_ops.mlir
@@ -1,6 +1,6 @@
// RUN: triton-opt %s | FileCheck %s
-func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i64) {
+func.func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i64) {
// scalar -> scalar
// CHECK: i64 -> !tt.ptr<f32>
%0 = tt.int_to_ptr %scalar_i64 : i64 -> !tt.ptr<f32>
@@ -35,7 +35,7 @@ func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i64) {
return
}
-func @addptr_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_i32: i32) {
+func.func @addptr_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_i32: i32) {
// scalar -> scalar
// CHECK: !tt.ptr<f32>
%0 = tt.addptr %scalar_ptr, %scalar_i32 : !tt.ptr<f32>, i32
@@ -54,7 +54,7 @@ func @addptr_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_i32: i32) {
return
}
-func @load_store_ops_scalar(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %mask : i1) {
+func.func @load_store_ops_scalar(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %mask : i1) {
// Test if Load/Store ops can handle scalar values
%other = arith.constant 0.0e+0 : f32
@@ -76,7 +76,7 @@ func @load_store_ops_scalar(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %ma
return
}
-func @reduce_ops_infer(%ptr: !tt.ptr<f32>, %v : tensor<1x2x4xf32>) {
+func.func @reduce_ops_infer(%ptr: !tt.ptr<f32>, %v : tensor<1x2x4xf32>) {
// Test if reduce ops infer types correctly
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<2x4xf32>
@@ -101,7 +101,7 @@ func @reduce_ops_infer(%ptr: !tt.ptr<f32>, %v : tensor<1x2x4xf32>) {
return
}
-func @dot_ops_infer(%ptr: !tt.ptr<f32>, %v : f32) {
+func.func @dot_ops_infer(%ptr: !tt.ptr<f32>, %v : f32) {
// Test if reduce ops infer types correctly
%v128x32 = tt.splat %v : (f32) -> tensor<128x32xf32>
%v32x128 = tt.splat %v : (f32) -> tensor<32x128xf32>
diff --git a/test/Conversion/triton_to_tritongpu.mlir b/test/Conversion/triton_to_tritongpu.mlir
index a160bc8815..b461ca542f 100644
--- a/test/Conversion/triton_to_tritongpu.mlir
+++ b/test/Conversion/triton_to_tritongpu.mlir
@@ -1,6 +1,6 @@
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu=num-warps=2 | FileCheck %s
-func @ops() {
+func.func @ops() {
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32} {{.*}}
%a = arith.constant dense<1.00e+00> : tensor<128x32xf16>
%b = arith.constant dense<2.00e+00> : tensor<32x128xf16>
@@ -11,7 +11,7 @@ func @ops() {
// -----
-func @load_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
+func.func @load_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
// Test if LoadOp is lowered properly (see #771)
%ptrs = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
%mask = arith.constant dense<true> : tensor<128xi1>
@@ -30,7 +30,7 @@ func @load_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
// -----
-func @reduce_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
+func.func @reduce_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
// Test if the total number of threadsPerWarp is 32
// Test if the total number of warps is 2
// CHECK: #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 2], order = [0, 1]}>
diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir
index e9e7d5a340..507b362c99 100644
--- a/test/Conversion/tritongpu_to_llvm.mlir
+++ b/test/Conversion/tritongpu_to_llvm.mlir
@@ -4,7 +4,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr<f16, 1>)
// Here the 128 comes from the 4 in module attribute multiples 32
// CHECK: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = 128 : i32} {{.*}}
- func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
+ func.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
// CHECK: llvm.return
return
}
@@ -15,7 +15,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_load
- func @basic_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
+ func.func @basic_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
// CHECK: llvm.inline_asm
// CHECK: llvm.inline_asm
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
@@ -28,7 +28,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: vectorized_load
- func @vectorized_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
+ func.func @vectorized_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
// CHECK: llvm.inline_asm
// CHECK-SAME: ld.global.b32
// CHECK: llvm.inline_asm
@@ -43,7 +43,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: vectorized_load_f16
- func @vectorized_load_f16(%a_ptr_init: tensor<256x!tt.ptr<f16>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf16, #blocked0>) {
+ func.func @vectorized_load_f16(%a_ptr_init: tensor<256x!tt.ptr<f16>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf16, #blocked0>) {
// CHECK: llvm.inline_asm
// CHECK-SAME: ld.global.b16
// CHECK: llvm.inline_asm
@@ -59,7 +59,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: masked_load_const_other
- func @masked_load_const_other(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>) {
+ func.func @masked_load_const_other(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>) {
%cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0>
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
return
@@ -72,7 +72,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: masked_load_const_other_vec
- func @masked_load_const_other_vec(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>) {
+ func.func @masked_load_const_other_vec(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>) {
%cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0>
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
return
@@ -84,7 +84,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>
module attributes {"triton_gpu.num-warps" = 2 : i32} {
// CHECK-LABEL: global_load_store_no_vec
- func @global_load_store_no_vec(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg3: i32) {
+ func.func @global_load_store_no_vec(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg3: i32) {
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.muli %0, %c256_i32 : i32
@@ -128,7 +128,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>
module attributes {"triton_gpu.num-warps" = 2 : i32} {
// CHECK-LABEL: global_load_store_vec4
- func @global_load_store_vec4(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
+ func.func @global_load_store_vec4(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.muli %0, %c256_i32 : i32
@@ -165,7 +165,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>
// Note, the %n_elements doesn't have a "tt.divisibility" hint, so Triton assumes it's divisibility is 1, this should effect the mask's alignment and further restrict the load/store ops' vector width to be 1.
module attributes {"triton_gpu.num-warps" = 2 : i32} {
- func @vecadd_masked_vec1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
+ func.func @vecadd_masked_vec1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
%c64_i32 = arith.constant 64 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.muli %0, %c64_i32 : i32
@@ -195,7 +195,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: global_load_store_vec2
- func @global_load_store_vec2(%arg0: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg3: i32) {
+ func.func @global_load_store_vec2(%arg0: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 8 : i32}, %arg3: i32) {
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.muli %0, %c256_i32 : i32
@@ -240,7 +240,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: global_load_store_vec8
- func @global_load_store_vec8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
+ func.func @global_load_store_vec8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.muli %0, %c256_i32 : i32
@@ -283,7 +283,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_view_broadcast
- func @basic_view_broadcast(%arg : tensor<256xf32,#blocked0>) {
+ func.func @basic_view_broadcast(%arg : tensor<256xf32,#blocked0>) {
// CHECK: llvm.mlir.undef
// CHECK: %[[T0:.*]] = llvm.extractvalue
// CHECK: %[[T1:.*]] = llvm.extractvalue
@@ -307,7 +307,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_make_range
- func @basic_make_range() {
+ func.func @basic_make_range() {
// CHECK: nvvm.read.ptx.sreg.tid.x
// CHECK: llvm.mlir.undef
// CHECK: llvm.insertvalue
@@ -322,7 +322,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_addf
- func @basic_addf(%arg0 : tensor<256xf32,#blocked0>, %arg1 : tensor<256xf32,#blocked0>) {
+ func.func @basic_addf(%arg0 : tensor<256xf32,#blocked0>, %arg1 : tensor<256xf32,#blocked0>) {
// CHECK: llvm.fadd
// CHECK: llvm.fadd
%1 = arith.addf %arg0, %arg1 : tensor<256xf32,#blocked0>
@@ -335,7 +335,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_addi
- func @basic_addi(%arg0 : tensor<256xi32,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) {
+ func.func @basic_addi(%arg0 : tensor<256xi32,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) {
// CHECK: llvm.add
// CHECK: llvm.add
%1 = arith.addi %arg0, %arg1 : tensor<256xi32,#blocked0>
@@ -347,7 +347,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_program_id
- func @basic_program_id() {
+ func.func @basic_program_id() {
// CHECK: nvvm.read.ptx.sreg.ctaid.x : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
return
@@ -359,7 +359,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_addptr
- func @basic_addptr(%arg0 : tensor<256x!tt.ptr<f32>,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) {
+ func.func @basic_addptr(%arg0 : tensor<256x!tt.ptr<f32>,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) {
// CHECK: llvm.getelementptr
// CHECK: llvm.getelementptr
%0 = tt.addptr %arg0, %arg1 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
@@ -373,7 +373,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: llvm.mlir.global external @global_smem
// CHECK-LABEL: basic_alloc_tensor
- func @basic_alloc_tensor() {
+ func.func @basic_alloc_tensor() {
// CHECK: llvm.mlir.addressof @global_smem
// CHECK-NEXT: llvm.bitcast
// CHECK-NEXT: llvm.mlir.constant
@@ -390,7 +390,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: llvm.mlir.global external @global_smem
// CHECK-LABEL: basic_extract_slice
- func @basic_extract_slice() {
+ func.func @basic_extract_slice() {
// CHECK: llvm.mlir.addressof @global_smem
// CHECK: llvm.extractvalue
// CHECK-NEXT: llvm.extractvalue
@@ -423,7 +423,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_async_wait
- func @basic_async_wait() {
+ func.func @basic_async_wait() {
// CHECK: cp.async.wait_group 0x4
triton_gpu.async_wait {num = 4: i32}
return
@@ -442,7 +442,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_insert_slice_async_fallback
- func @basic_insert_slice_async_fallback(%arg0: !tt.ptr<f16> {tt.divisibility = 1 : i32}) {
+ func.func @basic_insert_slice_async_fallback(%arg0: !tt.ptr<f16> {tt.divisibility = 1 : i32}) {
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1>
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0>
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2>
@@ -481,7 +481,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_insert_slice_async_v4
- func @basic_insert_slice_async_v4(%arg0: !tt.ptr<f32> {tt.divisibility = 32 : i32}) {
+ func.func @basic_insert_slice_async_v4(%arg0: !tt.ptr<f32> {tt.divisibility = 32 : i32}) {
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1>
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0>
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2>
@@ -523,7 +523,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_insert_slice_async_v1
- func @basic_insert_slice_async_v1(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
+ func.func @basic_insert_slice_async_v1(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
%off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1>
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice3d0>
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2>
@@ -568,7 +568,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_insert_slice_async_v1_multictas
- func @basic_insert_slice_async_v1_multictas(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
+ func.func @basic_insert_slice_async_v1_multictas(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}) {
%off0_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice2d1>
%off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice3d0>
%off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<32xi32, #slice2d1>) -> tensor<32x1xi32, #block2>
@@ -619,7 +619,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: basic_splat
- func @basic_splat(%ptr: !tt.ptr<f32>) {
+ func.func @basic_splat(%ptr: !tt.ptr<f32>) {
// CHECK: llvm.mlir.undef
// CHECK: llvm.insertvalue
// CHECK: llvm.insertvalue
@@ -633,7 +633,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_store
- func @basic_store(%ptrs: tensor<256x!tt.ptr<f32>, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) {
+ func.func @basic_store(%ptrs: tensor<256x!tt.ptr<f32>, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) {
// CHECK: llvm.inline_asm
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
// CHECK: llvm.inline_asm
@@ -650,7 +650,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_blocked_blocked
- func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) {
+ func.func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) {
// CHECK: llvm.mlir.addressof @global_smem
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
@@ -697,7 +697,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_blocked_blocked_vec
- func @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xf32, #blocked0>) {
+ func.func @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xf32, #blocked0>) {
// CHECK: llvm.mlir.addressof @global_smem
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
@@ -720,7 +720,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_blocked_blocked_multi_rep
- func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) {
+ func.func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) {
// CHECK: llvm.mlir.addressof @global_smem
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
@@ -751,7 +751,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_dot
- func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
+ func.func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
%AA = triton_gpu.convert_layout %A : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0>
%BB = triton_gpu.convert_layout %B : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0>
// CHECK: llvm.inline_asm
@@ -775,7 +775,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// TODO: problems in MLIR's parser on slice layout
// #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
// module attributes {"triton_gpu.num-warps" = 1 : i32} {
-// func @make_range_sliced_layout() {
+// func.func @make_range_sliced_layout() {
// %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>>
// return
// }
@@ -788,7 +788,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_mmav2_block
- func @convert_layout_mmav2_blocked(%arg0: tensor<32x16xf32, #mma>) {
+ func.func @convert_layout_mmav2_blocked(%arg0: tensor<32x16xf32, #mma>) {
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
// CHECK: llvm.store
@@ -808,7 +808,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_mmav1_block
- func @convert_layout_mmav1_blocked(%arg0: tensor<32x64xf32, #mma>) {
+ func.func @convert_layout_mmav1_blocked(%arg0: tensor<32x64xf32, #mma>) {
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
// CHECK: llvm.store
@@ -831,7 +831,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_blocked_shared
- func @convert_layout_blocked_shared(%arg0: tensor<128x32xf32, #blocked0>) {
+ func.func @convert_layout_blocked_shared(%arg0: tensor<128x32xf32, #blocked0>) {
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<8xf32>, 3>
// CHECK: llvm.store
@@ -847,7 +847,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_blocked1d_to_slice0
- func @convert_blocked1d_to_slice0(%src:tensor<32xi32, #blocked0>) {
+ func.func @convert_blocked1d_to_slice0(%src:tensor<32xi32, #blocked0>) {
// CHECK-COUNT-4: llvm.load {{.*}} : !llvm.ptr<vector<1xi32>, 3>
%cvt = triton_gpu.convert_layout %src : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
return
@@ -860,7 +860,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_blocked1d_to_slice1
- func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) {
+ func.func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) {
// CHECK-COUNT-32: llvm.load {{.*}} : !llvm.ptr<vector<1xi32>, 3>
%cvt = triton_gpu.convert_layout %src : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
return
@@ -873,7 +873,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#blocked1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_blocked_to_blocked_ptr
- func @convert_blocked_to_blocked_ptr(%src:tensor<32x!tt.ptr<f32>, #blocked0>) {
+ func.func @convert_blocked_to_blocked_ptr(%src:tensor<32x!tt.ptr<f32>, #blocked0>) {
// CHECK: llvm.ptrtoint
// CHECK: llvm.store
// CHECK: nvvm.barrier0
@@ -892,7 +892,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
- func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
+ func.func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
%a:tensor<128x32xf16, #shared>, %b:tensor<32x256xf16, #shared>) {
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma>
// CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16
@@ -918,7 +918,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, isMMAv1Row=true}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, isMMAv1Row=true}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
- func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
+ func.func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
%a:tensor<32x64xf16, #shared0>, %b:tensor<64x64xf16, #shared1>) {
%cst = arith.constant dense<0.000000e+00> : tensor<32x64xf32, #mma>
// CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16
@@ -941,7 +941,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#blocked}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
- func @matmul_fmadot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
+ func.func @matmul_fmadot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
%a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) {
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked>
// CHECK: llvm.intr.fmuladd
@@ -965,7 +965,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: matmul_tf32dot
- func @matmul_tf32dot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
+ func.func @matmul_tf32dot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
%a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) {
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
// CHECK: llvm.inline_asm
@@ -1000,7 +1000,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: atomic_add_f32
- func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
+ func.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
// CHECK: llvm.inline_asm
// CHECK-SAME: atom.global.gpu.add.f32
%0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32} : (tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0>
@@ -1012,7 +1012,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
-func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
+func.func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
%blockidx = tt.get_program_id {axis=0:i32} : i32
%blockidy = tt.get_program_id {axis=1:i32} : i32
%blockidz = tt.get_program_id {axis=2:i32} : i32
@@ -1032,7 +1032,7 @@ func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
- func @test_get_num_program(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
+ func.func @test_get_num_program(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
// CHECK: nvvm.read.ptx.sreg.nctaid.x
// CHECK: nvvm.read.ptx.sreg.nctaid.y
// CHECK: nvvm.read.ptx.sreg.nctaid.z
@@ -1052,7 +1052,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: test_index_cache
- func @test_index_cache() {
+ func.func @test_index_cache() {
// CHECK: nvvm.read.ptx.sreg.tid.x
%0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
// CHECK-NOT: nvvm.read.ptx.sreg.tid.x
@@ -1066,7 +1066,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: test_base_index_cache
- func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) {
+ func.func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) {
// CHECK: nvvm.read.ptx.sreg.tid.x
%0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
// CHECK-NOT: nvvm.read.ptx.sreg.tid.x
@@ -1080,7 +1080,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: test_index_cache_different_block
- func @test_index_cache_different_block(%arg0: tensor<128x32xf32, #blocked0>, %arg1: i1) {
+ func.func @test_index_cache_different_block(%arg0: tensor<128x32xf32, #blocked0>, %arg1: i1) {
// CHECK: nvvm.read.ptx.sreg.tid.x
%0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0>
scf.if %arg1 {
diff --git a/test/Target/tritongpu_to_llvmir.mlir b/test/Target/tritongpu_to_llvmir.mlir
index cafff3ca60..114d3a9eb2 100644
--- a/test/Target/tritongpu_to_llvmir.mlir
+++ b/test/Target/tritongpu_to_llvmir.mlir
@@ -4,11 +4,11 @@
// CHECK-LABEL: ; ModuleID = 'LLVMDialectModule'
// CHECK: define void @test_empty_kernel
// CHECK: !nvvm.annotations
-// CHECK: !{void (i32, half addrspace(1)*)* @test_empty_kernel, !"maxntidx", i32 128}
+// CHECK: !{ptr @test_empty_kernel, !"maxntidx", i32 128}
module attributes {"triton_gpu.num-warps" = 4 : i32} {
-func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
+func.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
return
}
diff --git a/test/Target/tritongpu_to_ptx.mlir b/test/Target/tritongpu_to_ptx.mlir
index 404e970a29..12742ad9e2 100644
--- a/test/Target/tritongpu_to_ptx.mlir
+++ b/test/Target/tritongpu_to_ptx.mlir
@@ -6,7 +6,7 @@
module attributes {"triton_gpu.num-warps" = 4 : i32} {
-func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
+func.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
return
}
diff --git a/test/Triton/combine.mlir b/test/Triton/combine.mlir
index 050a3f7565..5ef6790e69 100644
--- a/test/Triton/combine.mlir
+++ b/test/Triton/combine.mlir
@@ -2,10 +2,10 @@
// RUN: triton-opt %s -split-input-file -canonicalize -triton-combine | FileCheck %s
// CHECK-LABEL: @test_combine_dot_add_pattern
-func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32>) {
- // CHECK: %[[d:.*]] = arith.constant dense<3.000000e+00> : tensor<128x128xf32>
- // CHECK: %[[b:.*]] = arith.constant dense<2.000000e+00> : tensor<128x128xf32>
- // CHECK: %[[a:.*]] = arith.constant dense<1.000000e+00> : tensor<128x128xf32>
+func.func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+ // CHECK-DAG: %[[d:.*]] = arith.constant dense<3.000000e+00> : tensor<128x128xf32>
+ // CHECK-DAG: %[[b:.*]] = arith.constant dense<2.000000e+00> : tensor<128x128xf32>
+ // CHECK-DAG: %[[a:.*]] = arith.constant dense<1.000000e+00> : tensor<128x128xf32>
%a = arith.constant dense<1.0> : tensor<128x128xf32>
%b = arith.constant dense<2.0> : tensor<128x128xf32>
%zero = arith.constant dense<0.0> : tensor<128x128xf32>
@@ -24,7 +24,7 @@ func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32
// COM: CHECK-LABEL: @test_combine_addptr_pattern
-func @test_combine_addptr_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> {
+func.func @test_combine_addptr_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> {
%off0 = arith.constant 10 : i32
%off1 = arith.constant 15 : i32
@@ -47,46 +47,46 @@ func @test_combine_addptr_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>>
// CHECK-LABEL: @test_combine_select_masked_load_pattern
-func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) {
+func.func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) {
%mask = tt.broadcast %cond : (i1) -> tensor<8xi1>
%false_val = arith.constant dense<0.0> : tensor<8xf32>
// CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
%x = tt.load %ptr, %mask, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
- %0 = select %cond, %x, %false_val : tensor<8xf32>
+ %0 = arith.select %cond, %x, %false_val : tensor<8xf32>
// CHECK: %[[res2:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
%y = tt.load %ptr, %mask, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
- %1 = select %cond, %y, %false_val : tensor<8xf32>
+ %1 = arith.select %cond, %y, %false_val : tensor<8xf32>
// CHECK: return %[[res1]], %[[res2]] : tensor<8xf32>, tensor<8xf32>
return %0, %1 : tensor<8xf32>, tensor<8xf32>
}
// CHECK-LABEL: @test_combine_select_masked_load_fail_pattern
-func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %dummy_load: tensor<8xf32>, %dummy_broadcast: tensor<8xi1>, %cond0: i1, %cond1: i1) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) {
+func.func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %dummy_load: tensor<8xf32>, %dummy_broadcast: tensor<8xi1>, %cond0: i1, %cond1: i1) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) {
%false_val = arith.constant dense<0.0> : tensor<8xf32>
// Case 1: value at the "load" position is not an "op". Select should not be canonicalized.
- // CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
- %0 = select %cond0, %dummy_load, %false_val : tensor<8xf32>
+ // CHECK: %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
+ %0 = arith.select %cond0, %dummy_load, %false_val : tensor<8xf32>
// Case 2: value at the "broadcast" position is not an "op". Select should not be canonicalized.
%real_load0 = tt.load %ptr, %dummy_broadcast, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
- // CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
- %1 = select %cond0, %real_load0, %false_val : tensor<8xf32>
+ // CHECK: %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
+ %1 = arith.select %cond0, %real_load0, %false_val : tensor<8xf32>
// Case 3: condition of "broadcast" is not the same as the condition of "select". Select should not be canonicalized.
%cond0_ = tt.broadcast %cond0 : (i1) -> tensor<8xi1>
%real_load1 = tt.load %ptr, %cond0_, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
- // CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
- %2 = select %cond1, %real_load1, %false_val : tensor<8xf32>
+ // CHECK: %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
+ %2 = arith.select %cond1, %real_load1, %false_val : tensor<8xf32>
return %0, %1, %2 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
}
// CHECK-LABEL: @test_combine_broadcast_constant_pattern
-func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> {
+func.func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> {
// CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<8x2xf32>
%const = arith.constant dense<1.0> : tensor<8xf32>
%bst_out = tt.broadcast %const : (tensor<8xf32>) -> tensor<8x2xf32>
@@ -96,7 +96,7 @@ func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> {
}
// CHECK-LABEL: @test_canonicalize_masked_load_pattern
-func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) {
+func.func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) {
%true_mask = arith.constant dense<true> : tensor<8xi1>
%false_mask = arith.constant dense<false> : tensor<8xi1>
%other_val = arith.constant dense<0.0> : tensor<8xf32>
@@ -117,7 +117,7 @@ func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>) -> (te
}
// CHECK-LABEL: @test_canonicalize_masked_load_fail_pattern
-func @test_canonicalize_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %mask: tensor<8xi1>) -> (tensor<8xf32>, tensor<8xf32>) {
+func.func @test_canonicalize_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %mask: tensor<8xi1>) -> (tensor<8xf32>, tensor<8xf32>) {
%other_val = arith.constant dense<0.0> : tensor<8xf32>
// Case: value at the "mask" position is not an "op". Load should not be canonicalized.
@@ -130,7 +130,7 @@ func @test_canonicalize_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %
}
// CHECK-LABEL: @test_canonicalize_masked_store_pattern
-func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>) {
+func.func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>) {
%true_mask = arith.constant dense<true> : tensor<8xi1>
%false_mask = arith.constant dense<false> : tensor<8xi1>
@@ -144,7 +144,7 @@ func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val:
}
// CHECK-LABEL: @test_canonicalize_masked_store_fail_pattern
-func @test_canonicalize_masked_store_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>, %mask: tensor<8xi1>) {
+func.func @test_canonicalize_masked_store_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>, %mask: tensor<8xi1>) {
// Case: value at the "mask" position is not an "op". Store should not be canonicalized.
// CHECK: tt.store %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
tt.store %ptr, %val, %mask : tensor<8xf32>
diff --git a/test/Triton/vecadd.mlir b/test/Triton/vecadd.mlir
index 0b69ef3054..f5019b1cdd 100644
--- a/test/Triton/vecadd.mlir
+++ b/test/Triton/vecadd.mlir
@@ -1,7 +1,7 @@
// RUN: triton-opt %s -verify-diagnostics
module {
- func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) {
+ func.func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) {
%0 = tt.get_program_id {axis = 0 : i32} : i32
%c256_i32 = arith.constant 256 : i32
%1 = arith.muli %0, %c256_i32 : i32
@@ -43,7 +43,7 @@ module {
}
}
// module {
-// func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) {
+// func.func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32, %arg5: i32) {
// %c64 = arith.constant 64 : index
// %c32 = arith.constant 32 : index
// %c0 = arith.constant 0 : index
diff --git a/test/TritonGPU/coalesce.mlir b/test/TritonGPU/coalesce.mlir
index 60e359f527..51cccccfbd 100644
--- a/test/TritonGPU/coalesce.mlir
+++ b/test/TritonGPU/coalesce.mlir
@@ -19,7 +19,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: [[store_val:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xf32, [[col_layout]]>
// CHECK: [[store_mask:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xi1, [[col_layout]]>
// CHECK: tt.store [[store_ptr]], [[store_val]], [[store_mask]]
-func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
+func.func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%arg1: i32 {tt.divisibility = 16 : i32},
%arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%arg3: i32 {tt.divisibility = 16 : i32}) {
diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir
index 2c009ffa48..7e9cb9d504 100644
--- a/test/TritonGPU/combine.mlir
+++ b/test/TritonGPU/combine.mlir
@@ -9,7 +9,7 @@
// CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
// CHECK: [[col_layout_novec:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
// CHECK-LABEL: cst
-func @cst() -> tensor<1024xi32, #layout1> {
+func.func @cst() -> tensor<1024xi32, #layout1> {
%cst = arith.constant dense<0> : tensor<1024xi32, #layout0>
%1 = triton_gpu.convert_layout %cst : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1>
// CHECK-NOT: triton_gpu.convert_layout
@@ -18,7 +18,7 @@ func @cst() -> tensor<1024xi32, #layout1> {
}
// CHECK-LABEL: range
-func @range() -> tensor<1024xi32, #layout1> {
+func.func @range() -> tensor<1024xi32, #layout1> {
%0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0>
%1 = triton_gpu.convert_layout %0 : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1>
// CHECK-NOT: triton_gpu.convert_layout
@@ -27,7 +27,7 @@ func @range() -> tensor<1024xi32, #layout1> {
}
// CHECK-LABEL: splat
-func @splat(%arg0: i32) -> tensor<1024xi32, #layout1> {
+func.func @splat(%arg0: i32) -> tensor<1024xi32, #layout1> {
%0 = tt.splat %arg0 : (i32) -> tensor<1024xi32, #layout0>
%1 = triton_gpu.convert_layout %0 : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1>
// CHECK-NOT: triton_gpu.convert_layout
@@ -36,7 +36,7 @@ func @splat(%arg0: i32) -> tensor<1024xi32, #layout1> {
}
// CHECK-LABEL: remat
-func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
+func.func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
%0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0>
%1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0>
%2 = arith.muli %0, %1 : tensor<1024xi32, #layout0>
@@ -56,7 +56,7 @@ func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
}
// CHECK-LABEL: remat_load_store
-func @remat_load_store(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
+func.func @remat_load_store(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #layout0>
%1 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<64x!tt.ptr<i32>, #layout0>
%2 = tt.addptr %1, %0 : tensor<64x!tt.ptr<i32>, #layout0>, tensor<64xi32, #layout0>
@@ -70,7 +70,7 @@ func @remat_load_store(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
// Don't rematerialize vectorized loads
// CHECK-LABEL: remat_expensive
-func @remat_expensive(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
+func.func @remat_expensive(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #layout1>
%1 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<64x!tt.ptr<i32>, #layout1>
%2 = tt.addptr %1, %0 : tensor<64x!tt.ptr<i32>, #layout1>, tensor<64xi32, #layout1>
@@ -85,7 +85,7 @@ func @remat_expensive(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
// Don't rematerialize loads when original and target layouts are different
// CHECK-LABEL: remat_multi_layout
-func @remat_multi_layout(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
+func.func @remat_multi_layout(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #layout0>
%1 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<64x!tt.ptr<i32>, #layout0>
%2 = tt.addptr %1, %0 : tensor<64x!tt.ptr<i32>, #layout0>, tensor<64xi32, #layout0>
@@ -100,7 +100,7 @@ func @remat_multi_layout(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
// Always rematerialize single value loads
// CHECK-LABEL: remat_single_value
-func @remat_single_value(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
+func.func @remat_single_value(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
%0 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<1x!tt.ptr<i32>, #layout1>
%1 = tt.load %0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1xi32, #layout1>
// CHECK-NOT: triton_gpu.convert_layout
@@ -111,7 +111,7 @@ func @remat_single_value(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
}
// CHECK-LABEL: if
-func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
+func.func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
// CHECK-NOT: triton_gpu.convert_layout
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout1>
%0 = tt.get_program_id {axis = 0 : i32} : i32
@@ -128,7 +128,7 @@ func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
}
// CHECK-LABEL: if_convert_else_not
-func @if_convert_else_not(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
+func.func @if_convert_else_not(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0>
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0>
@@ -149,7 +149,7 @@ func @if_convert_else_not(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16
}
// CHECK-LABEL: if_not_else_convert
-func @if_not_else_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
+func.func @if_not_else_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0>
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0>
@@ -170,7 +170,7 @@ func @if_not_else_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16
}
// CHECK-LABEL: if_else_both_convert
-func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
+func.func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
%c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0>
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0>
@@ -200,7 +200,7 @@ func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16
#blocked4 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
// CHECK-LABEL: transpose
-func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
+func.func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
// CHECK-NOT: triton_gpu.convert_layout
// CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, {{%cst.*}}, {{%cst.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, [[row_layout]]>
// CHECK: [[cvt_val:%.*]] = triton_gpu.convert_layout [[loaded_val]] : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]>
@@ -241,7 +241,7 @@ func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt
}
// CHECK-LABEL: loop
-func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32) {
+func.func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32) {
// CHECK-NOT: triton_gpu.convert_layout
// CHECK: [[loop_ret:%.*]]:2 = scf.for {{.*}} -> (tensor<64x64xf32, [[row_layout]]>, tensor<64x64x!tt.ptr<f32>, [[row_layout]]>)
// CHECK-NEXT: {{.*}} = tt.load {{.*}} : tensor<64x64xf32, [[row_layout]]>
@@ -295,7 +295,7 @@ func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %ar
}
// CHECK-LABEL: vecadd
-func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
+func.func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
// CHECK-NOT: triton_gpu.convert_layout
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
@@ -327,7 +327,7 @@ func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f3
// Select has args with different element types
// CHECK-LABEL: select
-func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) {
+func.func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) {
// CHECK-NOT: triton_gpu.convert_layout
%cst = arith.constant dense<30000> : tensor<1x1xi32, #blocked2>
%cst_0 = arith.constant dense<30000> : tensor<1x512xi32, #blocked2>
@@ -378,7 +378,7 @@ func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f6
// Make sure the following IR doesn't hang the compiler.
// CHECK-LABEL: long_func
-func public @long_func(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg13: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg14: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg15: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}) {
+func.func public @long_func(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg13: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg14: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg15: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<1.000000e+00> : tensor<1024xf32, #blocked0>
%cst_0 = arith.constant dense<5.000000e-04> : tensor<1024xf32, #blocked0>
%cst_1 = arith.constant dense<0.999499976> : tensor<1024xf32, #blocked0>
@@ -775,7 +775,7 @@ func public @long_func(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg1:
// A mnist model from torch inductor.
// Check if topological sort is working correct and there's no unnecessary convert
// CHECK-LABEL: mnist
-func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: i32) {
+func.func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: i32) {
// CHECK-NOT: triton_gpu.convert_layout
%cst = arith.constant dense<10> : tensor<16x1xi32, #blocked2>
%cst_0 = arith.constant dense<10> : tensor<1x16xi32, #blocked3>
@@ -862,7 +862,7 @@ func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.
#blocked5 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
// cmpf and cmpi have different operands and result types
// CHECK-LABEL: cmp
-func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) {
+func.func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) {
%c64 = arith.constant 64 : index
%c2048 = arith.constant 2048 : index
%c0 = arith.constant 0 : index
diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir
index 6ee3b15fbc..663f2da7b0 100644
--- a/test/TritonGPU/loop-pipeline.mlir
+++ b/test/TritonGPU/loop-pipeline.mlir
@@ -10,7 +10,7 @@
#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
-// CHECK: func @matmul_loop
+// CHECK: func.func @matmul_loop
// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32
// CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32
@@ -46,8 +46,8 @@
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]]
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]]
-func @matmul_loop(%lb : index, %ub : index, %step : index,
- %A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
+func.func @matmul_loop(%lb : index, %ub : index, %step : index,
+ %A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
%B : !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
// A ptrs
%a_ptr_splat = tt.splat %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
@@ -61,7 +61,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index,
%b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : (tensor<128xi32, #BLs0>) -> tensor<1x128xi32, #BL>
%b_offs = tt.broadcast %b_tmp1 : (tensor<1x128xi32, #BL>) -> tensor<32x128xi32, #BL>
%b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
-
+
%a_mask = arith.constant dense<true> : tensor<128x32xi1, #AL>
%a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL>
@@ -88,7 +88,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index,
}
-// CHECK: func @matmul_loop_nested
+// CHECK: func.func @matmul_loop_nested
// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32
// CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32
@@ -118,8 +118,8 @@ func @matmul_loop(%lb : index, %ub : index, %step : index,
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]]
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]]
-func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
- %A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
+func.func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
+ %A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
%B : !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
scf.for %iv0 = %lb to %ub step %step {
// A ptrs
@@ -134,7 +134,7 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
%b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : (tensor<128xi32, #BLs0>) -> tensor<1x128xi32, #BL>
%b_offs = tt.broadcast %b_tmp1 : (tensor<1x128xi32, #BL>) -> tensor<32x128xi32, #BL>
%b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
-
+
%a_mask = arith.constant dense<true> : tensor<128x32xi1, #AL>
%a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL>
%b_mask = arith.constant dense<true> : tensor<32x128xi1, #BL>
@@ -161,7 +161,7 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
}
-// CHECK: func @matmul_loop_single_pipeline
+// CHECK: func.func @matmul_loop_single_pipeline
// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32
// CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32
@@ -183,8 +183,8 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]]
// CHECK: scf.yield {{.*}}, {{.*}}, %[[NEXT_B_BUFFER]], %[[NEXT_B]], {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]]
-func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
- %A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
+func.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
+ %A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
%B : !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
// A ptrs
%a_ptr_splat = tt.splat %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
diff --git a/test/TritonGPU/matmul.mlir b/test/TritonGPU/matmul.mlir
index 9bd5318e1e..01dc3f0ab1 100644
--- a/test/TritonGPU/matmul.mlir
+++ b/test/TritonGPU/matmul.mlir
@@ -4,7 +4,7 @@
// CHECK: offset = 49152, size = 49152
// CHECK: size = 98304
module {
-func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c64_13c64_14c64_15c8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32) {
+func.func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c64_13c64_14c64_15c8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32) {
%cst = arith.constant dense<true> : tensor<64x64xi1>
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
@@ -22,7 +22,7 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6
%7 = arith.muli %6, %c8_i32 : i32
%8 = arith.subi %2, %7 : i32
%9 = arith.cmpi slt, %8, %c8_i32 : i32
- %10 = select %9, %8, %c8_i32 : i32
+ %10 = arith.select %9, %8, %c8_i32 : i32
%11 = arith.remsi %0, %10 : i32
%12 = arith.addi %7, %11 : i32
%13 = arith.remsi %0, %5 : i32
diff --git a/test/TritonGPU/prefetch.mlir b/test/TritonGPU/prefetch.mlir
index 52b4dddec1..b427547890 100644
--- a/test/TritonGPU/prefetch.mlir
+++ b/test/TritonGPU/prefetch.mlir
@@ -11,7 +11,7 @@
#B_OP = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
-// CHECK: func @matmul_loop
+// CHECK: func.func @matmul_loop
// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = tensor.extract_slice %[[A0:.*]][0, 0] [128, 16]
// CHECK-DAG: %[[A0_PREFETCH:.*]] = triton_gpu.convert_layout %[[A0_PREFETCH_SMEM]]
// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = tensor.extract_slice %[[B0:.*]][0, 0] [16, 128]
@@ -28,7 +28,7 @@
// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = tensor.extract_slice {{.*}}[0, 0] [16, 128]
// CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = triton_gpu.convert_layout %[[NEXT_B_PREFETCH_SMEM]]
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH]], %[[NEXT_B_PREFETCH]]
-func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
+func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
%b_ptr_init = tt.broadcast %B : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #BL>
diff --git a/test/TritonGPU/update-mma-for-volta.mlir b/test/TritonGPU/update-mma-for-volta.mlir
index d587fffcca..7571ec6185 100644
--- a/test/TritonGPU/update-mma-for-volta.mlir
+++ b/test/TritonGPU/update-mma-for-volta.mlir
@@ -15,7 +15,7 @@
// CHECK: [[new_mma:#mma.*]] = #triton_gpu.mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [4, 2]}>
module attributes {"triton_gpu.num-warps" = 16 : i32} {
// CHECK-LABEL: dot_mmav1
- func @dot_mmav1(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) -> tensor<64x64xf32, #blocked0> {
+ func.func @dot_mmav1(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) -> tensor<64x64xf32, #blocked0> {
%C = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked0>
%AA = triton_gpu.convert_layout %A : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_a>
%BB = triton_gpu.convert_layout %B : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_b>
@@ -50,7 +50,7 @@ module attributes {"triton_gpu.num-warps" = 16 : i32} {
module attributes {"triton_gpu.num-warps" = 16 : i32} {
// CHECK-LABEL: dot_mmav1
- func @dot_mmav1(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) -> tensor<64x64xf32, #blocked0> {
+ func.func @dot_mmav1(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) -> tensor<64x64xf32, #blocked0> {
%C = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked0>
%AA = triton_gpu.convert_layout %A : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_a>
%BB = triton_gpu.convert_layout %B : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_b>
diff --git a/test/lib/Analysis/TestAlias.cpp b/test/lib/Analysis/TestAlias.cpp
index 88a4118fe9..3fd0cfd0d3 100644
--- a/test/lib/Analysis/TestAlias.cpp
+++ b/test/lib/Analysis/TestAlias.cpp
@@ -9,10 +9,10 @@ using namespace mlir;
namespace {
struct TestAliasPass
- : public PassWrapper<TestAliasPass, OperationPass<FuncOp>> {
+ : public PassWrapper<TestAliasPass, OperationPass<func::FuncOp>> {
+
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasPass);
- // LLVM15+
- // MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasPass);
static void print(StringRef name, SmallVector<std::string, 4> &vals,
raw_ostream &os) {
if (vals.empty())
@@ -39,23 +39,24 @@ struct TestAliasPass
auto opName = SymbolTable::getSymbolName(operation).getValue().str();
os << opName << "\n";
- SharedMemoryAliasAnalysis analysis(&getContext());
- analysis.run(operation);
+ std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
+ SharedMemoryAliasAnalysis *analysis =
+ solver->load<SharedMemoryAliasAnalysis>();
+ if (failed(solver->initializeAndRun(operation)))
+ return signalPassFailure();
AsmState state(operation->getParentOfType<ModuleOp>());
// Get operation ids of value's aliases
auto getAllocOpNames = [&](Value value) {
- LatticeElement<AliasInfo> *latticeElement =
- analysis.lookupLatticeElement(value);
+ dataflow::Lattice<AliasInfo> *latticeElement =
+ analysis->getLatticeElement(value);
SmallVector<std::string, 4> opNames;
- if (latticeElement) {
+ if (latticeElement && !latticeElement->isUninitialized()) {
auto &info = latticeElement->getValue();
- if (!info.getAllocs().empty()) {
- for (auto &alias : info.getAllocs()) {
- auto opName =
- getValueOperandName(alias.getDefiningOp()->getResult(0), state);
- opNames.push_back(std::move(opName));
- }
+ for (auto &alias : info.getAllocs()) {
+ auto opName =
+ getValueOperandName(alias.getDefiningOp()->getResult(0), state);
+ opNames.push_back(std::move(opName));
}
}
// Ensure deterministic output
diff --git a/test/lib/Analysis/TestAllocation.cpp b/test/lib/Analysis/TestAllocation.cpp
index 84108c4d36..35e42242bd 100644
--- a/test/lib/Analysis/TestAllocation.cpp
+++ b/test/lib/Analysis/TestAllocation.cpp
@@ -6,10 +6,9 @@ using namespace mlir;
namespace {
struct TestAllocationPass
- : public PassWrapper<TestAllocationPass, OperationPass<FuncOp>> {
+ : public PassWrapper<TestAllocationPass, OperationPass<func::FuncOp>> {
- // LLVM15+
- // MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllocationPass);
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllocationPass);
StringRef getArgument() const final { return "test-print-allocation"; }
StringRef getDescription() const final {
diff --git a/test/lib/Analysis/TestAxisInfo.cpp b/test/lib/Analysis/TestAxisInfo.cpp
index a5205bb0a0..22347c32f0 100644
--- a/test/lib/Analysis/TestAxisInfo.cpp
+++ b/test/lib/Analysis/TestAxisInfo.cpp
@@ -1,25 +1,15 @@
#include "mlir/Pass/Pass.h"
#include "triton/Analysis/AxisInfo.h"
+#include "triton/Analysis/Utility.h"
using namespace mlir;
namespace {
struct TestAxisInfoPass
- : public PassWrapper<TestAxisInfoPass, OperationPass<FuncOp>> {
+ : public PassWrapper<TestAxisInfoPass, OperationPass<func::FuncOp>> {
- // LLVM15+
- // MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAlignmentPass);
-
- void print(const std::string &name, raw_ostream &os, ArrayRef<int64_t> vals) {
- os << name << ": [";
- for (size_t d = 0; d < vals.size(); d++) {
- if (d != 0)
- os << ", ";
- os << vals[d];
- }
- os << "]";
- }
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAxisInfoPass);
StringRef getArgument() const final { return "test-print-alignment"; }
StringRef getDescription() const final {
@@ -30,38 +20,19 @@ struct TestAxisInfoPass
Operation *operation = getOperation();
auto &os = llvm::errs();
auto opName = SymbolTable::getSymbolName(operation).getValue().str();
- os << opName << "\n";
- AxisInfoAnalysis analysis(&getContext());
- analysis.run(operation);
+ os << "@" << opName << "\n";
+
+ std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
+ AxisInfoAnalysis *analysis = solver->load<AxisInfoAnalysis>();
+ if (failed(solver->initializeAndRun(operation)))
+ return signalPassFailure();
operation->walk([&](Operation *op) {
if (op->getNumResults() < 1)
return;
for (Value result : op->getResults()) {
- // std::ostringstream oss;
- // result.print(oss);
- // os << " => ";
- LatticeElement<AxisInfo> *latticeElement =
- analysis.lookupLatticeElement(result);
- if (!latticeElement) {
- os << "None\n";
- return;
- }
- AxisInfo &info = latticeElement->getValue();
- print("Contiguity", os, info.getContiguity());
- os << " ; ";
- print("Divisibility", os, info.getDivisibility());
- os << " ; ";
- print("Constancy", os, info.getConstancy());
- os << " ; ";
- auto constantValue = info.getConstantValue();
- os << "ConstantValue: [";
- if (constantValue.has_value())
- os << constantValue.value();
- else
- os << "None";
- os << "] ( ";
result.print(os);
- os << " ) ";
+ os << " => ";
+ analysis->getLatticeElement(result)->getValue().print(os);
os << "\n";
}
});
diff --git a/test/lib/Analysis/TestMembar.cpp b/test/lib/Analysis/TestMembar.cpp
index df4279fe24..ab9b9f3fb7 100644
--- a/test/lib/Analysis/TestMembar.cpp
+++ b/test/lib/Analysis/TestMembar.cpp
@@ -1,4 +1,4 @@
-#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Pass/Pass.h"
#include "triton/Analysis/Allocation.h"
@@ -9,10 +9,9 @@ using namespace mlir;
namespace {
struct TestMembarPass
- : public PassWrapper<TestMembarPass, OperationPass<FuncOp>> {
+ : public PassWrapper<TestMembarPass, OperationPass<func::FuncOp>> {
- // LLVM15+
- // MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMembarPass);
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMembarPass);
StringRef getArgument() const final { return "test-print-membar"; }
StringRef getDescription() const final {