diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td index a3c41f2e052cd206bb5c627a29e0fad9aea307fa..748ebb4a9b510da7526959d356a9e7040de8b790 100644 --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -2984,6 +2984,7 @@ let Predicates = [HasSVEorSME] in { } defm Pat_Store_P16 : unpred_store_predicate; + defm Pat_Store_P4 : unpred_store_predicate; multiclass unpred_load_predicate { def _fi : Pat<(Ty (load (am_sve_fi GPR64sp:$base, simm9:$offset))), @@ -2994,6 +2995,7 @@ let Predicates = [HasSVEorSME] in { } defm Pat_Load_P16 : unpred_load_predicate; + defm Pat_Load_P4 : unpred_load_predicate; multiclass ld1 { diff --git a/mlir/include/mlir/Conversion/PtrToLLVM/PtrToLLVM.h b/mlir/include/mlir/Conversion/PtrToLLVM/PtrToLLVM.h new file mode 100644 index 0000000000000000000000000000000000000000..0ff92bc85668c230cfa540d2459d18978117033d --- /dev/null +++ b/mlir/include/mlir/Conversion/PtrToLLVM/PtrToLLVM.h @@ -0,0 +1,27 @@ +//===- PtrToLLVM.h - Ptr to LLVM dialect conversion -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_PTRTOLLVM_PTRTOLLVM_H +#define MLIR_CONVERSION_PTRTOLLVM_PTRTOLLVM_H + +#include + +namespace mlir { +class DialectRegistry; +class LLVMTypeConverter; +class RewritePatternSet; +namespace ptr { +/// Populate the convert to LLVM patterns for the `ptr` dialect. +void populatePtrToLLVMConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns); +/// Register the convert to LLVM interface for the `ptr` dialect. +void registerConvertPtrToLLVMInterface(DialectRegistry ®istry); +} // namespace ptr +} // namespace mlir + +#endif // MLIR_CONVERSION_PTRTOLLVM_PTRTOLLVM_H diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h index 9dc262cc72ed0012402f21445a2fa3c421f185f4..1be7140ac71df8606f38ebd84413487a4c54bb75 100644 --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h @@ -57,6 +57,12 @@ void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns); /// Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts. void populateExpandBFloat16Patterns(RewritePatternSet &patterns); +/// Add patterns to expand Arith f8E5M2 patterns to lower level bitcasts/shifts. +void populateExpandF8E5M2Patterns(RewritePatternSet &patterns); + +// Add patterns to expand Arith f8E4M3 patterns to lower level bitcasts/shifts. +void populateExpandF8E4M3FNPatterns(RewritePatternSet &patterns); + /// Add patterns to expand Arith ops. void populateArithExpandOpsPatterns(RewritePatternSet &patterns); diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td index 1517f71f1a7c9080043c520140826c649b23fd41..eba1ec43055004df4df6963dc6f48f6957483817 100644 --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td @@ -17,6 +17,10 @@ def ArithExpandOpsPass : Pass<"arith-expand"> { let options = [ Option<"includeBf16", "include-bf16", "bool", /*default=*/"false", "Enable the BF16 expansion patterns">, + Option<"includeF8E5M2", "include-f8e5m2", "bool", /*default=*/"false", + "Enable the F8E5M2 expansion patterns">, + Option<"includeF8E4M3FN", "include-f8e4m3fn", "bool", /*default=*/"false", + "Enable the F8E4M3FN expansion patterns">, ]; } diff --git a/mlir/include/mlir/Dialect/ArmSME/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSME/CMakeLists.txt index 9f57627c321fb0c74b3e4a404e3c36bd435f64a7..cb1e9d01821a2cf352b79c28c44da4ddd33dd3e9 100644 --- a/mlir/include/mlir/Dialect/ArmSME/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/ArmSME/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(TransformOps) diff --git a/mlir/include/mlir/Dialect/ArmSME/TransformOps/ArmSMEVectorTransformOps.h b/mlir/include/mlir/Dialect/ArmSME/TransformOps/ArmSMEVectorTransformOps.h new file mode 100644 index 0000000000000000000000000000000000000000..61212afbbd9ee47c295460bb1b11e6b75d386c7a --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSME/TransformOps/ArmSMEVectorTransformOps.h @@ -0,0 +1,31 @@ +//===- ArmSMEVectorTransformOps.h - Vector transform ops --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARM_SME_VECTOR_TRANSFORMOPS_H +#define MLIR_DIALECT_ARM_SME_VECTOR_TRANSFORMOPS_H + +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/IR/OpImplementation.h" + +//===----------------------------------------------------------------------===// +// ArmSME Vector Transform Operations +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/ArmSME/TransformOps/ArmSMEVectorTransformOps.h.inc" + +namespace mlir { +class DialectRegistry; + +namespace arm_sme { +void registerTransformDialectExtension(DialectRegistry ®istry); + +} // namespace arm_sve +} // namespace mlir + +#endif // MLIR_DIALECT_ARM_SME_VECTOR_TRANSFORMOPS_H diff --git a/mlir/include/mlir/Dialect/ArmSME/TransformOps/ArmSMEVectorTransformOps.td b/mlir/include/mlir/Dialect/ArmSME/TransformOps/ArmSMEVectorTransformOps.td new file mode 100644 index 0000000000000000000000000000000000000000..12d0fa995ab5aaa1faf76ca04e74f5bbc61ec837 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSME/TransformOps/ArmSMEVectorTransformOps.td @@ -0,0 +1,15 @@ +//===- ArmSMEVectorTransformOps.td - Arm SME transform ops--*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef ARMSME_VECTOR_TRANSFORM_OPS +#define ARMSME_VECTOR_TRANSFORM_OPS + +include "mlir/Dialect/Transform/IR/TransformAttrs.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" + +#endif // ARMSME_VECTOR_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/ArmSME/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSME/TransformOps/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..9e7990b68230a2965ffe050b22a61c8af13dd58b --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSME/TransformOps/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS ArmSMEVectorTransformOps.td) +mlir_tablegen(ArmSMEVectorTransformOps.h.inc -gen-op-decls) +mlir_tablegen(ArmSMEVectorTransformOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRArmSMEVectorTransformOpsIncGen) + +add_mlir_doc(ArmSMEVectorTransformOps ArmSMEVectorTransformOps Dialects/ -gen-op-doc) diff --git a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt index 9f57627c321fb0c74b3e4a404e3c36bd435f64a7..cb1e9d01821a2cf352b79c28c44da4ddd33dd3e9 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(TransformOps) diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h new file mode 100644 index 0000000000000000000000000000000000000000..7f22cd1fe6435af0258c9dec78423b48678b8c8a --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h @@ -0,0 +1,31 @@ +//===- ArmSVEVectorTransformOps.h - Vector transform ops --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H +#define MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H + +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/IR/OpImplementation.h" + +//===----------------------------------------------------------------------===// +// ArmSVE Vector Transform Operations +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h.inc" + +namespace mlir { +class DialectRegistry; + +namespace arm_sve { +void registerTransformDialectExtension(DialectRegistry ®istry); + +} // namespace arm_sve +} // namespace mlir + +#endif // MLIR_DIALECT_ARM_SVE_VECTOR_TRANSFORMOPS_H diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td new file mode 100644 index 0000000000000000000000000000000000000000..00c69cfd6562a672b6321816748f37f031f3d5db --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td @@ -0,0 +1,15 @@ +//===- ArmSVEVectorTransformOps.td - Arm SVE transform ops--*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef ARMSVE_VECTOR_TRANSFORM_OPS +#define ARMSVE_VECTOR_TRANSFORM_OPS + +include "mlir/Dialect/Transform/IR/TransformAttrs.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" + +#endif // ARMSVE_VECTOR_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..ce8d8fea7f188ac450b82869d39929536b68391b --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS ArmSVEVectorTransformOps.td) +mlir_tablegen(ArmSVEVectorTransformOps.h.inc -gen-op-decls) +mlir_tablegen(ArmSVEVectorTransformOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRArmSVEVectorTransformOpsIncGen) + +add_mlir_doc(ArmSVEVectorTransformOps ArmSVEVectorTransformOps Dialects/ -gen-op-doc) diff --git a/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td b/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td index 443e3128b4acb349d3f969f14ad3c2953f286227..1b6faa9387e5ec050c21943298bb9d62017a2b8b 100644 --- a/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td +++ b/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td @@ -81,6 +81,9 @@ def DLTI_DataLayoutSpecAttr : /// Returns the endiannes identifier. StringAttr getEndiannessIdentifier(MLIRContext *context) const; + + /// Returns the default memory space identifier. + StringAttr getDefaultMemorySpaceIdentifier(MLIRContext *context) const; /// Returns the alloca memory space identifier. StringAttr getAllocaMemorySpaceIdentifier(MLIRContext *context) const; diff --git a/mlir/include/mlir/Dialect/DLTI/DLTIBase.td b/mlir/include/mlir/Dialect/DLTI/DLTIBase.td index e26fbdb146645c410f0bf5df932c721964a7265e..65c69d38c537425a563780a1f55c8978b2b39e7d 100644 --- a/mlir/include/mlir/Dialect/DLTI/DLTIBase.td +++ b/mlir/include/mlir/Dialect/DLTI/DLTIBase.td @@ -55,6 +55,9 @@ def DLTI_Dialect : Dialect { constexpr const static ::llvm::StringLiteral kDataLayoutStackAlignmentKey = "dlti.stack_alignment"; + + constexpr const static ::llvm::StringLiteral + kDataLayoutDefaultMemorySpaceKey = "dlti.default_memory_space"; }]; let useDefaultAttributePrinterParser = 1; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index d2d1fbaf304b245df5e1b1386156043b8a20c414..55ec45dd14a93a728206dd44679722abbb7ee3c2 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -59,22 +59,9 @@ class LLVM_IntArithmeticOpWithOverflowFlag traits = []> : LLVM_ArithmeticOpBase], traits)> { - dag iofArg = (ins EnumProperty<"IntegerOverflowFlags">:$overflowFlags); + dag iofArg = (ins EnumProperty<"IntegerOverflowFlags", "", "IntegerOverflowFlags::none">:$overflowFlags); let arguments = !con(commonArgs, iofArg); - let builders = [ - OpBuilder<(ins "Type":$type, "Value":$lhs, "Value":$rhs, - "IntegerOverflowFlags":$overflowFlags), [{ - $_state.getOrAddProperties().overflowFlags = overflowFlags; - build($_builder, $_state, type, lhs, rhs); - }]>, - OpBuilder<(ins "Value":$lhs, "Value":$rhs, - "IntegerOverflowFlags":$overflowFlags), [{ - $_state.getOrAddProperties().overflowFlags = overflowFlags; - build($_builder, $_state, lhs, rhs); - }]> - ]; - string mlirBuilder = [{ auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs); moduleImport.setIntegerOverflowFlags(inst, op); diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h index 93733ccd4929ae9048392278740ff26134ce4e88..b9f03451b48c8f6f14e55c9fde62e47640d677f2 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -198,8 +198,10 @@ public: uint64_t getPreferredAlignment(const DataLayout &dataLayout, DataLayoutEntryListRef params) const; - bool areCompatible(DataLayoutEntryListRef oldLayout, - DataLayoutEntryListRef newLayout) const; + bool areCompatible( + DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout, + DataLayoutSpecInterface newSpec, + const DataLayoutIdentifiedEntryMap &map) const; LogicalResult verifyEntries(DataLayoutEntryListRef entries, Location loc) const; diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index ac61117c3d6e3612b76446e207e814c3ca8222b6..f20f036d6fe480583be1ddb0d778068f49df634e 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -311,7 +311,7 @@ def MapOp : LinalgStructuredBase_Op<"map", [ def ReduceOp : LinalgStructuredBase_Op<"reduce", [ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, - SameVariadicOperandSize, + AttrSizedOperandSegments, SingleBlockImplicitTerminator<"YieldOp">]> { let summary = "Reduce operator"; let description = [{ diff --git a/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt index df07b8d5a63d96eb03077f40ea37abd16592ebac..255af4c486cb102a4fc0007fb59b163d96726b77 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Ptr/IR/CMakeLists.txt @@ -5,3 +5,15 @@ set(LLVM_TARGET_DEFINITIONS PtrOps.td) mlir_tablegen(PtrOpsAttrs.h.inc -gen-attrdef-decls -attrdefs-dialect=ptr) mlir_tablegen(PtrOpsAttrs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=ptr) add_public_tablegen_target(MLIRPtrOpsAttributesIncGen) + +set(LLVM_TARGET_DEFINITIONS MemorySpaceInterfaces.td) +mlir_tablegen(MemorySpaceInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(MemorySpaceInterfaces.cpp.inc -gen-op-interface-defs) +mlir_tablegen(MemorySpaceAttrInterfaces.h.inc -gen-attr-interface-decls) +mlir_tablegen(MemorySpaceAttrInterfaces.cpp.inc -gen-attr-interface-defs) +add_public_tablegen_target(MLIRPtrMemorySpaceInterfacesIncGen) + +set(LLVM_TARGET_DEFINITIONS PtrOps.td) +mlir_tablegen(PtrOpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(PtrOpsEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(MLIRPtrOpsEnumsGen) \ No newline at end of file diff --git a/mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h b/mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h new file mode 100644 index 0000000000000000000000000000000000000000..3714c1caa36701271513791870ba9f6b934aadcc --- /dev/null +++ b/mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h @@ -0,0 +1,32 @@ +//===-- MemorySpaceInterfaces.h - ptr memory space interfaces ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the ptr dialect memory space interfaces. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_PTR_IR_MEMORYSPACEINTERFACES_H +#define MLIR_DIALECT_PTR_IR_MEMORYSPACEINTERFACES_H + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +class Operation; +namespace ptr { +enum class AtomicBinOp : uint64_t; +enum class AtomicOrdering : uint64_t; +} // namespace ptr +} // namespace mlir + +#include "mlir/Dialect/Ptr/IR/MemorySpaceAttrInterfaces.h.inc" + +#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h.inc" + +#endif // MLIR_DIALECT_PTR_IR_MEMORYSPACEINTERFACES_H \ No newline at end of file diff --git a/mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td b/mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td new file mode 100644 index 0000000000000000000000000000000000000000..cb7775c862a9863635dcbb22f6a4d19e7c96c998 --- /dev/null +++ b/mlir/include/mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td @@ -0,0 +1,117 @@ +//===-- MemorySpaceInterfaces.td - Memory space interfaces ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines memory space attribute interfaces. +// +//===----------------------------------------------------------------------===// + +#ifndef PTR_MEMORYSPACEINTERFACES +#define PTR_MEMORYSPACEINTERFACES + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// Memory space attribute interface. +//===----------------------------------------------------------------------===// + +def MemorySpaceAttrInterface : AttrInterface<"MemorySpaceAttrInterface"> { + let description = [{ + This interface defines a common API for interacting with the memory model of + a memory space and the operations in the pointer dialect. + + Furthermore, this interface allows concepts such as read-only memory to be + adequately modeled and enforced. + }]; + let cppNamespace = "::mlir::ptr"; + let methods = [ + InterfaceMethod< + /*desc=*/ [{ + This method checks if it's valid to load a value from the memory space + with a specific type, alignment, and atomic ordering. + If `emitError` is non-null then the method is allowed to emit errors. + }], + /*returnType=*/ "::mlir::LogicalResult", + /*methodName=*/ "isValidLoad", + /*args=*/ (ins "::mlir::Type":$type, + "::mlir::ptr::AtomicOrdering":$ordering, + "::mlir::IntegerAttr":$alignment, + "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError) + >, + InterfaceMethod< + /*desc=*/ [{ + This method checks if it's valid to store a value in the memory space + with a specific type, alignment, and atomic ordering. + If `emitError` is non-null then the method is allowed to emit errors. + }], + /*returnType=*/ "::mlir::LogicalResult", + /*methodName=*/ "isValidStore", + /*args=*/ (ins "::mlir::Type":$type, + "::mlir::ptr::AtomicOrdering":$ordering, + "::mlir::IntegerAttr":$alignment, + "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError) + >, + InterfaceMethod< + /*desc=*/ [{ + This method checks if it's valid to perform an atomic operation in the + memory space with a specific type, alignment, and atomic ordering. + If `emitError` is non-null then the method is allowed to emit errors. + }], + /*returnType=*/ "::mlir::LogicalResult", + /*methodName=*/ "isValidAtomicOp", + /*args=*/ (ins "::mlir::ptr::AtomicBinOp":$op, + "::mlir::Type":$type, + "::mlir::ptr::AtomicOrdering":$ordering, + "::mlir::IntegerAttr":$alignment, + "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError) + >, + InterfaceMethod< + /*desc=*/ [{ + This method checks if it's valid to perform an atomic exchange operation + in the memory space with a specific type, alignment, and atomic + orderings. + If `emitError` is non-null then the method is allowed to emit errors. + }], + /*returnType=*/ "::mlir::LogicalResult", + /*methodName=*/ "isValidAtomicXchg", + /*args=*/ (ins "::mlir::Type":$type, + "::mlir::ptr::AtomicOrdering":$successOrdering, + "::mlir::ptr::AtomicOrdering":$failureOrdering, + "::mlir::IntegerAttr":$alignment, + "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError) + >, + InterfaceMethod< + /*desc=*/ [{ + This method checks if it's valid to perform an `addrspacecast` op + in the memory space. + If `emitError` is non-null then the method is allowed to emit errors. + }], + /*returnType=*/ "::mlir::LogicalResult", + /*methodName=*/ "isValidAddrSpaceCast", + /*args=*/ (ins "::mlir::Type":$tgt, + "::mlir::Type":$src, + "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError) + >, + InterfaceMethod< + /*desc=*/ [{ + This method checks if it's valid to perform a `ptrtoint` or `inttoptr` + op in the memory space. + The first type is expected to be integer-like, while the second must be a + ptr-like type. + If `emitError` is non-null then the method is allowed to emit errors. + }], + /*returnType=*/ "::mlir::LogicalResult", + /*methodName=*/ "isValidPtrIntCast", + /*args=*/ (ins "::mlir::Type":$intLikeTy, + "::mlir::Type":$ptrLikeTy, + "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError) + >, + ]; +} + +#endif // PTR_MEMORYSPACEINTERFACES \ No newline at end of file diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrDefs.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrDefs.td index e75038f300f1a76a2f3024e1518695df43700b33..24ee1851c9a916fea909b2ac4a842b31d677e61c 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrDefs.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrDefs.td @@ -10,7 +10,9 @@ #define PTR_ATTRDEFS include "mlir/Dialect/Ptr/IR/PtrDialect.td" +include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td" include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinAttributeInterfaces.td" // All of the attributes will extend this class. class Ptr_Attr + ]> { + let summary = "Generic memory space"; + let description = [{ + The `generic_space` attribute defines a memory space attribute with the + following properties: + - Load and store operations are always valid, regardless of the type. + - Atomic operations are always valid, regardless of the type. + - Cast operations to `generic_space` are always valid. + + Example: + + ```mlir + #ptr.generic_space + ``` + }]; + let assemblyFormat = ""; +} + + //===----------------------------------------------------------------------===// // SpecAttr //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h index 72e767764d98b2951b013a891eec373738ad25c7..dc0a3ffd4ae33ddf483b1713a4cb32eb6698fcf3 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h @@ -13,9 +13,16 @@ #ifndef MLIR_DIALECT_PTR_IR_PTRATTRS_H #define MLIR_DIALECT_PTR_IR_PTRATTRS_H +#include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "llvm/Support/TypeSize.h" + +#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h" #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.h.inc" +#include "mlir/Dialect/Ptr/IR/PtrOpsEnums.h.inc" + #endif // MLIR_DIALECT_PTR_IR_PTRATTRS_H diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td index 14d72c3001d919d44a8b1e25d5b9f20f2f47a2ac..cd87d7474468ea07414f505a3a16f3a0e79968d6 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td @@ -37,6 +37,7 @@ class Ptr_Type traits = []> def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [ MemRefElementTypeInterface, + PtrLikeTypeInterface, DeclareTypeInterfaceMethods ]> { @@ -53,16 +54,67 @@ def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [ memory-space ::= attribute-value ``` }]; - let parameters = (ins OptionalParameter<"Attribute">:$memorySpace); - let assemblyFormat = "(`<` $memorySpace^ `>`)?"; + let parameters = (ins "MemorySpaceAttrInterface":$memorySpace); + let assemblyFormat = "`<` $memorySpace `>`"; let builders = [ - TypeBuilder<(ins CArg<"Attribute", "nullptr">:$memorySpace), [{ - return $_get($_ctxt, memorySpace); + TypeBuilderWithInferredContext<(ins + "MemorySpaceAttrInterface":$memorySpace), [{ + return $_get(memorySpace.getContext(), memorySpace); }]> ]; - let skipDefaultBuilders = 1; + let extraClassDeclaration = [{ + // `PtrLikeTypeInterface` interface methods. + /// Returns `Type()` as this pointer type is opaque. + Type getElementType() const { + return Type(); + } + /// Clones the pointer with specified memory space or returns failure + /// if an `elementType` was specified or if the memory space doesn't + /// implement `MemorySpaceAttrInterface`. + FailureOr clonePtrWith(Attribute memorySpace, + std::optional elementType) const { + if (elementType) + return failure(); + if (auto ms = memorySpace.dyn_cast()) + return llvm::cast(get(ms)); + return failure(); + } + /// `!ptr.ptr` types are seen as ptr-like objects with no metadata. + bool hasPtrMetadata() const { + return false; + } + }]; } +def Ptr_PtrMetadata : Ptr_Type<"PtrMetadata", "ptr_metadata"> { + let summary = "Pointer metadata type"; + let description = [{ + The `ptr_metadata` type represents an opaque-view of the metadata associated + with a `ptr-like` object type. + + Note: It's a verification error to construct a `ptr_metadata` type using a + `ptr-like` type with no metadata. + + Example: + + ```mlir + // The metadata associated with a `memref` type. + !ptr.ptr_metadata> + ``` + }]; + let parameters = (ins "PtrLikeTypeInterface":$type); + let assemblyFormat = "`<` $type `>`"; + let builders = [ + TypeBuilderWithInferredContext<(ins + "PtrLikeTypeInterface":$ptrLike), [{ + return $_get(ptrLike.getContext(), ptrLike); + }]> + ]; + let genVerifyDecl = 1; +} + + + //===----------------------------------------------------------------------===// // Base address operation definition. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td new file mode 100644 index 0000000000000000000000000000000000000000..472891dca5cde021aff4a4ea3c1e0b77be5faccd --- /dev/null +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td @@ -0,0 +1,80 @@ +//===-- PtrEnums.td - Ptr dialect enumerations -------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PTR_ENUMS +#define PTR_ENUMS + +include "mlir/IR/EnumAttr.td" + +//===----------------------------------------------------------------------===// +// Atomic binary op enum attribute. +//===----------------------------------------------------------------------===// + +def AtomicBinOpXchg : I64EnumAttrCase<"xchg", 0, "xchg">; +def AtomicBinOpAdd : I64EnumAttrCase<"add", 1, "add">; +def AtomicBinOpSub : I64EnumAttrCase<"sub", 2, "sub">; +def AtomicBinOpAnd : I64EnumAttrCase<"_and", 3, "_and">; +def AtomicBinOpNand : I64EnumAttrCase<"nand", 4, "nand">; +def AtomicBinOpOr : I64EnumAttrCase<"_or", 5, "_or">; +def AtomicBinOpXor : I64EnumAttrCase<"_xor", 6, "_xor">; +def AtomicBinOpMax : I64EnumAttrCase<"max", 7, "max">; +def AtomicBinOpMin : I64EnumAttrCase<"min", 8, "min">; +def AtomicBinOpUMax : I64EnumAttrCase<"umax", 9, "umax">; +def AtomicBinOpUMin : I64EnumAttrCase<"umin", 10, "umin">; +def AtomicBinOpFAdd : I64EnumAttrCase<"fadd", 11, "fadd">; +def AtomicBinOpFSub : I64EnumAttrCase<"fsub", 12, "fsub">; +def AtomicBinOpFMax : I64EnumAttrCase<"fmax", 13, "fmax">; +def AtomicBinOpFMin : I64EnumAttrCase<"fmin", 14, "fmin">; +def AtomicBinOpUIncWrap : I64EnumAttrCase<"uinc_wrap", 15, "uinc_wrap">; +def AtomicBinOpUDecWrap : I64EnumAttrCase<"udec_wrap", 16, "udec_wrap">; + +def AtomicBinOp : I64EnumAttr< + "AtomicBinOp", + "ptr.atomicrmw binary operations", + [AtomicBinOpXchg, AtomicBinOpAdd, AtomicBinOpSub, AtomicBinOpAnd, + AtomicBinOpNand, AtomicBinOpOr, AtomicBinOpXor, AtomicBinOpMax, + AtomicBinOpMin, AtomicBinOpUMax, AtomicBinOpUMin, AtomicBinOpFAdd, + AtomicBinOpFSub, AtomicBinOpFMax, AtomicBinOpFMin, AtomicBinOpUIncWrap, + AtomicBinOpUDecWrap]> { + let cppNamespace = "::mlir::ptr"; +} + +//===----------------------------------------------------------------------===// +// Atomic ordering enum attribute. +//===----------------------------------------------------------------------===// + +def AtomicOrderingNotAtomic : I64EnumAttrCase<"not_atomic", 0, "not_atomic">; +def AtomicOrderingUnordered : I64EnumAttrCase<"unordered", 1, "unordered">; +def AtomicOrderingMonotonic : I64EnumAttrCase<"monotonic", 2, "monotonic">; +def AtomicOrderingAcquire : I64EnumAttrCase<"acquire", 3, "acquire">; +def AtomicOrderingRelease : I64EnumAttrCase<"release", 4, "release">; +def AtomicOrderingAcqRel : I64EnumAttrCase<"acq_rel", 5, "acq_rel">; +def AtomicOrderingSeqCst : I64EnumAttrCase<"seq_cst", 6, "seq_cst">; + +def AtomicOrdering : I64EnumAttr< + "AtomicOrdering", + "Atomic ordering for LLVM's memory model", + [AtomicOrderingNotAtomic, AtomicOrderingUnordered, AtomicOrderingMonotonic, + AtomicOrderingAcquire, AtomicOrderingRelease, AtomicOrderingAcqRel, + AtomicOrderingSeqCst + ]> { + let cppNamespace = "::mlir::ptr"; +} + +//===----------------------------------------------------------------------===// +// Ptr add flags enum properties. +//===----------------------------------------------------------------------===// + +def Ptr_PtrAddFlags : I32EnumAttr<"PtrAddFlags", "Pointer add flags", [ + I32EnumAttrCase<"none", 0>, I32EnumAttrCase<"nusw", 1>, I32EnumAttrCase<"nuw", 2>, + I32EnumAttrCase<"inbounds", 3> + ]> { + let cppNamespace = "::mlir::ptr"; +} + +#endif // PTR_ENUMS \ No newline at end of file diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h index 6a0c1429c6be923ba52a7a2c4f9dd021ee673133..8686cc7d316d4a0c7117b97af2cd42849b7f229b 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h @@ -18,6 +18,8 @@ #include "mlir/Dialect/Ptr/IR/PtrDialect.h" #include "mlir/Dialect/Ptr/IR/PtrTypes.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/ViewLikeInterface.h" #define GET_OP_CLASSES #include "mlir/Dialect/Ptr/IR/PtrOps.h.inc" diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td index c63a0b220e501cf8cb3730935673767829aeccb1..313c9f8eb09ac5b0937517a1e1288e4d0c3f55ea 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td @@ -11,6 +11,7 @@ include "mlir/Dialect/Ptr/IR/PtrDialect.td" include "mlir/Dialect/Ptr/IR/PtrAttrDefs.td" +include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td" include "mlir/IR/OpAsmInterface.td" #endif // PTR_OPS diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h index 264a97c80722a2fd928ecf8740e5258b7280deea..4fe1b5a1aa42304f908f320bb53baae5b77eb464 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_PTR_IR_PTRTYPES_H #define MLIR_DIALECT_PTR_IR_PTRTYPES_H +#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h" #include "mlir/IR/Types.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index b946fc8875860b9245011d8141ca8f3cb2063bc0..5caa0932c73eb579e8dec1deec1689761f1f463e 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -1379,4 +1379,57 @@ def YieldOp : TransformDialectOp<"yield", ]; } +def LowerToArmSMEOp : TransformDialectOp<"lower_to_arm_sme", + [FunctionalStyleTransformOpTrait, + DeclareOpInterfaceMethods, + TransformOpInterface, TransformEachOpTrait]> { + let description = [{Apply a list of passes to lower supported ops to + legalized arm_sme dialect ops and types.}]; + + let arguments = + (ins TransformHandleTypeInterface:$target, + DefaultValuedAttr:$fuse_outer_products + ); + + let assemblyFormat = "$target attr-dict `:` type($target)"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::ModuleOp target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +def LowerToLLVMNewOp : TransformDialectOp<"lower_to_llvm_new", + [FunctionalStyleTransformOpTrait, + DeclareOpInterfaceMethods, + TransformOpInterface, TransformEachOpTrait]> { + let description = [{Indicates that the entire module should be converted + to the LLVM dialect. This is expected to be the last transformation in + a sequence.}]; + + let arguments = + (ins TransformHandleTypeInterface:$target, + DefaultValuedAttr:$reassociate_fp_reductions, + DefaultValuedAttr:$enable_index_optimizations, + DefaultValuedAttr:$enable_arm_neon, + DefaultValuedAttr:$enable_arm_sve, + DefaultValuedAttr:$enable_amx, + DefaultValuedAttr:$enable_x86vector, + DefaultValuedAttr:$enable_async, + DefaultValuedAttr:$vscale_range); + + let assemblyFormat = "$target attr-dict `:` type($target)"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::ModuleOp target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + #endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td index 820a18731ffdb03c33ac46262500971e3c554c6f..4a8f38c1269820624211077c0ec98970475d36eb 100644 --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -223,7 +223,12 @@ def ApplyLowerOuterProductPatternsOp : Op:$isSVE + ); + + let assemblyFormat = [{ + (`enableSVE` `=` $isSVE^)? attr-dict + }]; } def ApplyLowerGatherPatternsOp : Op { }]; } +//===----------------------------------------------------------------------===// +// PtrLikeTypeInterface +//===----------------------------------------------------------------------===// + +def PtrLikeTypeInterface : TypeInterface<"PtrLikeTypeInterface"> { + let cppNamespace = "::mlir"; + let description = [{ + A ptr-like type represents an object storing a memory address. This object + is constituted by: + - A memory address called the base pointer. This pointer is treated as a + bag of bits without any assumed structure. The bit-width of the base + pointer must be a compile-time constant. However, the bit-width may remain + opaque or unavailable during transformations that do not depend on the + base pointer. Finally, it is considered indivisible in the sense that as + a `PtrLikeTypeInterface` value, it has no metadata. + - Optional metadata about the pointer. For example, the size of the memory + region associated with the pointer. + + Furthermore, all ptr-like types have two properties: + - The memory space associated with the address held by the pointer. + - An optional element type. If the element type is not specified, the + pointer is considered opaque. + }]; + let methods = [ + InterfaceMethod<[{ + Returns the memory space of this ptr-like type. + }], + "::mlir::Attribute", "getMemorySpace">, + InterfaceMethod<[{ + Returns the element type of this ptr-like type. Note: this method can + return `::mlir::Type()`, in which case the pointer is considered opaque. + }], + "::mlir::Type", "getElementType">, + InterfaceMethod<[{ + Returns whether this ptr-like type has non-empty metadata. + }], + "bool", "hasPtrMetadata">, + InterfaceMethod<[{ + Returns a clone of this type with the given memory space and element type, + or `failure` if the type cannot be cloned with the specified arguments. + If the pointer is opaque and `elementType` is not `std::nullopt` the + method will return `failure`. + + If no `elementType` is provided and ptr is not opaque, the `elementType` + of this type is used. + }], + "::llvm::FailureOr<::mlir::PtrLikeTypeInterface>", "clonePtrWith", (ins + "::mlir::Attribute":$memorySpace, + "::std::optional<::mlir::Type>":$elementType + )> + ]; +} + //===----------------------------------------------------------------------===// // ShapedType //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h index 4250be90ba7fb0bb9558af5a7f6a133d9bf305a8..217bfd9eb4f4b9f3e3f9c04f94539a834e6df27b 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -141,7 +141,9 @@ public: /// Note: This class attaches the ShapedType trait to act as a mixin to /// provide many useful utility functions. This inheritance has no effect /// on derived memref types. -class BaseMemRefType : public Type, public ShapedType::Trait { +class BaseMemRefType : public Type, + public PtrLikeTypeInterface::Trait, + public ShapedType::Trait { public: using Type::Type; @@ -158,6 +160,13 @@ public: /// provided shape is `std::nullopt`, the current shape of the type is used. BaseMemRefType cloneWith(std::optional> shape, Type elementType) const; + + /// Clone this type with the given memory space and element type. If the + /// provided element type is `std::nullopt`, the current element type of the + /// type is used. + FailureOr + clonePtrWith(Attribute memorySpace, std::optional elementType) const; + // Make sure that base class overloads are visible. using ShapedType::Trait::clone; @@ -183,8 +192,16 @@ public: /// New `Attribute getMemorySpace()` method should be used instead. unsigned getMemorySpaceAsInt() const; + /// Returns that this ptr-like object has non-empty ptr metadata. + bool hasPtrMetadata() const { return true; } + /// Allow implicit conversion to ShapedType. operator ShapedType() const { return llvm::cast(*this); } + + /// Allow implicit conversion to PtrLikeTypeInterface. + operator PtrLikeTypeInterface() const { + return llvm::cast(*this); + } }; } // namespace mlir diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td index 0b3532dcc7d4f18660ed5b8e2e1c7eabc8a9603c..387e037f9eee788741a8770856076b09cf2742e4 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -423,6 +423,7 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> { //===----------------------------------------------------------------------===// def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [ + PtrLikeTypeInterface, ShapedTypeInterface ], "BaseMemRefType"> { let summary = "Shaped reference to a region of memory"; @@ -951,6 +952,7 @@ def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> { //===----------------------------------------------------------------------===// def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [ + PtrLikeTypeInterface, ShapedTypeInterface ], "BaseMemRefType"> { let summary = "Shaped reference, with unknown rank, to a region of memory"; diff --git a/mlir/include/mlir/IR/ODSSupport.h b/mlir/include/mlir/IR/ODSSupport.h index 70e3f986431e26676a594aa0c408f24801059842..cafb0b58a759130e1ad16ef1a412ac410bc5d771 100644 --- a/mlir/include/mlir/IR/ODSSupport.h +++ b/mlir/include/mlir/IR/ODSSupport.h @@ -33,6 +33,37 @@ convertFromAttribute(int64_t &storage, Attribute attr, /// Convert the provided int64_t to an IntegerAttr attribute. Attribute convertToAttribute(MLIRContext *ctx, int64_t storage); +/// Convert an IntegerAttr attribute to an int32_t, or return an error if the +/// attribute isn't an IntegerAttr. If the optional diagnostic is provided an +/// error message is also emitted. +LogicalResult +convertFromAttribute(int32_t &storage, Attribute attr, + function_ref emitError); + +/// Convert the provided int32_t to an IntegerAttr attribute. +Attribute convertToAttribute(MLIRContext *ctx, int32_t storage); + +/// Extract the string from `attr` into `storage`. If `attr` is not a +/// `StringAttr`, return failure and emit an error into the diagnostic from +/// `emitError`. +LogicalResult +convertFromAttribute(std::string &storage, Attribute attr, + function_ref emitError); + +/// Convert the given string into a StringAttr. Note that this takes a reference +/// to the storage of a string property, which is an std::string. +Attribute convertToAttribute(MLIRContext *ctx, const std::string &storage); + +/// Extract the boolean from `attr` into `storage`. If `attr` is not a +/// `BoolAttr`, return failure and emit an error into the diagnostic from +/// `emitError`. +LogicalResult +convertFromAttribute(bool &storage, Attribute attr, + function_ref emitError); + +/// Convert the given string into a BooleanAttr. +Attribute convertToAttribute(MLIRContext *ctx, bool storage); + /// Convert a DenseI64ArrayAttr to the provided storage. It is expected that the /// storage has the same size as the array. An error is returned if the /// attribute isn't a DenseI64ArrayAttr or it does not have the same size. If @@ -49,6 +80,21 @@ LogicalResult convertFromAttribute(MutableArrayRef storage, Attribute attr, function_ref emitError); +/// Convert a DenseI64ArrayAttr to the provided storage, which will be +/// cleared before writing. An error is returned and emitted to the optional +/// `emitError` function if the attribute isn't a DenseI64ArrayAttr. +LogicalResult +convertFromAttribute(SmallVectorImpl &storage, Attribute attr, + function_ref emitError); + +/// Convert a DenseI32ArrayAttr to the provided storage, which will be +/// cleared before writing. It is expected that the storage has the same size as +/// the array. An error is returned and emitted to the optional `emitError` +/// function if the attribute isn't a DenseI32ArrayAttr. +LogicalResult +convertFromAttribute(SmallVectorImpl &storage, Attribute attr, + function_ref emitError); + /// Convert the provided ArrayRef to a DenseI64ArrayAttr attribute. Attribute convertToAttribute(MLIRContext *ctx, ArrayRef storage); diff --git a/mlir/include/mlir/IR/Properties.td b/mlir/include/mlir/IR/Properties.td index 0babdbbfa05bc2d8e6c5443cd4d90854dcf0a580..f55a5ab96f77fc4dcf5de96fb5c7e8ede2686611 100644 --- a/mlir/include/mlir/IR/Properties.td +++ b/mlir/include/mlir/IR/Properties.td @@ -29,7 +29,6 @@ class Property { // // Format: // - `$_storage` will contain the property in the storage type. - // - `$_ctxt` will contain an `MLIRContext *`. code convertFromStorage = "$_storage"; // The call expression to build a property storage from the interface type. @@ -40,24 +39,26 @@ class Property { code assignToStorage = "$_storage = $_value"; // The call expression to convert from the storage type to an attribute. + // The resulting attribute must be non-null in non-error cases. // // Format: // - `$_storage` is the storage type value. // - `$_ctxt` is a `MLIRContext *`. // - // The expression must result in an Attribute. + // The expression must return an `Attribute` and will be used as a function body. code convertToAttribute = [{ - convertToAttribute($_ctxt, $_storage) + return convertToAttribute($_ctxt, $_storage); }]; // The call expression to convert from an Attribute to the storage type. // // Format: - // - `$_storage` is the storage type value. + // - `$_storage` is a reference to a value of the storage type. // - `$_attr` is the attribute. // - `$_diag` is a callback to get a Diagnostic to emit error. // - // The expression must return a LogicalResult + // The expression must return a LogicalResult and will be used as a function body + // or in other similar contexts. code convertFromAttribute = [{ return convertFromAttribute($_storage, $_attr, $_diag); }]; @@ -68,18 +69,68 @@ class Property { // - `$_storage` is the variable to hash. // // The expression should define a llvm::hash_code. - code hashProperty = [{ - llvm::hash_value($_storage); + // If unspecified, defaults to `llvm::hash_value($_storage)`. + // The default is not specified in tablegen because many combinators, like + // ArrayProperty, can fall back to more efficient implementations of + // `hashProperty` when their underlying elements have trivial hashing. + code hashProperty = ""; + + // The body of the parser for a value of this property. + // Format: + // - `$_parser` is the OpAsmParser. + // - `$_storage` is the location into which the value is to be placed if it is + // present. + // - `$_ctxt` is a `MLIRContext *` + // + // This defines the body of a function (typically a lambda) that returns a + // ParseResult. There is an implicit `return success()` at the end of the parser + // code. + // + // When this code executes, `$_storage` will be initialized to the property's + // default value (if any, accounting for the storage type override). + code parser = [{ + auto value = ::mlir::FieldParser<}] # storageType # [{>::parse($_parser); + if (::mlir::failed(value)) + return ::mlir::failure(); + $_storage = std::move(*value); }]; + // The body of the parser for a value of this property as the anchor of an optional + // group. This should parse the property if possible and do nothing if a value of + // the relevant type is not next in the parse stream. + // You are not required to define this parser if it cannot be meaningfully + // implemented. + // This has the same context and substitutions as `parser` except that it is + // required to return an OptionalParseResult. + // + // If the optional parser doesn't parse anything, it should not set + // $_storage, since the parser doesn't know if the default value has been + // overwritten. + code optionalParser = ""; + + // The printer for a value of this property. + // Format: + // - `$_storage` is the storage data. + // - `$_printer` is the OpAsmPrinter instance. + // - `$_ctxt` is a `MLIRContext *` + // + // This may be called in an expression context, so variable declarations must + // be placed within a new scope. + // + // The printer for a property should always print a non-empty value - default value + // printing elision happens outside the context of this printing expression. + code printer = "$_printer << $_storage"; + // The call expression to emit the storage type to bytecode. // // Format: // - `$_storage` is the storage type value. // - `$_writer` is a `DialectBytecodeWriter`. // - `$_ctxt` is a `MLIRContext *`. + // + // This will become the body af a function returning void. code writeToMlirBytecode = [{ - writeToMlirBytecode($_writer, $_storage) + writeToMlirBytecode($_writer, $_storage); }]; // The call expression to read the storage type from bytecode. @@ -88,13 +139,31 @@ class Property { // - `$_storage` is the storage type value. // - `$_reader` is a `DialectBytecodeReader`. // - `$_ctxt` is a `MLIRContext *`. + // + // This will become the body of a function returning LogicalResult. + // There is an implicit `return success()` at the end of this function. + // + // When this code executes, `$_storage` will be initialized to the property's + // default value (if any, accounting for the storage type override). code readFromMlirBytecode = [{ if (::mlir::failed(readFromMlirBytecode($_reader, $_storage))) return ::mlir::failure(); }]; - // Default value for the property. - string defaultValue = ?; + // Base definition for the property. (Will be) used for `OptionalProperty` and + // such cases, analogously to `baseAttr`. + Property baseProperty = ?; + + // Default value for the property within its storage. This should be an expression + // of type `interfaceType` and should be comparable with other types of that + // interface typ with `==`. The empty string means there is no default value. + string defaultValue = ""; + + // If set, the default value the storage of the property should be initilized to. + // This is only needed when the storage and interface types of the property + // are distinct (ex. SmallVector for storage vs. ArrayRef for interfacing), as it + // will fall back to `defaultValue` when unspecified. + string storageTypeValueOverride = ""; } /// Implementation of the Property class's `readFromMlirBytecode` field using @@ -133,12 +202,16 @@ defvar writeMlirBytecodeWithConvertToAttribute = [{ // Primitive property kinds // Any kind of integer stored as properties. -class IntProperty : +class IntProperty : Property { - code writeToMlirBytecode = [{ + let summary = !if(!empty(desc), storageTypeParam, desc); + let optionalParser = [{ + return $_parser.parseOptionalInteger($_storage); + }]; + let writeToMlirBytecode = [{ $_writer.writeVarInt($_storage); }]; - code readFromMlirBytecode = [{ + let readFromMlirBytecode = [{ uint64_t val; if (failed($_reader.readVarInt(val))) return ::mlir::failure(); @@ -146,24 +219,472 @@ class IntProperty : }]; } -class ArrayProperty : - Property { - let interfaceType = "::llvm::ArrayRef<" # storageTypeParam # ">"; - let convertFromStorage = "$_storage"; - let assignToStorage = "::llvm::copy($_value, $_storage)"; -} +def I32Property : IntProperty<"int32_t">; +def I64Property : IntProperty<"int64_t">; -class EnumProperty : +class EnumProperty : Property { - code writeToMlirBytecode = [{ + // TODO: take advantage of EnumAttrInfo and the like to make this share nice + // parsing code with EnumAttr. + let writeToMlirBytecode = [{ $_writer.writeVarInt(static_cast($_storage)); }]; - code readFromMlirBytecode = [{ + let readFromMlirBytecode = [{ uint64_t val; if (failed($_reader.readVarInt(val))) return ::mlir::failure(); $_storage = static_cast<}] # storageTypeParam # [{>(val); }]; + let defaultValue = default; } -#endif // PROPERTIES +def StringProperty : Property<"std::string", "string"> { + let interfaceType = "::llvm::StringRef"; + let convertFromStorage = "::llvm::StringRef{$_storage}"; + let assignToStorage = "$_storage = $_value.str()"; + let optionalParser = [{ + if (::mlir::failed($_parser.parseOptionalString(&$_storage))) + return std::nullopt; + }]; + let printer = "$_printer.printString($_storage)"; + let readFromMlirBytecode = [{ + StringRef val; + if (::mlir::failed($_reader.readString(val))) + return ::mlir::failure(); + $_storage = val.str(); + }]; + let writeToMlirBytecode = [{ + $_writer.writeOwnedString($_storage); + }]; +} + +def BoolProperty : IntProperty<"bool", "boolean"> { + let printer = [{ $_printer << ($_storage ? "true" : "false") }]; + let readFromMlirBytecode = [{ + return $_reader.readBool($_storage); + }]; + let writeToMlirBytecode = [{ + $_writer.writeOwnedBool($_storage); + }]; +} + +def UnitProperty : Property<"bool", "unit property"> { + let summary = "unit property"; + let description = [{ + A property whose presence or abscence is used as a flag. + + This is stored as a boolean that defaults to false, and is named UnitProperty + by analogy with UnitAttr, which has the more comprehensive rationale and + explains the less typical syntax. + + Note that this attribute does have a syntax for the false case to allow for its + use in contexts where default values shouldn't be elided. + }]; + let defaultValue = "false"; + + let convertToAttribute = [{ + if ($_storage) + return ::mlir::UnitAttr::get($_ctxt); + else + return ::mlir::BoolAttr::get($_ctxt, false); + }]; + let convertFromAttribute = [{ + if (::llvm::isa<::mlir::UnitAttr>($_attr)) { + $_storage = true; + return ::mlir::success(); + } + if (auto boolAttr = ::llvm::dyn_cast<::mlir::BoolAttr>($_attr)) { + $_storage = boolAttr.getValue(); + return ::mlir::success(); + } + return ::mlir::failure(); + }]; + + let parser = [{ + ::llvm::StringRef keyword; + if (::mlir::failed($_parser.parseOptionalKeyword(&keyword, + {"unit", "unit_absent"}))) + return $_parser.emitError($_parser.getCurrentLocation(), + "expected 'unit' or 'unit_absent'"); + $_storage = (keyword == "unit"); + }]; + + let optionalParser = [{ + ::llvm::StringRef keyword; + if (::mlir::failed($_parser.parseOptionalKeyword(&keyword, + {"unit", "unit_absent"}))) + return std::nullopt; + $_storage = (keyword == "unit"); + }]; + + let printer = [{ + $_printer << ($_storage ? "unit" : "unit_absent") + }]; + + let writeToMlirBytecode = [{ + $_writer.writeOwnedBool($_storage); + }]; + let readFromMlirBytecode = [{ + if (::mlir::failed($_reader.readBool($_storage))) + return ::mlir::failure(); + }]; +} + +//===----------------------------------------------------------------------===// +// Primitive property combinators + +/// Create a variable named `name` of `prop`'s storage type that is initialized +/// to the correct default value, if there is one. +class _makePropStorage { + code ret = prop.storageType # " " # name + # !cond(!not(!empty(prop.storageTypeValueOverride)) : " = " # prop.storageTypeValueOverride, + !not(!empty(prop.defaultValue)) : " = " # prop.defaultValue, + true : "") # ";"; +} + +/// The generic class for arrays of some other property, which is stored as a +/// `SmallVector` of that property. This uses an `ArrayAttr` as its attribute form +/// though subclasses can override this, as is the case with IntArrayAttr below. +/// Those wishing to use a non-default number of SmallVector elements should +/// subclass `ArrayProperty`. +class ArrayProperty, string desc = ""> : + Property<"::llvm::SmallVector<" # elem.storageType # ">", desc> { + let summary = "array of " # elem.summary; + let interfaceType = "::llvm::ArrayRef<" # elem.storageType # ">"; + let convertFromStorage = "::llvm::ArrayRef<" # elem.storageType # ">{$_storage}"; + let assignToStorage = "$_storage.assign($_value.begin(), $_value.end())"; + + let convertFromAttribute = [{ + auto arrayAttr = ::llvm::dyn_cast_if_present<::mlir::ArrayAttr>($_attr); + if (!arrayAttr) + return $_diag() << "expected array attribute"; + for (::mlir::Attribute elemAttr : arrayAttr) { + }] # _makePropStorage.ret # [{ + auto elemRes = [&](Attribute propAttr, }] # elem.storageType # [{& propStorage) -> ::mlir::LogicalResult { + }] # !subst("$_attr", "propAttr", + !subst("$_storage", "propStorage", elem.convertFromAttribute)) # [{ + }(elemAttr, elemVal); + if (::mlir::failed(elemRes)) + return ::mlir::failure(); + $_storage.push_back(std::move(elemVal)); + } + return ::mlir::success(); + }]; + + let convertToAttribute = [{ + SmallVector elems; + for (const auto& elemVal : $_storage) { + auto elemAttr = [&](const }] # elem.storageType #[{& propStorage) -> ::mlir::Attribute { + }] # !subst("$_storage", "propStorage", elem.convertToAttribute) # [{ + }(elemVal); + elems.push_back(elemAttr); + } + return ::mlir::ArrayAttr::get($_ctxt, elems); + }]; + + defvar theParserBegin = [{ + auto& storage = $_storage; + auto parseElemFn = [&]() -> ::mlir::ParseResult { + }] # _makePropStorage.ret # [{ + auto elemParse = [&](}] # elem.storageType # [{& propStorage) -> ::mlir::ParseResult { + }] # !subst("$_storage", "propStorage", elem.parser) # [{ + return ::mlir::success(); + }(elemVal); + if (::mlir::failed(elemParse)) + return ::mlir::failure(); + storage.push_back(std::move(elemVal)); + return ::mlir::success(); + }; + }]; + let parser = theParserBegin # [{ + return $_parser.parseCommaSeparatedList( + ::mlir::OpAsmParser::Delimiter::Square, parseElemFn); + }]; + // Hack around the lack of a peek method + let optionalParser = theParserBegin # [{ + auto oldLoc = $_parser.getCurrentLocation(); + auto parseResult = $_parser.parseCommaSeparatedList( + ::mlir::OpAsmParser::Delimiter::OptionalSquare, parseElemFn); + if (::mlir::failed(parseResult)) + return ::mlir::failure(); + auto newLoc = $_parser.getCurrentLocation(); + if (oldLoc == newLoc) + return std::nullopt; + return ::mlir::success(); + }]; + + let printer = [{ [&](){ + $_printer << "["; + auto elemPrinter = [&](const }] # elem.storageType # [{& elemVal) { + }] # !subst("$_storage", "elemVal", elem.printer) #[{; + }; + ::llvm::interleaveComma($_storage, $_printer, elemPrinter); + $_printer << "]"; + }()}]; + + let readFromMlirBytecode = [{ + uint64_t length; + if (::mlir::failed($_reader.readVarInt(length))) + return ::mlir::failure(); + $_storage.reserve(length); + for (uint64_t i = 0; i < length; ++i) { + }]# _makePropStorage.ret # [{ + auto elemRead = [&](}] # elem.storageType # [{& propStorage) -> ::mlir::LogicalResult { + }] # !subst("$_storage", "propStorage", elem.readFromMlirBytecode) # [{; + return ::mlir::success(); + }(elemVal); + if (::mlir::failed(elemRead)) + return ::mlir::failure(); + $_storage.push_back(std::move(elemVal)); + } + }]; + + let writeToMlirBytecode = [{ + $_writer.writeVarInt($_storage.size()); + for (const auto& elemVal : $_storage) { + [&]() { + }] # !subst("$_storage", "elemVal", elem.writeToMlirBytecode) #[{; + }(); + } + }]; + + // There's no hash_value for SmallVector, so we construct the ArrayRef ourselves. + // In the non-trivial case, we define a mapped range to get internal hash + // codes. + let hashProperty = !if(!empty(elem.hashProperty), + [{::llvm::hash_value(::llvm::ArrayRef<}] # elem.storageType # [{>{$_storage})}], + [{[&]() -> ::llvm::hash_code { + auto getElemHash = [](const auto& propStorage) -> ::llvm::hash_code { + return }] # !subst("$_storage", "propStorage", elem.hashProperty) # [{; + }; + auto mapped = ::llvm::map_range($_storage, getElemHash); + return ::llvm::hash_combine_range(mapped.begin(), mapped.end()); + }() + }]); +} + +class IntArrayProperty : + ArrayProperty> { + // Bring back the trivial conversions we don't get in the general case. + let convertFromAttribute = [{ + return convertFromAttribute($_storage, $_attr, $_diag); + }]; + let convertToAttribute = [{ + return convertToAttribute($_ctxt, $_storage); + }]; +} + +/// Class for giving a property a default value. +/// This doesn't change anything about the property other than giving it a default +/// which can be used by ODS to elide printing. +class DefaultValuedProperty : Property { + let defaultValue = default; + let storageTypeValueOverride = storageDefault; + let baseProperty = p; + // Keep up to date with `Property` above. + let summary = p.summary; + let description = p.description; + let storageType = p.storageType; + let interfaceType = p.interfaceType; + let convertFromStorage = p.convertFromStorage; + let assignToStorage = p.assignToStorage; + let convertToAttribute = p.convertToAttribute; + let convertFromAttribute = p.convertFromAttribute; + let hashProperty = p.hashProperty; + let parser = p.parser; + let optionalParser = p.optionalParser; + let printer = p.printer; + let readFromMlirBytecode = p.readFromMlirBytecode; + let writeToMlirBytecode = p.writeToMlirBytecode; +} + +/// An optional property, stored as an std::optional +/// interfaced with as an std::optional.. +/// The syntax is `none` (or empty string if elided) for an absent value or +/// `some<[underlying property]>` when a value is set. +/// +/// As a special exception, if the underlying property has an optional parser and +/// no default value (ex. an integer property), the printer will skip the `some` +/// bracketing and delegate to the optional parser. In that case, the syntax is the +/// syntax of the underlying property, or the keyword `none` in the rare cases that +/// it is needed. This behavior can be disabled by setting `canDelegateParsing` to 0. +class OptionalProperty + : Property<"std::optional<" # p.storageType # ">", "optional " # p.summary> { + + // In the cases where the underlying attribute is plain old data that's passed by + // value, the conversion code is trivial. + defvar hasTrivialStorage = !and(!eq(p.convertFromStorage, "$_storage"), + !eq(p.assignToStorage, "$_storage = $_value"), + !eq(p.storageType, p.interfaceType)); + + defvar delegatesParsing = !and(!empty(p.defaultValue), + !not(!empty(p.optionalParser)), canDelegateParsing); + + let interfaceType = "std::optional<" # p.interfaceType # ">"; + let defaultValue = "std::nullopt"; + + let convertFromStorage = !if(hasTrivialStorage, + p.convertFromStorage, + [{($_storage.has_value() ? std::optional<}] # p.interfaceType # ">{" + # !subst("$_storage", "(*($_storage))", p.convertFromStorage) + # [{} : std::nullopt)}]); + let assignToStorage = !if(hasTrivialStorage, + p.assignToStorage, + [{[&]() { + if (!$_value.has_value()) { + $_storage = std::nullopt; + return; + } + }] # _makePropStorage.ret # [{ + [&](}] # p.storageType # [{& propStorage) { + }] # !subst("$_storage", "propStorage", + !subst("$_value", "(*($_value))", p.assignToStorage)) # [{; + }(presentVal); + $_storage = std::move(presentVal); + }()}]); + + let convertFromAttribute = [{ + auto arrayAttr = ::llvm::dyn_cast<::mlir::ArrayAttr>($_attr); + if (!arrayAttr) + return $_diag() << "expected optional properties to materialize as arrays"; + if (arrayAttr.size() > 1) + return $_diag() << "expected optional properties to become 0- or 1-element arrays"; + if (arrayAttr.empty()) { + $_storage = std::nullopt; + return ::mlir::success(); + } + ::mlir::Attribute presentAttr = arrayAttr[0]; + }] # _makePropStorage.ret # [{ + auto presentRes = [&](Attribute propAttr, }] # p.storageType # [{& propStorage) -> ::mlir::LogicalResult { + }] # !subst("$_storage", "propStorage", + !subst("$_attr", "propAttr", p.convertFromAttribute)) # [{ + }(presentAttr, presentVal); + if (::mlir::failed(presentRes)) + return ::mlir::failure(); + $_storage = std::move(presentVal); + return ::mlir::success(); + }]; + + let convertToAttribute = [{ + if (!$_storage.has_value()) { + return ::mlir::ArrayAttr::get($_ctxt, {}); + } + auto attr = [&]() -> ::mlir::Attribute { + }] # !subst("$_storage", "(*($_storage))", p.convertToAttribute) # [{ + }(); + return ::mlir::ArrayAttr::get($_ctxt, {attr}); + }]; + + defvar delegatedParserBegin = [{ + if (::mlir::succeeded($_parser.parseOptionalKeyword("none"))) { + $_storage = std::nullopt; + return ::mlir::success(); + } + }] #_makePropStorage.ret # [{ + auto delegParseResult = [&](}] # p.storageType # [{& propStorage) -> ::mlir::OptionalParseResult { + }] # !subst("$_storage", "propStorage", p.optionalParser) # [{ + return ::mlir::success(); + }(presentVal); + if (!delegParseResult.has_value()) { + }]; + + defvar delegatedParserEnd = [{ + } + if (delegParseResult.has_value() && ::mlir::failed(*delegParseResult)) + return ::mlir::failure(); + $_storage = std::move(presentVal); + return ::mlir::success(); + }]; + // If we're being explicitly called for our parser, we're expecting to have been + // printede into a context where the default value isn't elided. Therefore, + // not-present from the underlying parser is a failure. + defvar delegatedParser = delegatedParserBegin # [{ + return ::mlir::failure(); + }] # delegatedParserEnd; + defvar delegatedOptionalParser = delegatedParserBegin # [{ + return std::nullopt; + }] # delegatedParserEnd; + + defvar generalParserBegin = [{ + ::llvm::StringRef keyword; + if (::mlir::failed($_parser.parseOptionalKeyword(&keyword, {"none", "some"}))) { + }]; + defvar generalParserEnd = [{ + } + if (keyword == "none") { + $_storage = std::nullopt; + return ::mlir::success(); + } + if (::mlir::failed($_parser.parseLess())) + return ::mlir::failure(); + }] # _makePropStorage.ret # [{ + auto presentParse = [&](}] # p.storageType # [{& propStorage) -> ::mlir::ParseResult { + }] # !subst("$_storage", "propStorage", p.parser) # [{ + return ::mlir::success(); + }(presentVal); + if (presentParse || $_parser.parseGreater()) + return ::mlir::failure(); + $_storage = std::move(presentVal); + }]; + defvar generalParser = generalParserBegin # [{ + return $_parser.emitError($_parser.getCurrentLocation(), "expected 'none' or 'some'"); + }] # generalParserEnd; + defvar generalOptionalParser = generalParserBegin # [{ + return std::nullopt; + }] # generalParserEnd; + + let parser = !if(delegatesParsing, delegatedParser, generalParser); + let optionalParser = !if(delegatesParsing, + delegatedOptionalParser, generalOptionalParser); + + defvar delegatedPrinter = [{ + [&]() { + if (!$_storage.has_value()) { + $_printer << "none"; + return; + } + }] # !subst("$_storage", "(*($_storage))", p.printer) # [{; + }()}]; + defvar generalPrinter = [{ + [&]() { + if (!$_storage.has_value()) { + $_printer << "none"; + return; + } + $_printer << "some<"; + }] # !subst("$_storage", "(*($_storage))", p.printer) # [{; + $_printer << ">"; + }()}]; + let printer = !if(delegatesParsing, delegatedPrinter, generalPrinter); + + let readFromMlirBytecode = [{ + bool isPresent = false; + if (::mlir::failed($_reader.readBool(isPresent))) + return ::mlir::failure(); + if (!isPresent) { + $_storage = std::nullopt; + return ::mlir::success(); + } + }] # _makePropStorage.ret # [{ + auto presentResult = [&](}] # p.storageType # [{& propStorage) -> ::mlir::LogicalResult { + }] # !subst("$_storage", "propStorage", p.readFromMlirBytecode) # [{; + return ::mlir::success(); + }(presentVal); + if (::mlir::failed(presentResult)) + return ::mlir::failure(); + $_storage = std::move(presentVal); + }]; + let writeToMlirBytecode = [{ + $_writer.writeOwnedBool($_storage.has_value()); + if (!$_storage.has_value()) + return; + }] # !subst("$_storage", "(*($_storage))", p.writeToMlirBytecode); + + let hashProperty = !if(!empty(p.hashProperty), p.hashProperty, + [{ ::llvm::hash_value($_storage.has_value() ? std::optional<::llvm::hash_code>{}] # + !subst("$_storage", "(*($_storage))", p.hashProperty) #[{} : std::nullopt) }]); + assert !or(!not(delegatesParsing), !eq(defaultValue, "std::nullopt")), + "For delegated parsing to be used, the default value must be nullopt. " # + "To use a non-trivial default, set the canDelegateParsing argument to 0"; +} +#endif // PROPERTIES \ No newline at end of file diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h index 20a4ab6f18a286f454ef27cc0c95628628ec36e0..a9509a5ddabeaa579ae0686d83c83179ec9124a2 100644 --- a/mlir/include/mlir/InitAllExtensions.h +++ b/mlir/include/mlir/InitAllExtensions.h @@ -23,6 +23,7 @@ #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" #include "mlir/Conversion/UBToLLVM/UBToLLVM.h" +#include "mlir/Conversion/PtrToLLVM/PtrToLLVM.h" #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h" #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" #include "mlir/Dialect/Func/Extensions/AllExtensions.h" @@ -35,6 +36,8 @@ #include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h" #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h" #include "mlir/Dialect/Transform/DebugExtension/DebugExtension.h" +#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h" +#include "mlir/Dialect/ArmSME/TransformOps/ArmSMEVectorTransformOps.h" #include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h" #include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h" #include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h" @@ -44,6 +47,7 @@ #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" + #include namespace mlir { @@ -56,6 +60,7 @@ namespace mlir { inline void registerAllExtensions(DialectRegistry ®istry) { // Register all conversions to LLVM extensions. arith::registerConvertArithToLLVMInterface(registry); + ptr::registerConvertPtrToLLVMInterface(registry); registerConvertComplexToLLVMInterface(registry); cf::registerConvertControlFlowToLLVMInterface(registry); func::registerAllExtensions(registry); @@ -82,6 +87,8 @@ inline void registerAllExtensions(DialectRegistry ®istry) { transform::registerLoopExtension(registry); transform::registerPDLExtension(registry); vector::registerTransformDialectExtension(registry); + arm_sve::registerTransformDialectExtension(registry); + arm_sme::registerTransformDialectExtension(registry); // Translation extensions need to be registered by calling // `registerAllToLLVMIRTranslations` (see All.h). diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h index ab65f92820a6a8ba05ebaa14ffec988a70fda097..e644301890090e428758be60527d282335a97209 100644 --- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h +++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h @@ -35,6 +35,8 @@ using DeviceIDTargetDeviceSpecPair = std::pair; using DeviceIDTargetDeviceSpecPairListRef = llvm::ArrayRef; +using DataLayoutIdentifiedEntryMap = + ::llvm::DenseMap<::mlir::StringAttr, ::mlir::DataLayoutEntryInterface>; class DataLayoutOpInterface; class DataLayoutSpecInterface; class ModuleOp; @@ -79,6 +81,10 @@ Attribute getDefaultEndianness(DataLayoutEntryInterface entry); /// DataLayoutInterface if specified, otherwise returns the default. Attribute getDefaultAllocaMemorySpace(DataLayoutEntryInterface entry); +/// Default handler for the default memory space request. Dispatches to the +/// DataLayoutInterface if specified, otherwise returns the default. +Attribute getDefaultMemorySpace(DataLayoutEntryInterface entry); + /// Default handler for program memory space request. Dispatches to the /// DataLayoutInterface if specified, otherwise returns the default. Attribute getDefaultProgramMemorySpace(DataLayoutEntryInterface entry); @@ -231,6 +237,9 @@ public: /// Returns the memory space used for AllocaOps. Attribute getAllocaMemorySpace() const; + /// Returns the default memory space used for memory operations. + Attribute getDefaultMemorySpace() const; + /// Returns the memory space used for program memory operations. Attribute getProgramMemorySpace() const; @@ -281,6 +290,7 @@ private: mutable std::optional allocaMemorySpace; mutable std::optional programMemorySpace; mutable std::optional globalMemorySpace; + mutable std::optional defaultMemorySpace; /// Cache for stack alignment. mutable std::optional stackAlignment; diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td index bc5080c9c6a55893e4b1ee9d0a7b5c63a676ced8..a286397c9a41d7d63b165182e49ea8af74d7b6a5 100644 --- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td +++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td @@ -136,6 +136,12 @@ def DataLayoutSpecInterface : AttrInterface<"DataLayoutSpecInterface"> { /*methodName=*/"getStackAlignmentIdentifier", /*args=*/(ins "::mlir::MLIRContext *":$context) >, + InterfaceMethod< + /*description=*/"Returns the default memory space identifier.", + /*retTy=*/"::mlir::StringAttr", + /*methodName=*/"getDefaultMemorySpaceIdentifier", + /*args=*/(ins "::mlir::MLIRContext *":$context) + >, // Implementations may override this if they have an efficient lookup // mechanism. InterfaceMethod< @@ -465,6 +471,18 @@ def DataLayoutOpInterface : OpInterface<"DataLayoutOpInterface"> { return ::mlir::detail::getDefaultStackAlignment(entry); }] >, + StaticInterfaceMethod< + /*description=*/"Returns the memory space used by the ABI computed " + "using the relevant entries. The data layout object " + "can be used for recursive queries.", + /*retTy=*/"::mlir::Attribute", + /*methodName=*/"getDefaultMemorySpace", + /*args=*/(ins "::mlir::DataLayoutEntryInterface":$entry), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return ::mlir::detail::getDefaultMemorySpace(entry); + }] + >, StaticInterfaceMethod< /*description=*/"Returns the value of the property, if the property is " "defined. Otherwise, it returns std::nullopt.", @@ -567,7 +585,9 @@ def DataLayoutTypeInterface : TypeInterface<"DataLayoutTypeInterface"> { /*retTy=*/"bool", /*methodName=*/"areCompatible", /*args=*/(ins "::mlir::DataLayoutEntryListRef":$oldLayout, - "::mlir::DataLayoutEntryListRef":$newLayout), + "::mlir::DataLayoutEntryListRef":$newLayout, + "::mlir::DataLayoutSpecInterface":$newSpec, + "const ::mlir::DataLayoutIdentifiedEntryMap&":$identified), /*methodBody=*/"", /*defaultImplementation=*/[{ return true; }] >, diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h index cc5853c044e9753281f91b57da3d1dc7503ffbf0..768291a3a7267b221c229b72215667498f481b8e 100644 --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -384,7 +384,7 @@ private: SmallVector attributes; /// The properties of the op. - SmallVector properties; + SmallVector properties; /// The arguments of the op (operands and native attributes). SmallVector arguments; diff --git a/mlir/include/mlir/TableGen/Property.h b/mlir/include/mlir/TableGen/Property.h index d0d6f4940c7c0414884e4bb8571b12276046aae0..702e6756e6a95c65776a9ab517aa7593b79331c8 100644 --- a/mlir/include/mlir/TableGen/Property.h +++ b/mlir/include/mlir/TableGen/Property.h @@ -35,12 +35,20 @@ class Property { public: explicit Property(const llvm::Record *record); explicit Property(const llvm::DefInit *init); - Property(StringRef storageType, StringRef interfaceType, - StringRef convertFromStorageCall, StringRef assignToStorageCall, - StringRef convertToAttributeCall, StringRef convertFromAttributeCall, + Property(StringRef summary, StringRef description, StringRef storageType, + StringRef interfaceType, StringRef convertFromStorageCall, + StringRef assignToStorageCall, StringRef convertToAttributeCall, + StringRef convertFromAttributeCall, StringRef parserCall, + StringRef optionalParserCall, StringRef printerCall, StringRef readFromMlirBytecodeCall, StringRef writeToMlirBytecodeCall, StringRef hashPropertyCall, - StringRef defaultValue); + StringRef defaultValue, StringRef storageTypeValueOverride); + + // Returns the summary (for error messages) of this property's type. + StringRef getSummary() const { return summary; } + + // Returns the description of this property. + StringRef getDescription() const { return description; } // Returns the storage type. StringRef getStorageType() const { return storageType; } @@ -66,6 +74,19 @@ public: return convertFromAttributeCall; } + // Returns the method call which parses this property from textual MLIR. + StringRef getParserCall() const { return parserCall; } + + // Returns true if this property has defined an optional parser. + bool hasOptionalParser() const { return !optionalParserCall.empty(); } + + // Returns the method call which optionally parses this property from textual + // MLIR. + StringRef getOptionalParserCall() const { return optionalParserCall; } + + // Returns the method call which prints this property to textual MLIR. + StringRef getPrinterCall() const { return printerCall; } + // Returns the method call which reads this property from // bytecode and assign it to the storage. StringRef getReadFromMlirBytecodeCall() const { @@ -87,6 +108,24 @@ public: // Returns the default value for this Property. StringRef getDefaultValue() const { return defaultValue; } + // Returns whether this Property has a default storage-type value that is + // distinct from its default interface-type value. + bool hasStorageTypeValueOverride() const { + return !storageTypeValueOverride.empty(); + } + + StringRef getStorageTypeValueOverride() const { + return storageTypeValueOverride; + } + + // Returns this property's TableGen def-name. + StringRef getPropertyDefName() const; + + // Returns the base-level property that this Property constraint is based on + // or the Property itself otherwise. (Note: there are currently no + // property constraints, this function is added for future-proofing) + Property getBaseProperty() const; + // Returns the TableGen definition this Property was constructed from. const llvm::Record &getDef() const { return *def; } @@ -95,16 +134,22 @@ private: const llvm::Record *def; // Elements describing a Property, in general fetched from the record. + StringRef summary; + StringRef description; StringRef storageType; StringRef interfaceType; StringRef convertFromStorageCall; StringRef assignToStorageCall; StringRef convertToAttributeCall; StringRef convertFromAttributeCall; + StringRef parserCall; + StringRef optionalParserCall; + StringRef printerCall; StringRef readFromMlirBytecodeCall; StringRef writeToMlirBytecodeCall; StringRef hashPropertyCall; StringRef defaultValue; + StringRef storageTypeValueOverride; }; // A struct wrapping an op property and its name together diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 80c8b84d9ae89a577e10b5e3e774c30007d0f2c8..9dcbba9eda5b331f5433df33758a530e3b489678 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -50,6 +50,7 @@ add_subdirectory(ReconcileUnrealizedCasts) add_subdirectory(SCFToControlFlow) add_subdirectory(SCFToEmitC) add_subdirectory(SCFToGPU) +add_subdirectory(PtrToLLVM) add_subdirectory(SCFToOpenMP) add_subdirectory(SCFToSPIRV) add_subdirectory(ShapeToStandard) diff --git a/mlir/lib/Conversion/PtrToLLVM/CMakeLists.txt b/mlir/lib/Conversion/PtrToLLVM/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..d8c60d7ad0d19e13976cb3578f32707dfb6a07f0 --- /dev/null +++ b/mlir/lib/Conversion/PtrToLLVM/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_conversion_library(MLIRPtrToLLVM + PtrToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/PtrToLLVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRPtrDialect + MLIRLLVMCommonConversion + MLIRLLVMDialect + ) \ No newline at end of file diff --git a/mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp b/mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp new file mode 100644 index 0000000000000000000000000000000000000000..756a441b8ec6add7bb529e59d348ec4bb316690a --- /dev/null +++ b/mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp @@ -0,0 +1,77 @@ +//===- PtrToLLVM.cpp - Ptr to LLVM dialect conversion ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/PtrToLLVM/PtrToLLVM.h" + +#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/Ptr/IR/PtrOps.h" +#include "mlir/IR/TypeUtilities.h" +#include + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// ConvertToLLVMPatternInterface implementation +//===----------------------------------------------------------------------===// + +namespace { +/// Implement the interface to convert Ptr to LLVM. +struct PtrToLLVMDialectInterface : public ConvertToLLVMPatternInterface { + using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; + void loadDependentDialects(MLIRContext *context) const final { + context->loadDialect(); + } + + /// Hook for derived dialect interface to provide conversion patterns + /// and mark dialect legal for the conversion target. + void populateConvertToLLVMConversionPatterns( + ConversionTarget &target, LLVMTypeConverter &converter, + RewritePatternSet &patterns) const final { + ptr::populatePtrToLLVMConversionPatterns(converter, patterns); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// API +//===----------------------------------------------------------------------===// + +void mlir::ptr::populatePtrToLLVMConversionPatterns( + LLVMTypeConverter &converter, RewritePatternSet &patterns) { + converter.addTypeAttributeConversion( + [&](PtrLikeTypeInterface type, ptr::GenericSpaceAttr memorySpace) + -> TypeConverter::AttributeConversionResult { + if (type.getMemorySpace() != memorySpace) + return TypeConverter::AttributeConversionResult::na(); + return IntegerAttr::get(IntegerType::get(type.getContext(), 32), 0); + }); + + // Add type conversions. + converter.addConversion([&](ptr::PtrType type) -> Type { + std::optional maybeAttr = + converter.convertTypeAttribute(type, type.getMemorySpace()); + auto memSpace = + maybeAttr ? dyn_cast_or_null(*maybeAttr) : IntegerAttr(); + if (!memSpace) + return {}; + return LLVM::LLVMPointerType::get(type.getContext(), + memSpace.getValue().getSExtValue()); + }); +} + +void mlir::ptr::registerConvertPtrToLLVMInterface(DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, ptr::PtrDialect *dialect) { + dialect->addInterfaces(); + }); +} \ No newline at end of file diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 55143d5939ba2579d1944f755644e9db67a1d96c..56b1fd2a882e06c048fc66dcf1143fbec1385bde 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -66,6 +66,8 @@ void LowerVectorToLLVMPass::runOnOperation() { populateVectorToVectorCanonicalizationPatterns(patterns); populateVectorBitCastLoweringPatterns(patterns); populateVectorBroadcastLoweringPatterns(patterns); + populateVectorContractLoweringPatterns( + patterns, VectorTransformsOptions().enableArmSVE(armSVE)); populateVectorContractLoweringPatterns(patterns, VectorTransformsOptions()); populateVectorMaskOpLoweringPatterns(patterns); populateVectorShapeCastLoweringPatterns(patterns); diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp index 54be644a7101135295c425397d4e0d26846df7a6..b328752886e02579c4c2ac100380fd0269459bbd 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp @@ -335,6 +335,553 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern { } }; +struct F8E5M2ExtFOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::ExtFOp op, + PatternRewriter &rewriter) const final { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + auto operand = op.getOperand(); + Type operandTy = operand.getType(); + Type resultTy = op.getType(); + Type operandETy = getElementTypeOrSelf(operandTy); + Type resultETy = getElementTypeOrSelf(resultTy); + + // Match only f8E5M2 → f32 for now + if (!llvm::isa(operandETy) || !resultETy.isF32()) { + return rewriter.notifyMatchFailure(op, "not a ext of f8E5M2 to f32."); + } + + // Integer and float shaped types matching the input shape + Type i8Ty = b.getI8Type(); + Type i32Ty = b.getI32Type(); + Type f32Ty = b.getF32Type(); + if (auto shapedTy = dyn_cast(operandTy)) { + i8Ty = shapedTy.clone(i8Ty); + i32Ty = shapedTy.clone(i32Ty); + f32Ty = shapedTy.clone(f32Ty); + } + + // Bitcast fp8 to raw uint8 + Value bits = b.create(i8Ty, operand); + // Zero-extend to 32 bits + Value bits32 = b.create(i32Ty, bits); + + // Extract sign (bit 7) → move to f32 sign position (bit 31) + Value sign = b.create( + bits32, createConst(op.getLoc(), i32Ty, 7, rewriter)); + sign = b.create( + sign, createConst(op.getLoc(), i32Ty, 0x1, rewriter)); + + // Extract exponent (bits 2–6) → move to f32 exponent position (bits 23–30) + Value e5m2_exponent = b.create( + bits32, createConst(op.getLoc(), i32Ty, 2, rewriter)); + e5m2_exponent = b.create( + e5m2_exponent, createConst(op.getLoc(), i32Ty, 0x1F, rewriter)); + + // Extract mantissa (bits 0–1) + Value e5m2_mantissa = b.create( + bits32, + createConst(op.getLoc(), i32Ty, 0x3, rewriter)); // 0b11 mask for 2 bits + + // Bias exponent: f8E5M2 has a bias of 15, so we need to subtract 15 + Value exponent = b.create( + e5m2_exponent, createConst(op.getLoc(), i32Ty, 15, rewriter)); + Value float_exponent = b.create( + exponent, createConst(op.getLoc(), i32Ty, 127, rewriter)); + + // Special case handling for NaNs, Infs, subnormals + // Subnormal handling + // if (e5m2_mantissa >= 0x2) + Value isSubnormal = + b.create(arith::CmpIPredicate::sge, e5m2_mantissa, + createConst(op.getLoc(), i32Ty, 0x2, rewriter)); + // result = sign << 31 | (float_exponent) << 23 | (e5m2_mantissa & 0x1) << + // (23 - 1); + Value subnormalResult = b.create( + b.create(sign, + createConst(op.getLoc(), i32Ty, 31, rewriter)), + b.create( + b.create( + float_exponent, createConst(op.getLoc(), i32Ty, 23, rewriter)), + b.create( + b.create( + e5m2_mantissa, + createConst(op.getLoc(), i32Ty, 0x1, rewriter)), + createConst(op.getLoc(), i32Ty, 22, rewriter)))); + + // if (e5m2_mantissa == 0x1) + Value isSubnormal2 = + b.create(arith::CmpIPredicate::eq, e5m2_mantissa, + createConst(op.getLoc(), i32Ty, 0x1, rewriter)); + // result = sign << 31 | (float_exponent - 1) << 23; + Value subnormalResult2 = b.create( + b.create(sign, + createConst(op.getLoc(), i32Ty, 31, rewriter)), + b.create( + b.create( + float_exponent, createConst(op.getLoc(), i32Ty, 1, rewriter)), + createConst(op.getLoc(), i32Ty, 23, rewriter))); + + // Is normal if (e5m2_exponent > 0) + Value isNormal = + b.create(arith::CmpIPredicate::sgt, e5m2_exponent, + createConst(op.getLoc(), i32Ty, 0, rewriter)); + + // else nan + Value NaN = createConst(op.getLoc(), i32Ty, 0x7FC00000, rewriter); + + // Combine sign | exponent | mantissa + Value normalResult = b.create( + b.create(sign, + createConst(op.getLoc(), i32Ty, 31, rewriter)), + b.create( + b.create( + float_exponent, createConst(op.getLoc(), i32Ty, 23, rewriter)), + b.create( + e5m2_mantissa, createConst(op.getLoc(), i32Ty, 21, rewriter)))); + + // Select the appropriate result based on the conditions + Value result = b.create( + isNormal, normalResult, + b.create( + isSubnormal, subnormalResult, + b.create(isSubnormal2, subnormalResult2, NaN))); + + // Bitcast to f32 + result = b.create(f32Ty, result); + + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct F8E5M2TruncFOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::TruncFOp op, + PatternRewriter &rewriter) const final { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + Value operand = op.getOperand(); + Type operandTy = operand.getType(); + Type operandETy = getElementTypeOrSelf(operandTy); + Type resultTy = op.getType(); + Type resultETy = getElementTypeOrSelf(resultTy); + + if (!resultETy.isFloat8E5M2()) { + return rewriter.notifyMatchFailure(op, "not a truncf to fp8e5m2"); + } + + if (op.getRoundingmodeAttr()) { + return rewriter.notifyMatchFailure( + op, "only applicable to default rounding mode."); + } + + Type i8Ty = b.getI8Type(); + Type i32Ty = b.getI32Type(); + Type f32Ty = b.getF32Type(); + + if (auto shapedTy = mlir::dyn_cast(operandTy)) { + i8Ty = shapedTy.clone(i8Ty); + i32Ty = shapedTy.clone(i32Ty); + f32Ty = shapedTy.clone(f32Ty); + } + + // Normalize to f32 + if (operandETy.getIntOrFloatBitWidth() < 32) { + operand = b.create(f32Ty, operand, op.getFastmathAttr()); + } else if (operandETy.getIntOrFloatBitWidth() > 32) { + operand = b.create( + f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr()); + } + + // Bitcast f32 to i32 for bit manipulations + Value bits = b.create(i32Ty, operand); + + // Extract sign bit (bit 31) + Value sign = b.create( + bits, createConst(op.getLoc(), i32Ty, 31, rewriter)); + sign = b.create( + sign, createConst(op.getLoc(), i32Ty, 0x1, rewriter)); + + // Extract exponent bits (bits 30:23) + Value exponent = b.create( + bits, createConst(op.getLoc(), i32Ty, 23, rewriter)); + exponent = b.create( + exponent, createConst(op.getLoc(), i32Ty, 0xFF, rewriter)); + + // Compute unbiased exponent (exponent - 127) + Value exponentBias = createConst(op.getLoc(), i32Ty, 127, rewriter); + Value unbiasedExp = b.create(exponent, exponentBias); + + // Extract mantissa bits (bits 22:0) + Value mantissa = b.create( + bits, createConst(op.getLoc(), i32Ty, 0x7FFFFF, rewriter)); + + // Add fp8 bias (15) + Value fp8Bias = createConst(op.getLoc(), i32Ty, 15, rewriter); + Value fp8Exp = b.create(unbiasedExp, fp8Bias); + + // Prepare mantissa for rounding: + // We need to reduce mantissa from 23 bits → 2 bits mantissa in fp8. + // To round to nearest, shift mantissa right by 21 (23 - 2) + Value mantissaShift = createConst(op.getLoc(), i32Ty, 21, rewriter); + Value mantissaTruncated = b.create(mantissa, mantissaShift); + + Value e5m2_mantissa = b.create( + mantissaTruncated, + createConst(op.getLoc(), i32Ty, 0x3, rewriter)); // 0b11 mask for 2 bits + + // Compose final fp8 bits: sign (bit7), expFinal (bits 6:2), mantissaFinal + // (bits 1:0) + Value signShifted = b.create( + sign, createConst(op.getLoc(), i32Ty, 7, rewriter)); + Value expShifted = b.create( + fp8Exp, createConst(op.getLoc(), i32Ty, 2, rewriter)); + Value resultInt = b.create(signShifted, expShifted); + resultInt = b.create(resultInt, e5m2_mantissa); + + // Subnormal cases + // if (e5m2_exponent > 31) + Value isSubnormal = + b.create(arith::CmpIPredicate::sgt, fp8Exp, + createConst(op.getLoc(), i32Ty, 31, rewriter)); + // return sign << 7 | 0x7C; + Value subnormalResult = b.create( + b.create(sign, + createConst(op.getLoc(), i32Ty, 7, rewriter)), + createConst(op.getLoc(), i32Ty, 0x7C, rewriter) // 0b01111100 + ); + // if ((e5m2_exponent >= -1) && (e5m2_exponent <= 0)) + Value isSubnormal2 = b.create( + b.create(arith::CmpIPredicate::sge, fp8Exp, + createConst(op.getLoc(), i32Ty, -1, rewriter)), + b.create(arith::CmpIPredicate::sle, fp8Exp, + createConst(op.getLoc(), i32Ty, 0, rewriter))); + // uint8_t shift_bits = (2 + e5m2_exponent); + // uint8_t e5m2_mantissa = (mantissa >> (24 - shift_bits)) & (0x3 >> (0 - + // e5m2_exponent)); return sign << 7 | 0x00 | e5m2_mantissa; + + Value shiftBits = b.create( + createConst(op.getLoc(), i32Ty, 2, rewriter), fp8Exp); + Value mantissaShift2 = b.create( + createConst(op.getLoc(), i32Ty, 24, rewriter), shiftBits); + Value e5m2_mantissa2 = b.create( + b.create(mantissa, mantissaShift2), + b.create( + createConst(op.getLoc(), i32Ty, 0x3, + rewriter), // 0b11 mask for 2 bits + b.create( + createConst(op.getLoc(), i32Ty, 0, rewriter), fp8Exp))); + Value subnormalResult2 = b.create( + b.create(sign, + createConst(op.getLoc(), i32Ty, 7, rewriter)), + b.create(createConst(op.getLoc(), i32Ty, 0x00, rewriter), + e5m2_mantissa2)); + + // if (e5m2_exponent < -1) + Value isZero = + b.create(arith::CmpIPredicate::slt, fp8Exp, + createConst(op.getLoc(), i32Ty, -1, rewriter)); + // return sign << 7 | 0x00; + Value zeroResult = b.create( + b.create(sign, + createConst(op.getLoc(), i32Ty, 7, rewriter)), + createConst(op.getLoc(), i32Ty, 0x00, rewriter)); + + // Select the appropiate result based on the conditions + Value finalResult = b.create( + isSubnormal, subnormalResult, + b.create( + isSubnormal2, subnormalResult2, + b.create(isZero, zeroResult, resultInt))); + + // Truncate to i8 and bitcast to fp8e5m2 + Value resultI8 = b.create(i8Ty, finalResult); + Value result = b.create(resultTy, resultI8); + + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct F8E4M3FNExtFOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::ExtFOp op, + PatternRewriter &rewriter) const final { + + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + auto operand = op.getOperand(); + Type operandTy = operand.getType(); + Type resultTy = op.getType(); + Type operandETy = getElementTypeOrSelf(operandTy); + Type resultETy = getElementTypeOrSelf(resultTy); + + // Match only f8E4M3 → f32 for now + if (!llvm::isa(operandETy) || !resultETy.isF32()) { + return rewriter.notifyMatchFailure(op, "not a ext of f8E4M3 to f32."); + } + + // Integer and float shaped types matching the input shape + Type i8Ty = b.getI8Type(); + Type i32Ty = b.getI32Type(); + Type f32Ty = b.getF32Type(); + if (auto shapedTy = dyn_cast(operandTy)) { + i8Ty = shapedTy.clone(i8Ty); + i32Ty = shapedTy.clone(i32Ty); + f32Ty = shapedTy.clone(f32Ty); + } + + // Bitcast fp8 to raw uint8 + Value bits = b.create(i8Ty, operand); + // Zero-extend to 32 bits + Value bits32 = b.create(i32Ty, bits); + + // Extract sign + Value sign = b.create( + bits32, createConst(op.getLoc(), i32Ty, 7, rewriter)); + sign = b.create( + sign, createConst(op.getLoc(), i32Ty, 0x1, rewriter)); + + // extract exponent + Value e4m3_exponent = b.create( + bits32, createConst(op.getLoc(), i32Ty, 3, rewriter)); + e4m3_exponent = b.create( + e4m3_exponent, createConst(op.getLoc(), i32Ty, 0xF, rewriter)); + + // extract mantissa + Value rounding_bias = createConst(op.getLoc(), i32Ty, 0x80000, rewriter); + Value mantissa = b.create(bits32, rounding_bias); + Value e4m3_mantissa = b.create( + mantissa, createConst(op.getLoc(), i32Ty, 0x7, rewriter)); + + // bias exponent + Value exponent = b.create( + e4m3_exponent, createConst(op.getLoc(), i32Ty, 7, rewriter)); + Value float_exponent = b.create( + exponent, createConst(op.getLoc(), i32Ty, 127, rewriter)); + + // put everything together (normal number) e4m3_exponent > 0 + Value isNormal = + b.create(arith::CmpIPredicate::sgt, e4m3_exponent, + createConst(op.getLoc(), i32Ty, 0, rewriter)); + + Value result = b.create( + b.create(sign, + createConst(op.getLoc(), i32Ty, 31, rewriter)), + b.create( + b.create( + float_exponent, createConst(op.getLoc(), i32Ty, 23, rewriter)), + b.create( + e4m3_mantissa, createConst(op.getLoc(), i32Ty, 20, rewriter)))); + + // sub-normal numbers handling (e4m3_matissa >= 0x4) + Value isSubnormal1 = + b.create(arith::CmpIPredicate::sge, e4m3_mantissa, + createConst(op.getLoc(), i32Ty, 0x4, rewriter)); + + Value resultSubnormal1 = b.create( + b.create(sign, + createConst(op.getLoc(), i32Ty, 31, rewriter)), + b.create( + b.create( + float_exponent, createConst(op.getLoc(), i32Ty, 23, rewriter)), + b.create( + createConst(op.getLoc(), i32Ty, 0x3, rewriter), + b.create( + e4m3_mantissa, + createConst(op.getLoc(), i32Ty, 21, rewriter))))); + + // else if e4m3_mantissa > 0x1 + Value isSubnormal2 = + b.create(arith::CmpIPredicate::sgt, e4m3_mantissa, + createConst(op.getLoc(), i32Ty, 0x1, rewriter)); + + Value resultSubormal2 = b.create( + b.create(sign, + createConst(op.getLoc(), i32Ty, 31, rewriter)), + b.create( + b.create( + b.create( + float_exponent, + createConst(op.getLoc(), i32Ty, 1, rewriter)), + createConst(op.getLoc(), i32Ty, 23, rewriter)), + b.create( + createConst(op.getLoc(), i32Ty, 0x1, rewriter), + b.create( + e4m3_mantissa, + createConst(op.getLoc(), i32Ty, 22, rewriter))))); + + // else if e4m3_mantissa == 0x1 + Value isSubnormal3 = + b.create(arith::CmpIPredicate::eq, e4m3_mantissa, + createConst(op.getLoc(), i32Ty, 0x1, rewriter)); + + Value resultSubnormal3 = b.create( + b.create(sign, + createConst(op.getLoc(), i32Ty, 31, rewriter)), + b.create( + b.create( + float_exponent, createConst(op.getLoc(), i32Ty, 2, rewriter)), + createConst(op.getLoc(), i32Ty, 23, rewriter))); + + // else Zero + Value resultZero = b.create( + sign, createConst(op.getLoc(), i32Ty, 31, rewriter)); + + // Compute final result + result = b.create( + isNormal, result, + b.create( + isSubnormal1, resultSubnormal1, + b.create( + isSubnormal2, resultSubormal2, + b.create(isSubnormal3, resultSubnormal3, + resultZero)))); + + // Bitcast to f32 + result = b.create(f32Ty, result); + + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct F32ToF8E4M3FNTruncFOpConverter + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::TruncFOp op, + PatternRewriter &rewriter) const final { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + auto operand = op.getOperand(); + Type operandTy = operand.getType(); + Type resultTy = op.getType(); + Type operandETy = getElementTypeOrSelf(operandTy); + Type resultETy = getElementTypeOrSelf(resultTy); + + // Match only f32 → f8E4M3 + if (!operandETy.isF32() || !llvm::isa(resultETy)) { + return rewriter.notifyMatchFailure(op, "not a trunc of f32 to f8E4M3."); + } + + // Integer and float shaped types matching the input shape + Type i8Ty = b.getI8Type(); + Type i32Ty = b.getI32Type(); + if (auto shapedTy = dyn_cast(operandTy)) { + i8Ty = shapedTy.clone(i8Ty); + i32Ty = shapedTy.clone(i32Ty); + } + + // Bitcast f32 to raw uint32 + Value bits32 = b.create(i32Ty, operand); + + // Constants + Value bias127 = createConst(op.getLoc(), i32Ty, 127, rewriter); + Value bias7 = createConst(op.getLoc(), i32Ty, 7, rewriter); + + // Extract sign + Value sign = b.create( + bits32, createConst(op.getLoc(), i32Ty, 31, rewriter)); + sign = b.create( + sign, createConst(op.getLoc(), i32Ty, 0x1, rewriter)); + + // Extract exponent + Value exponent = b.create( + bits32, createConst(op.getLoc(), i32Ty, 23, rewriter)); + exponent = b.create( + exponent, createConst(op.getLoc(), i32Ty, 0xFF, rewriter)); + exponent = b.create(exponent, bias127); + + // Extract the mantissa + Value mantissa = b.create( + bits32, createConst(op.getLoc(), i32Ty, 0x7FFFFF, rewriter)); + + // For normal numbers, add the implicit leading 1 in the mantissa + mantissa = b.create( + mantissa, createConst(op.getLoc(), i32Ty, 0x800000, rewriter)); + + // Apply the bias for e4m3 (bias of 7) + Value e4m3_exponent = b.create(exponent, bias7); + + // if e4m3_exponent > 15 + Value isOverflow = + b.create(arith::CmpIPredicate::sgt, e4m3_exponent, + createConst(op.getLoc(), i32Ty, 15, rewriter)); + + // Clamp to max finite value + Value maxFinite = + createConst(op.getLoc(), i32Ty, 0x7F, rewriter); // 0b01111111 in f8 + + // if ((e4m3_exponent > -3) && (e4m3_exponent <= 0)) + Value isSubnormal = + b.create(arith::CmpIPredicate::sge, e4m3_exponent, + createConst(op.getLoc(), i32Ty, -3, rewriter)); + isSubnormal = b.create( + isSubnormal, + b.create(arith::CmpIPredicate::sle, e4m3_exponent, + createConst(op.getLoc(), i32Ty, 0, rewriter))); + + Value shift_bits = b.create( + e4m3_exponent, createConst(op.getLoc(), i32Ty, 3, rewriter)); + Value e4m3_mantissa_subnormal = b.create( + mantissa, + b.create(createConst(op.getLoc(), i32Ty, 24, rewriter), + shift_bits)); + e4m3_mantissa_subnormal = b.create( + e4m3_mantissa_subnormal, + b.create( + createConst(op.getLoc(), i32Ty, 0x7, rewriter), + b.create( + createConst(op.getLoc(), i32Ty, 0, rewriter), e4m3_exponent))); + + Value resultSubnormal = b.create( + b.create(sign, + createConst(op.getLoc(), i32Ty, 7, rewriter)), + b.create( + createConst(op.getLoc(), i32Ty, 0x00, + rewriter), // Exponent is 0 for subnormals + e4m3_mantissa_subnormal)); + + // else if e4m3_exponent <= -3 + Value isZero = + b.create(arith::CmpIPredicate::sle, e4m3_exponent, + createConst(op.getLoc(), i32Ty, -3, rewriter)); + + Value resultZero = + createConst(op.getLoc(), i32Ty, 0x00, rewriter); // 0b00000000 + + // For normal numbers, normalize mantissa to fit into 3 bits (e4m3 has 3 + // bits for mantissa) + Value e4m3_mantissa = b.create( + mantissa, createConst(op.getLoc(), i32Ty, 20, rewriter)); + e4m3_mantissa = b.create( + e4m3_mantissa, createConst(op.getLoc(), i32Ty, 0x7, rewriter)); + + // Pack the sign, exponent, and mantissa into an 8-bit value (normal + // numbers) + Value result = b.create( + b.create(sign, + createConst(op.getLoc(), i32Ty, 7, rewriter)), + b.create( + b.create( + e4m3_exponent, createConst(op.getLoc(), i32Ty, 3, rewriter)), + e4m3_mantissa)); + + // compute final result (if no codition is met, result is normal) + result = b.create( + isOverflow, maxFinite, + b.create( + isSubnormal, resultSubnormal, + b.create(isZero, resultZero, result))); + + // Truncate to i8 and bitcast to f8e4m3 + result = b.create(i8Ty, result); + result = b.create(resultTy, result); + + rewriter.replaceOp(op, result); + return success(); + } +}; + struct ArithExpandOpsPass : public arith::impl::ArithExpandOpsPassBase { using ArithExpandOpsPassBase::ArithExpandOpsPassBase; @@ -363,21 +910,42 @@ struct ArithExpandOpsPass if (includeBf16) { arith::populateExpandBFloat16Patterns(patterns); - target.addDynamicallyLegalOp( - [](arith::ExtFOp op) { - Type inETy = getElementTypeOrSelf(op.getOperand().getType()); - Type outETy = getElementTypeOrSelf(op.getType()); - return !(inETy.isBF16() && outETy.isF32()); - }); - - target.addDynamicallyLegalOp( - [](arith::TruncFOp op) { - Type inETy = getElementTypeOrSelf(op.getOperand().getType()); - Type outETy = getElementTypeOrSelf(op.getType()); - return !(inETy.isF32() && outETy.isBF16()); - }); + } + if (includeF8E5M2){ + arith::populateExpandF8E5M2Patterns(patterns); + } + if (includeF8E4M3FN){ + arith::populateExpandF8E4M3FNPatterns(patterns); } + target.addDynamicallyLegalOp( + [=](arith::ExtFOp op) { + Type inETy = getElementTypeOrSelf(op.getOperand().getType()); + Type outETy = getElementTypeOrSelf(op.getType()); + bool legalTypes = true; + if (includeBf16) + legalTypes &= !(inETy.isBF16() && outETy.isF32()); + if (includeF8E5M2) + legalTypes &= !inETy.isFloat8E5M2(); + if (includeF8E4M3FN) + legalTypes &= !inETy.isFloat8E4M3FN(); + return legalTypes; + }); + + target.addDynamicallyLegalOp( + [=](arith::TruncFOp op) { + Type inETy = getElementTypeOrSelf(op.getOperand().getType()); + Type outETy = getElementTypeOrSelf(op.getType()); + bool legalTypes = true; + if (includeBf16) + legalTypes &= !(inETy.isF32() && outETy.isBF16()); + if (includeF8E5M2) + legalTypes &= !outETy.isFloat8E5M2(); + if (includeF8E4M3FN) + legalTypes &= !outETy.isFloat8E4M3FN(); + return legalTypes; + }); + // clang-format on if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) @@ -399,6 +967,16 @@ void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) { patterns.getContext()); } +void mlir::arith::populateExpandF8E5M2Patterns(RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); +} + +void mlir::arith::populateExpandF8E4M3FNPatterns(RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); +} + void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) { populateCeilFloorDivExpandOpsPatterns(patterns); // clang-format off diff --git a/mlir/lib/Dialect/ArmSME/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/CMakeLists.txt index 9f57627c321fb0c74b3e4a404e3c36bd435f64a7..cb1e9d01821a2cf352b79c28c44da4ddd33dd3e9 100644 --- a/mlir/lib/Dialect/ArmSME/CMakeLists.txt +++ b/mlir/lib/Dialect/ArmSME/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(TransformOps) diff --git a/mlir/lib/Dialect/ArmSME/TransformOps/ArmSMEVectorTransformOps.cpp b/mlir/lib/Dialect/ArmSME/TransformOps/ArmSMEVectorTransformOps.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cc27c84307e7b1b74ec77dadce0b77ce25fd5e3a --- /dev/null +++ b/mlir/lib/Dialect/ArmSME/TransformOps/ArmSMEVectorTransformOps.cpp @@ -0,0 +1,46 @@ +//===- ArmSMEVectorTransformOps.cpp - Implementation transform ops -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/ArmSME/TransformOps/ArmSMEVectorTransformOps.h" + +#include "mlir/Dialect/ArmSME/Transforms/Transforms.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/ArmSME/IR/ArmSME.h" + +using namespace mlir; + +#define GET_OP_CLASSES +#include "mlir/Dialect/ArmSME/TransformOps/ArmSMEVectorTransformOps.cpp.inc" + + +//===----------------------------------------------------------------------===// +// Transform op registration +//===----------------------------------------------------------------------===// + +namespace { +class ArmSMEVectorTransformDialectExtension + : public transform::TransformDialectExtension< + ArmSMEVectorTransformDialectExtension> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + ArmSMEVectorTransformDialectExtension) + + ArmSMEVectorTransformDialectExtension() { + declareGeneratedDialect(); + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/ArmSME/TransformOps/ArmSMEVectorTransformOps.cpp.inc" + >(); + } +}; +} // namespace + +void mlir::arm_sme::registerTransformDialectExtension( + DialectRegistry ®istry) { + registry.addExtensions(); +} diff --git a/mlir/lib/Dialect/ArmSME/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/TransformOps/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a47013c6d2c6c2e22a2537befe4752ee8ec5d6d4 --- /dev/null +++ b/mlir/lib/Dialect/ArmSME/TransformOps/CMakeLists.txt @@ -0,0 +1,18 @@ +add_mlir_dialect_library(MLIRArmSMEVectorTransformOps + ArmSMEVectorTransformOps.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/TransformOps + + DEPENDS + MLIRArmSMEVectorTransformOpsIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMCommonConversion + MLIRLLVMDialect + MLIRVectorDialect + MLIRTransformDialect + MLIRArmSMEDialect + MLIRArmSMETransforms + ) diff --git a/mlir/lib/Dialect/ArmSVE/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/CMakeLists.txt index 9f57627c321fb0c74b3e4a404e3c36bd435f64a7..cb1e9d01821a2cf352b79c28c44da4ddd33dd3e9 100644 --- a/mlir/lib/Dialect/ArmSVE/CMakeLists.txt +++ b/mlir/lib/Dialect/ArmSVE/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(TransformOps) diff --git a/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp b/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9a1973683531e883531daa43362ecd5f9dece6b6 --- /dev/null +++ b/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp @@ -0,0 +1,46 @@ +//===- ArmSVEVectorTransformOps.cpp - Implementation transform ops -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h" + +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" + +using namespace mlir; + +#define GET_OP_CLASSES +#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp.inc" + + +//===----------------------------------------------------------------------===// +// Transform op registration +//===----------------------------------------------------------------------===// + +namespace { +class ArmSVEVectorTransformDialectExtension + : public transform::TransformDialectExtension< + ArmSVEVectorTransformDialectExtension> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + ArmSVEVectorTransformDialectExtension) + + ArmSVEVectorTransformDialectExtension() { + declareGeneratedDialect(); + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp.inc" + >(); + } +}; +} // namespace + +void mlir::arm_sve::registerTransformDialectExtension( + DialectRegistry ®istry) { + registry.addExtensions(); +} diff --git a/mlir/lib/Dialect/ArmSVE/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/TransformOps/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..235fe8f5ca0ce2c23b173d1a9189dd574c8275f6 --- /dev/null +++ b/mlir/lib/Dialect/ArmSVE/TransformOps/CMakeLists.txt @@ -0,0 +1,18 @@ +add_mlir_dialect_library(MLIRArmSVEVectorTransformOps + ArmSVEVectorTransformOps.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSVE/TransformOps + + DEPENDS + MLIRArmSVEVectorTransformOpsIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMCommonConversion + MLIRLLVMDialect + MLIRVectorDialect + MLIRTransformDialect + MLIRArmSVEDialect + MLIRArmSVETransforms + ) diff --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp index 420c605d1a19b2a53501f9328d28b83a108c885d..987de03a4685e31e50c9d247fdc742fcd3d3212b 100644 --- a/mlir/lib/Dialect/DLTI/DLTI.cpp +++ b/mlir/lib/Dialect/DLTI/DLTI.cpp @@ -170,25 +170,9 @@ combineOneSpec(DataLayoutSpecInterface spec, DenseMap newEntriesForID; spec.bucketEntriesByType(newEntriesForType, newEntriesForID); - // Try overwriting the old entries with the new ones. - for (auto &kvp : newEntriesForType) { - if (!entriesForType.count(kvp.first)) { - entriesForType[kvp.first] = std::move(kvp.second); - continue; - } - - Type typeSample = kvp.second.front().getKey().get(); - assert(&typeSample.getDialect() != - typeSample.getContext()->getLoadedDialect() && - "unexpected data layout entry for built-in type"); - - auto interface = llvm::cast(typeSample); - if (!interface.areCompatible(entriesForType.lookup(kvp.first), kvp.second)) - return failure(); - - overwriteDuplicateEntries(entriesForType[kvp.first], kvp.second); - } - + + // Combine non-Type DL entries first so they are visible to the + // `type.areCompatible` method, allowing to query global properties. for (const auto &kvp : newEntriesForID) { StringAttr id = kvp.second.getKey().get(); Dialect *dialect = id.getReferencedDialect(); @@ -197,7 +181,7 @@ combineOneSpec(DataLayoutSpecInterface spec, continue; } - // Attempt to combine the enties using the dialect interface. If the + // Attempt to combine the entries using the dialect interface. If the // dialect is not loaded for some reason, use the default combinator // that conservatively accepts identical entries only. entriesForID[id] = @@ -208,6 +192,27 @@ combineOneSpec(DataLayoutSpecInterface spec, if (!entriesForID[id]) return failure(); } + // Try overwriting the old entries with the new ones. + for (auto &kvp : newEntriesForType) { + if (!entriesForType.count(kvp.first)) { + entriesForType[kvp.first] = std::move(kvp.second); + continue; + } + + Type typeSample = cast(kvp.second.front().getKey()); + assert(&typeSample.getDialect() != + typeSample.getContext()->getLoadedDialect() && + "unexpected data layout entry for built-in type"); + + auto interface = cast(typeSample); + // TODO: Revisit this method and call once + // https://github.com/llvm/llvm-project/issues/130321 gets resolved. + if (!interface.areCompatible(entriesForType.lookup(kvp.first), kvp.second, + spec, entriesForID)) + return failure(); + + overwriteDuplicateEntries(entriesForType[kvp.first], kvp.second); + } return success(); } @@ -244,6 +249,12 @@ DataLayoutSpecAttr::getEndiannessIdentifier(MLIRContext *context) const { return Builder(context).getStringAttr(DLTIDialect::kDataLayoutEndiannessKey); } +StringAttr DataLayoutSpecAttr::getDefaultMemorySpaceIdentifier( + MLIRContext *context) const { + return Builder(context).getStringAttr( + DLTIDialect::kDataLayoutDefaultMemorySpaceKey); +} + StringAttr DataLayoutSpecAttr::getAllocaMemorySpaceIdentifier(MLIRContext *context) const { return Builder(context).getStringAttr( @@ -417,7 +428,9 @@ public: << DLTIDialect::kDataLayoutEndiannessBig << "' or '" << DLTIDialect::kDataLayoutEndiannessLittle << "'"; } - if (entryName == DLTIDialect::kDataLayoutAllocaMemorySpaceKey || + + if (entryName == DLTIDialect::kDataLayoutDefaultMemorySpaceKey || + entryName == DLTIDialect::kDataLayoutAllocaMemorySpaceKey || entryName == DLTIDialect::kDataLayoutProgramMemorySpaceKey || entryName == DLTIDialect::kDataLayoutGlobalMemorySpaceKey || entryName == DLTIDialect::kDataLayoutStackAlignmentKey) diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp index cf3f38b710130753b8f37aaf204e02f58eead164..a3f1bef9b28ce7d145b2541448d105932ef5d6e3 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -349,8 +349,10 @@ LLVMPointerType::getIndexBitwidth(const DataLayout &dataLayout, return dataLayout.getTypeIndexBitwidth(get(getContext())); } -bool LLVMPointerType::areCompatible(DataLayoutEntryListRef oldLayout, - DataLayoutEntryListRef newLayout) const { +bool LLVMPointerType::areCompatible( + DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout, + DataLayoutSpecInterface newSpec, + const DataLayoutIdentifiedEntryMap &map) const { for (DataLayoutEntryInterface newEntry : newLayout) { if (!newEntry.isTypeEntry()) continue; @@ -596,8 +598,10 @@ static uint64_t extractStructSpecValue(Attribute attr, StructDLEntryPos pos) { .getValues()[static_cast(pos)]; } -bool LLVMStructType::areCompatible(DataLayoutEntryListRef oldLayout, - DataLayoutEntryListRef newLayout) const { +bool LLVMStructType::areCompatible( + DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout, + DataLayoutSpecInterface newSpec, + const DataLayoutIdentifiedEntryMap &map) const{ for (DataLayoutEntryInterface newEntry : newLayout) { if (!newEntry.isTypeEntry()) continue; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index d1db90bbe2d207672b830dc269d38171fa80c13a..7df8f93ac892dfa017079f522b46db16841854b3 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1251,11 +1251,12 @@ LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl &) { static ParseResult parseDstStyleOp( OpAsmParser &parser, OperationState &result, function_ref parseAttrsFn = - nullptr) { + nullptr, + bool addOperandSegmentSizes = false) { // Parse `ins` and `outs`. SmallVector inputTypes, outputTypes; if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes, - /*addOperandSegmentSizes=*/false)) + addOperandSegmentSizes)) return failure(); // Add result types. @@ -1596,9 +1597,12 @@ ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) { } if (parseDstStyleOp( - parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) { + parser, result, + [&](OpAsmParser &parser, NamedAttrList &attributes) { return parseDenseI64ArrayAttr(parser, attributes, "dimensions"); - })) + }, + /*addOperandSegmentSizes=*/true)) + return failure(); if (payloadOpName.has_value()) { @@ -1633,7 +1637,9 @@ void ReduceOp::print(OpAsmPrinter &p) { printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits()); printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions()); - p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()}); + p.printOptionalAttrDict( + (*this)->getAttrs(), + {getDimensionsAttrName(), getOperandSegmentSizesAttrName()}); if (!payloadOp) { // Print region if the payload op was not detected. p.increaseIndent(); diff --git a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt index 9cf3643c73d3ed680fb9d11d3ac1845350568265..8c186594b6ad543f119bfd33826f77be4e73493a 100644 --- a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt @@ -7,10 +7,13 @@ add_mlir_dialect_library( DEPENDS MLIRPtrOpsAttributesIncGen MLIRPtrOpsIncGen + MLIRPtrOpsEnumsGen + MLIRPtrMemorySpaceInterfacesIncGen LINK_LIBS PUBLIC MLIRIR MLIRDataLayoutInterfaces MLIRMemorySlotInterfaces + MLIRViewLikeInterface ) diff --git a/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp b/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp index f8ce820d0bcbd0a47bf09a958f092c9da43970ea..1770e4febf099211b616633e0602a5bee99c9668 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Ptr/IR/PtrAttrs.h" +#include "mlir/IR/BuiltinTypes.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; @@ -18,6 +19,51 @@ using namespace mlir::ptr; constexpr const static unsigned kBitsInByte = 8; +//===----------------------------------------------------------------------===// +// GenericSpaceAttr +//===----------------------------------------------------------------------===// + +LogicalResult GenericSpaceAttr::isValidLoad( + Type type, ptr::AtomicOrdering ordering, IntegerAttr alignment, + function_ref emitError) const { + return success(); +} + +LogicalResult GenericSpaceAttr::isValidStore( + Type type, ptr::AtomicOrdering ordering, IntegerAttr alignment, + function_ref emitError) const { + return success(); +} + +LogicalResult GenericSpaceAttr::isValidAtomicOp( + ptr::AtomicBinOp op, Type type, ptr::AtomicOrdering ordering, + IntegerAttr alignment, function_ref emitError) const { + return success(); +} + +LogicalResult GenericSpaceAttr::isValidAtomicXchg( + Type type, ptr::AtomicOrdering successOrdering, + ptr::AtomicOrdering failureOrdering, IntegerAttr alignment, + function_ref emitError) const { + return success(); +} + +LogicalResult GenericSpaceAttr::isValidAddrSpaceCast( + Type tgt, Type src, function_ref emitError) const { + // TODO: update this method once the `addrspace_cast` op is added to the + // dialect. + assert(false && "unimplemented, see TODO in the source."); + return failure(); +} + +LogicalResult GenericSpaceAttr::isValidPtrIntCast( + Type intLikeTy, Type ptrLikeTy, + function_ref emitError) const { + // TODO: update this method once the int-cast ops are added to the dialect. + assert(false && "unimplemented, see TODO in the source."); + return failure(); +} + //===----------------------------------------------------------------------===// // SpecAttr //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp index 7830ffe893dfd156077854cf83b1985071ffcdd9..061b3feb4d666f4df0a5a9fb348da01ad8ef8de4 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp @@ -12,7 +12,9 @@ #include "mlir/Dialect/Ptr/IR/PtrOps.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/TypeSwitch.h" @@ -48,6 +50,12 @@ void PtrDialect::initialize() { #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.cpp.inc" +#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.cpp.inc" + +#include "mlir/Dialect/Ptr/IR/MemorySpaceAttrInterfaces.cpp.inc" + +#include "mlir/Dialect/Ptr/IR/PtrOpsEnums.cpp.inc" + #define GET_TYPEDEF_CLASSES #include "mlir/Dialect/Ptr/IR/PtrOpsTypes.cpp.inc" diff --git a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp index 2866d4eb10feb1287e3bf4f821ad307752b0568b..101330073d2d59e3997112f52b8b505bee568578 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp @@ -23,13 +23,12 @@ using namespace mlir::ptr; constexpr const static unsigned kDefaultPointerSizeBits = 64; constexpr const static unsigned kBitsInByte = 8; -constexpr const static unsigned kDefaultPointerAlignment = 8; - -static Attribute getDefaultMemorySpace(PtrType ptr) { return nullptr; } +constexpr const static unsigned kDefaultPointerAlignmentBits = 8; /// Searches the data layout for the pointer spec, returns nullptr if it is not /// found. -static SpecAttr getPointerSpec(DataLayoutEntryListRef params, PtrType type) { +static SpecAttr getPointerSpec(DataLayoutEntryListRef params, PtrType type, + MemorySpaceAttrInterface defaultMemorySpace) { for (DataLayoutEntryInterface entry : params) { if (!entry.isTypeEntry()) continue; @@ -39,22 +38,26 @@ static SpecAttr getPointerSpec(DataLayoutEntryListRef params, PtrType type) { return spec; } } - // If not found, and this is the pointer to the default memory space, assume - // 64-bit pointers. - if (type.getMemorySpace() == getDefaultMemorySpace(type)) + // If not found, and this is the pointer to the default memory space or if + // `defaultMemorySpace` is null, assume 64-bit pointers. `defaultMemorySpace` + // might be null if the data layout doesn't define the default memory space. + if (type.getMemorySpace() == defaultMemorySpace || + defaultMemorySpace == nullptr) return SpecAttr::get(type.getContext(), kDefaultPointerSizeBits, - kDefaultPointerAlignment, kDefaultPointerAlignment, - kDefaultPointerSizeBits); + kDefaultPointerAlignmentBits, + kDefaultPointerAlignmentBits, kDefaultPointerSizeBits); return nullptr; } bool PtrType::areCompatible(DataLayoutEntryListRef oldLayout, - DataLayoutEntryListRef newLayout) const { + DataLayoutEntryListRef newLayout, + DataLayoutSpecInterface newSpec, + const DataLayoutIdentifiedEntryMap &map) const { for (DataLayoutEntryInterface newEntry : newLayout) { if (!newEntry.isTypeEntry()) continue; uint32_t size = kDefaultPointerSizeBits; - uint32_t abi = kDefaultPointerAlignment; + uint32_t abi = kDefaultPointerAlignmentBits; auto newType = llvm::cast(newEntry.getKey().get()); const auto *it = llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) { @@ -65,10 +68,12 @@ bool PtrType::areCompatible(DataLayoutEntryListRef oldLayout, return false; }); if (it == oldLayout.end()) { + Attribute defaultMemorySpace = mlir::detail::getDefaultMemorySpace( + map.lookup(newSpec.getDefaultMemorySpaceIdentifier(getContext()))); it = llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) { if (auto type = llvm::dyn_cast_if_present(entry.getKey())) { auto ptrTy = llvm::cast(type); - return ptrTy.getMemorySpace() == getDefaultMemorySpace(ptrTy); + return ptrTy.getMemorySpace() == defaultMemorySpace; } return false; }); @@ -90,43 +95,47 @@ bool PtrType::areCompatible(DataLayoutEntryListRef oldLayout, uint64_t PtrType::getABIAlignment(const DataLayout &dataLayout, DataLayoutEntryListRef params) const { - if (SpecAttr spec = getPointerSpec(params, *this)) + auto defaultMemorySpace = llvm::cast_if_present( + dataLayout.getDefaultMemorySpace()); + if (SpecAttr spec = getPointerSpec(params, *this, defaultMemorySpace)) return spec.getAbi() / kBitsInByte; - return dataLayout.getTypeABIAlignment( - get(getContext(), getDefaultMemorySpace(*this))); + return dataLayout.getTypeABIAlignment(get(defaultMemorySpace)); } std::optional PtrType::getIndexBitwidth(const DataLayout &dataLayout, DataLayoutEntryListRef params) const { - if (SpecAttr spec = getPointerSpec(params, *this)) { + auto defaultMemorySpace = llvm::cast_if_present( + dataLayout.getDefaultMemorySpace()); + if (SpecAttr spec = getPointerSpec(params, *this, defaultMemorySpace)) { return spec.getIndex() == SpecAttr::kOptionalSpecValue ? spec.getSize() : spec.getIndex(); } - return dataLayout.getTypeIndexBitwidth( - get(getContext(), getDefaultMemorySpace(*this))); + return dataLayout.getTypeIndexBitwidth(get(defaultMemorySpace)); } llvm::TypeSize PtrType::getTypeSizeInBits(const DataLayout &dataLayout, DataLayoutEntryListRef params) const { - if (SpecAttr spec = getPointerSpec(params, *this)) + auto defaultMemorySpace = llvm::cast_if_present( + dataLayout.getDefaultMemorySpace()); + if (SpecAttr spec = getPointerSpec(params, *this, defaultMemorySpace)) return llvm::TypeSize::getFixed(spec.getSize()); // For other memory spaces, use the size of the pointer to the default memory // space. - return dataLayout.getTypeSizeInBits( - get(getContext(), getDefaultMemorySpace(*this))); + return dataLayout.getTypeSizeInBits(get(defaultMemorySpace)); } uint64_t PtrType::getPreferredAlignment(const DataLayout &dataLayout, DataLayoutEntryListRef params) const { - if (SpecAttr spec = getPointerSpec(params, *this)) + auto defaultMemorySpace = llvm::cast_if_present( + dataLayout.getDefaultMemorySpace()); + if (SpecAttr spec = getPointerSpec(params, *this, defaultMemorySpace)) return spec.getPreferred() / kBitsInByte; - return dataLayout.getTypePreferredAlignment( - get(getContext(), getDefaultMemorySpace(*this))); + return dataLayout.getTypePreferredAlignment(get(defaultMemorySpace)); } LogicalResult PtrType::verifyEntries(DataLayoutEntryListRef entries, @@ -142,3 +151,15 @@ LogicalResult PtrType::verifyEntries(DataLayoutEntryListRef entries, } return success(); } + +//===----------------------------------------------------------------------===// +// Pointer metadata +//===----------------------------------------------------------------------===// + +LogicalResult +PtrMetadataType::verify(function_ref emitError, + PtrLikeTypeInterface type) { + if (!type.hasPtrMetadata()) + return emitError() << "the ptr-like type has no metadata"; + return success(); +} diff --git a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt index 5b4989f328e6909d7d9145d3ff8c4b49b3afc182..3f598a69096972541c5a0cbb1447b6c34ae6a01b 100644 --- a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt @@ -17,6 +17,23 @@ add_mlir_dialect_library(MLIRTransformDialect MLIRRewrite MLIRSideEffectInterfaces MLIRTransforms + # Quite Odd that we need this here + MLIRArmSMETestPasses + MLIRArithTransforms + MLIRAsyncToLLVM + MLIRLinalgTransforms + MLIRMemRefTransforms + MLIRAffineTransforms + MLIRArithToLLVM + MLIROpenMPToLLVM + MLIRIndexToLLVM + MLIRComplexToLLVM + MLIRMathToLLVM + MLIRAffineToStandard + MLIRReconcileUnrealizedCasts + MLIRAsyncTransforms + MLIRVectorToLLVM + MLIRVectorToLLVMPass MLIRTransformDialectInterfaces MLIRTransformDialectUtils ) diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index c4238080533bef11c66aea0fdc1ee8665ff11fde..ba06950ea22503bb738842d1ea03616beb8a417c 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/Dialect/ArmSME/IR/ArmSME.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dominance.h" @@ -41,6 +42,11 @@ #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" + +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/Dialect/ArmSME/Transforms/Passes.h" +#include "mlir/InitAllPasses.h" + #include #define DEBUG_TYPE "transform-dialect" @@ -2859,3 +2865,131 @@ void transform::YieldOp::getEffects( SmallVectorImpl &effects) { onlyReadsHandle(getOperandsMutable(), effects); } + +//===----------------------------------------------------------------------===// +// LowerToArmSMEOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::LowerToArmSMEOp::applyToOne( + transform::TransformRewriter &rewriter, ModuleOp target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + PassManager pm(getContext()); + // createVectorLegalizationPass requires ModuleOp level pass. + // Legalize vector operations so they can be converted to ArmSME. + pm.addPass(arm_sme::createVectorLegalizationPass()); + + // Sprinkle some cleanups. + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + + // Passes that convert operations on vectors to ArmSME operations. + pm.addPass(createArithToArmSMEConversionPass()); + pm.addPass(createConvertVectorToArmSMEPass()); + + // TODO: Leverage FMOPA 2Way for half precision? + // Fuse outer products. + if (getFuseOuterProducts()) + pm.addPass(arm_sme::createOuterProductFusionPass()); + + // Convert operations on high-level vectors to loops. + pm.addPass(createConvertArmSMEToSCFPass()); + // Convert Vector to SCF (with full unroll enabled). + pm.addNestedPass(arm_sme::createEnableArmStreamingPass( + arm_sme::ArmStreamingMode::StreamingLocally, arm_sme::ArmZaMode::NewZA, + /*onlyIfRequiredByOps=*/true)); + + if (failed(pm.run(target))) + return DiagnosedSilenceableFailure::definiteFailure(); + return DiagnosedSilenceableFailure::success(); +} + +void transform::LowerToArmSMEOp::getEffects( + SmallVectorImpl &effects) { + transform::modifiesPayload(effects); + transform::onlyReadsHandle(getTargetMutable(), effects); +} + +//===---------------------------------------------------------------------===// +// LowerToLLVMNewOp +//===---------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::LowerToLLVMNewOp::applyToOne( + transform::TransformRewriter &rewriter, ModuleOp target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + // TODO: it is feasible to scope lowering at arbitrary level and introduce + // unrealized casts, but there needs to be the final module-wise cleanup in + // the end. Keep module-level for now. + MLIRContext *ctx = getContext(); + PassManager pm(ctx); + + // Lower multi dimensionOps to scf + pm.addNestedPass(createConvertVectorToSCFPass()); + pm.addNestedPass(createConvertLinalgToLoopsPass()); + // Lower Async + if (getEnableAsync()) { + pm.addPass(createAsyncToAsyncRuntimePass()); + pm.addPass(createAsyncRuntimeRefCountingPass()); + pm.addPass(createAsyncRuntimeRefCountingOptPass()); + } + pm.addPass(createCanonicalizerPass()); + pm.addPass(memref::createExpandStridedMetadataPass()); + // The expansion may create affine expressions. Get rid of them. + pm.addPass(createLowerAffinePass()); + pm.addPass(createConvertSCFToCFPass()); + if (ctx->getLoadedDialect()) { + pm.addNestedPass(createConvertArmSMEToLLVMPass()); + } + pm.addPass(createConvertComplexToLLVMPass()); + pm.addPass(createConvertVectorToLLVMPass(ConvertVectorToLLVMPassOptions{ + /* reassociateFPReductions = */ getReassociateFpReductions(), + /* force32BitVectorIndices */ getEnableIndexOptimizations(), + /* amx = */ getEnableAmx(), + /* armNeon = */ getEnableArmNeon(), + /* armSVE = */ getEnableArmSve(), + /* x86Vector = */ getEnableX86vector()})); + pm.addNestedPass(createConvertMathToLLVMPass()); + pm.addNestedPass(arith::createArithExpandOpsPass()); + pm.addPass(createFinalizeMemRefToLLVMConversionPass()); + if (getEnableAsync()) + pm.addPass(createConvertAsyncToLLVMPass()); + pm.addPass(createConvertOpenMPToLLVMPass()); + pm.addPass(createConvertFuncToLLVMPass()); + pm.addPass(createConvertControlFlowToLLVMPass()); + pm.addPass(createArithToLLVMConversionPass()); + pm.addPass(createConvertIndexToLLVMPass()); + pm.addPass(createReconcileUnrealizedCastsPass()); + if (failed(pm.run(target))) + return DiagnosedSilenceableFailure::definiteFailure(); + + llvm::SmallVector attrs; + + if (getVscaleRange() > 0) { + attrs.push_back(mlir::ArrayAttr::get( + ctx, {mlir::StringAttr::get(ctx, "vscale_range"), + mlir::StringAttr::get(ctx, llvm::Twine(getVscaleRange()))})); + + target->walk([&](LLVM::LLVMFuncOp funcOp) { + if (!funcOp.getBody().empty()) + funcOp->setAttr("passthrough", mlir::ArrayAttr::get(ctx, attrs)); + }); + } + + // Make all arguments noalias for now. + // FIXME: this is a terrible hack! + target->walk([](LLVM::LLVMFuncOp funcOp) { + for (int64_t i = 0; i < funcOp.getNumArguments(); ++i) { + if (!isa(funcOp.getFunctionType().getParamType(i))) + continue; + funcOp.setArgAttr(i, "llvm.noalias", UnitAttr::get(funcOp.getContext())); + } + }); + return DiagnosedSilenceableFailure::success(); +} + +void transform::LowerToLLVMNewOp::getEffects( + SmallVectorImpl &effects) { + transform::modifiesPayload(effects); + transform::onlyReadsHandle(getTargetMutable(), effects); +} \ No newline at end of file diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index 2e9aa88011825bc2d4f9d7550b8687370bf7a80d..4ed4aa57565f87df814a4cf0aa5d8bbf13d6003c 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -129,7 +129,9 @@ void transform::ApplyLowerMultiReductionPatternsOp::populatePatterns( void transform::ApplyLowerOuterProductPatternsOp::populatePatterns( RewritePatternSet &patterns) { - populateVectorOuterProductLoweringPatterns(patterns); + vector::VectorTransformsOptions vectorTransformOptions; + vectorTransformOptions.enableArmSVE(getIsSVE()); + populateVectorOuterProductLoweringPatterns(patterns, vectorTransformOptions); } void transform::ApplyLowerGatherPatternsOp::populatePatterns( diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt index 723b2f62d65d4fb6f607d2492ede77b3fe90bc67..22bb8fa4e631611b2771d79f17fd9c609eb203ed 100644 --- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt @@ -40,6 +40,7 @@ add_mlir_dialect_library(MLIRVectorTransforms MLIRGPUDialect MLIRIR MLIRLinalgDialect + MLIRLLVMDialect MLIRMemRefDialect MLIRMemRefUtils MLIRSCFDialect diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp index 3a799ce8e0bce3301bc8b1a4733ea5e1163149ad..748ac6fe96eb20f5376d0d1b1332696b52691ec4 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -1200,8 +1201,15 @@ FailureOr ContractionOpLowering::lowerReduction( /// %x = vector.insert %.., %..[N-1] /// class OuterProductOpLowering : public OpRewritePattern { +private: + vector::VectorTransformsOptions vectorTransformOptions; + public: using OpRewritePattern::OpRewritePattern; + OuterProductOpLowering(vector::VectorTransformsOptions vectorTransformOptions, + MLIRContext *context, PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + vectorTransformOptions(vectorTransformOptions) {} LogicalResult matchAndRewrite(vector::OuterProductOp op, PatternRewriter &rewriter) const override { @@ -1246,19 +1254,51 @@ public: loc, resType, rewriter.getZeroAttr(resType)); for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) { Value x = rewriter.create(loc, op.getLhs(), d); - Value a = rewriter.create(loc, rhsType, x); - Value r = nullptr; - if (acc) + Value r = nullptr; + if (acc) r = rewriter.create(loc, acc, d); - Value extrMask; - if (mask) - extrMask = rewriter.create(loc, mask, d); - - std::optional m = createContractArithOp( - loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask); - if (!m.has_value()) - return failure(); - result = rewriter.create(loc, *m, result, d); + if (vectorTransformOptions.armSve) { + long sizeOfScalableVector = + 128 / + mlir::LLVM::getPrimitiveTypeSizeInBits(resType.getElementType()); + assert(d <= sizeOfScalableVector && "Unsupported index for SVE fmla"); + Type vectype = VectorType::get({sizeOfScalableVector}, + rhsType.getElementType(), 1); + auto udef = rewriter.create(loc, vectype); + auto ScalableLhs = rewriter.create( + loc, op.getLhs(), udef, 0); + auto i0 = rewriter.create( + loc, rewriter.getZeroAttr(rewriter.getI64Type())); + StringAttr dupq = rewriter.getStringAttr("llvm.aarch64.sve.dupq.lane"); + auto lhsdup = rewriter.create( + loc, TypeRange{vectype}, dupq, ValueRange{ScalableLhs, i0}); + auto broadcastedLHS = lhsdup.getResult(0); + auto rhs = rewriter.create(loc, op.getRhs(), + udef, 0); + r = rewriter.create(loc, r, udef, 0); + auto idx = rewriter.create( + loc, rewriter.getI32IntegerAttr(d)); + LLVM::CallIntrinsicOp fma; + StringAttr fmla = rewriter.getStringAttr("llvm.aarch64.sve.fmla.lane"); + fma = rewriter.create( + loc, TypeRange{vectype}, fmla, + ValueRange{r, rhs, broadcastedLHS, idx}); + auto mm = rewriter.create( + loc, rhsType, fma.getResult(0), 0); + result = + rewriter.create(loc, mm, result, d); + } else { + Value a = rewriter.create(loc, rhsType, x); + Value extrMask; + if (mask) + extrMask = rewriter.create(loc, mask, d); + + std::optional m = createContractArithOp( + loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask); + if (!m.has_value()) + return failure(); + result = rewriter.create(loc, *m, result, d); + } } rewriter.replaceOp(rootOp, result); @@ -1378,13 +1418,14 @@ void mlir::vector::populateVectorContractLoweringPatterns( RewritePatternSet &patterns, VectorTransformsOptions options, PatternBenefit benefit, bool disableOuterProductLowering) { if (!disableOuterProductLowering) - patterns.add(patterns.getContext(), benefit); + patterns.add(options, patterns.getContext(), benefit); patterns.add( options, patterns.getContext(), benefit); } void mlir::vector::populateVectorOuterProductLoweringPatterns( - RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add(patterns.getContext(), benefit); + RewritePatternSet &patterns, VectorTransformsOptions options, + PatternBenefit benefit) { + patterns.add(options, patterns.getContext(), benefit); } diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp index faa944937e007a868803d82edb0889df9689c67d..39627cba244d3b3c2c5c895047e72701503e2315 100644 --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -386,6 +386,20 @@ BaseMemRefType BaseMemRefType::cloneWith(std::optional> shape, return builder; } +FailureOr +BaseMemRefType::clonePtrWith(Attribute memorySpace, + std::optional elementType) const { + Type eTy = elementType ? *elementType : getElementType(); + if (llvm::dyn_cast(*this)) + return ::llvm::cast( + UnrankedMemRefType::get(eTy, memorySpace)); + + MemRefType::Builder builder(llvm::cast(*this)); + builder.setElementType(eTy); + builder.setMemorySpace(memorySpace); + return ::llvm::cast(static_cast(builder)); +} + MemRefType BaseMemRefType::clone(::llvm::ArrayRef shape, Type elementType) const { return ::llvm::cast(cloneWith(shape, elementType)); diff --git a/mlir/lib/IR/ODSSupport.cpp b/mlir/lib/IR/ODSSupport.cpp index 6e968d62e61c7f4bf6831552c83c79421d7c717d..a55cdd834b39cad7e7c2e01ba5db593e1d738a0a 100644 --- a/mlir/lib/IR/ODSSupport.cpp +++ b/mlir/lib/IR/ODSSupport.cpp @@ -33,6 +33,50 @@ Attribute mlir::convertToAttribute(MLIRContext *ctx, int64_t storage) { return IntegerAttr::get(IntegerType::get(ctx, 64), storage); } +LogicalResult +mlir::convertFromAttribute(int32_t &storage, Attribute attr, + function_ref emitError) { + auto valueAttr = dyn_cast(attr); + if (!valueAttr) { + emitError() << "expected IntegerAttr for key `value`"; + return failure(); + } + storage = valueAttr.getValue().getSExtValue(); + return success(); +} +Attribute mlir::convertToAttribute(MLIRContext *ctx, int32_t storage) { + return IntegerAttr::get(IntegerType::get(ctx, 32), storage); +} + +LogicalResult +mlir::convertFromAttribute(std::string &storage, Attribute attr, + function_ref emitError) { + auto valueAttr = dyn_cast(attr); + if (!valueAttr) + return emitError() + << "expected string property to come from string attribute"; + storage = valueAttr.getValue().str(); + return success(); +} +Attribute mlir::convertToAttribute(MLIRContext *ctx, + const std::string &storage) { + return StringAttr::get(ctx, storage); +} + +LogicalResult +mlir::convertFromAttribute(bool &storage, Attribute attr, + function_ref emitError) { + auto valueAttr = dyn_cast(attr); + if (!valueAttr) + return emitError() + << "expected string property to come from string attribute"; + storage = valueAttr.getValue(); + return success(); +} +Attribute mlir::convertToAttribute(MLIRContext *ctx, bool storage) { + return BoolAttr::get(ctx, storage); +} + template LogicalResult convertDenseArrayFromAttr(MutableArrayRef storage, Attribute attr, @@ -64,7 +108,34 @@ mlir::convertFromAttribute(MutableArrayRef storage, Attribute attr, "DenseI32ArrayAttr"); } +template +LogicalResult +convertDenseArrayFromAttr(SmallVectorImpl &storage, Attribute attr, + function_ref emitError, + StringRef denseArrayTyStr) { + auto valueAttr = dyn_cast(attr); + if (!valueAttr) { + emitError() << "expected " << denseArrayTyStr << " for key `value`"; + return failure(); + } + storage.resize_for_overwrite(valueAttr.size()); + llvm::copy(valueAttr.asArrayRef(), storage.begin()); + return success(); +} +LogicalResult +mlir::convertFromAttribute(SmallVectorImpl &storage, Attribute attr, + function_ref emitError) { + return convertDenseArrayFromAttr(storage, attr, emitError, + "DenseI64ArrayAttr"); +} +LogicalResult +mlir::convertFromAttribute(SmallVectorImpl &storage, Attribute attr, + function_ref emitError) { + return convertDenseArrayFromAttr(storage, attr, emitError, + "DenseI32ArrayAttr"); +} + Attribute mlir::convertToAttribute(MLIRContext *ctx, ArrayRef storage) { return DenseI64ArrayAttr::get(ctx, storage); -} +} \ No newline at end of file diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp index 2634245a4b7b1e3b7c41bd8e2b80673be59c3830..9b3885cc539d3eb4e65d568b31222096d106685a 100644 --- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp +++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp @@ -246,6 +246,16 @@ Attribute mlir::detail::getDefaultEndianness(DataLayoutEntryInterface entry) { return entry.getValue(); } +// Returns the default memory space if specified in the given entry. If the +// entry is empty the default memory space represented by an empty attribute is +// returned. +Attribute mlir::detail::getDefaultMemorySpace(DataLayoutEntryInterface entry) { + if (!entry) + return Attribute(); + + return entry.getValue(); +} + // Returns the memory space used for alloca operations if specified in the // given entry. If the entry is empty the default memory space represented by // an empty attribute is returned. @@ -596,6 +606,23 @@ mlir::Attribute mlir::DataLayout::getEndianness() const { return *endianness; } +mlir::Attribute mlir::DataLayout::getDefaultMemorySpace() const { + checkValid(); + if (defaultMemorySpace) + return *defaultMemorySpace; + DataLayoutEntryInterface entry; + if (originalLayout) + entry = originalLayout.getSpecForIdentifier( + originalLayout.getDefaultMemorySpaceIdentifier( + originalLayout.getContext())); + if (auto iface = dyn_cast_or_null(scope)) + defaultMemorySpace = iface.getDefaultMemorySpace(entry); + else + defaultMemorySpace = detail::getDefaultMemorySpace(entry); + return *defaultMemorySpace; +} + + mlir::Attribute mlir::DataLayout::getAllocaMemorySpace() const { checkValid(); if (allocaMemorySpace) diff --git a/mlir/lib/TableGen/Property.cpp b/mlir/lib/TableGen/Property.cpp index e61d2fd2480fd5338eb33cf460fb7baa6aadbb2e..9f4b9ce1a294470d36d16b10c24a31c7c3cf1175 100644 --- a/mlir/lib/TableGen/Property.cpp +++ b/mlir/lib/TableGen/Property.cpp @@ -33,16 +33,23 @@ static StringRef getValueAsString(const Init *init) { } Property::Property(const Record *def) - : Property(getValueAsString(def->getValueInit("storageType")), - getValueAsString(def->getValueInit("interfaceType")), - getValueAsString(def->getValueInit("convertFromStorage")), - getValueAsString(def->getValueInit("assignToStorage")), - getValueAsString(def->getValueInit("convertToAttribute")), - getValueAsString(def->getValueInit("convertFromAttribute")), - getValueAsString(def->getValueInit("readFromMlirBytecode")), - getValueAsString(def->getValueInit("writeToMlirBytecode")), - getValueAsString(def->getValueInit("hashProperty")), - getValueAsString(def->getValueInit("defaultValue"))) { + : Property( + getValueAsString(def->getValueInit("summary")), + getValueAsString(def->getValueInit("description")), + getValueAsString(def->getValueInit("storageType")), + getValueAsString(def->getValueInit("interfaceType")), + getValueAsString(def->getValueInit("convertFromStorage")), + getValueAsString(def->getValueInit("assignToStorage")), + getValueAsString(def->getValueInit("convertToAttribute")), + getValueAsString(def->getValueInit("convertFromAttribute")), + getValueAsString(def->getValueInit("parser")), + getValueAsString(def->getValueInit("optionalParser")), + getValueAsString(def->getValueInit("printer")), + getValueAsString(def->getValueInit("readFromMlirBytecode")), + getValueAsString(def->getValueInit("writeToMlirBytecode")), + getValueAsString(def->getValueInit("hashProperty")), + getValueAsString(def->getValueInit("defaultValue")), + getValueAsString(def->getValueInit("storageTypeValueOverride"))) { this->def = def; assert((def->isSubClassOf("Property") || def->isSubClassOf("Attr")) && "must be subclass of TableGen 'Property' class"); @@ -50,22 +57,44 @@ Property::Property(const Record *def) Property::Property(const DefInit *init) : Property(init->getDef()) {} -Property::Property(StringRef storageType, StringRef interfaceType, +Property::Property(StringRef summary, StringRef description, + StringRef storageType, StringRef interfaceType, StringRef convertFromStorageCall, StringRef assignToStorageCall, StringRef convertToAttributeCall, - StringRef convertFromAttributeCall, + StringRef convertFromAttributeCall, StringRef parserCall, + StringRef optionalParserCall, StringRef printerCall, StringRef readFromMlirBytecodeCall, StringRef writeToMlirBytecodeCall, - StringRef hashPropertyCall, StringRef defaultValue) - : storageType(storageType), interfaceType(interfaceType), + StringRef hashPropertyCall, StringRef defaultValue, + StringRef storageTypeValueOverride) + : summary(summary), description(description), storageType(storageType), + interfaceType(interfaceType), convertFromStorageCall(convertFromStorageCall), assignToStorageCall(assignToStorageCall), convertToAttributeCall(convertToAttributeCall), convertFromAttributeCall(convertFromAttributeCall), + parserCall(parserCall), optionalParserCall(optionalParserCall), + printerCall(printerCall), readFromMlirBytecodeCall(readFromMlirBytecodeCall), writeToMlirBytecodeCall(writeToMlirBytecodeCall), - hashPropertyCall(hashPropertyCall), defaultValue(defaultValue) { + hashPropertyCall(hashPropertyCall), defaultValue(defaultValue), + storageTypeValueOverride(storageTypeValueOverride) { if (storageType.empty()) storageType = "Property"; } + +StringRef Property::getPropertyDefName() const { + if (def->isAnonymous()) { + return getBaseProperty().def->getName(); + } + return def->getName(); +} + +Property Property::getBaseProperty() const { + if (const auto *defInit = + llvm::dyn_cast(def->getValueInit("baseProperty"))) { + return Property(defInit).getBaseProperty(); + } + return *this; +} \ No newline at end of file diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index fc3fb0b5334c144d16aa462b996f039ab88571c8..666844d16fcb17de8e7d4445082ef6b6f358dccf 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -1224,6 +1224,19 @@ static LogicalResult checkedAddLLVMFnAttribute(Location loc, llvmFunc->addFnAttr(key, value); return success(); } + if (kind == llvm::Attribute::VScaleRange) { + llvm::AttrBuilder attr_builder(llvmFunc->getContext()); + int result; + if (!value.getAsInteger(/*Radix=*/0, result)) + attr_builder.addVScaleRangeAttr(result, std::nullopt); + else + return emitError(loc) + << "LLVM attribute 'vscale_range' expects an integer value"; + + llvmFunc->addFnAttrs(attr_builder); + return success(); + } + if (llvm::Attribute::isIntAttrKind(kind)) { if (value.empty()) diff --git a/mlir/lib/Tools/mlir-opt/CMakeLists.txt b/mlir/lib/Tools/mlir-opt/CMakeLists.txt index f24d4c60174eeca225e81ca3d99369b916f1989f..dbc5ef973c566f8ec9e630a30e55111ed2753598 100644 --- a/mlir/lib/Tools/mlir-opt/CMakeLists.txt +++ b/mlir/lib/Tools/mlir-opt/CMakeLists.txt @@ -5,6 +5,7 @@ add_mlir_library(MLIROptLib ${MLIR_MAIN_INCLUDE_DIR}/mlir/Tools/mlir-opt LINK_LIBS PUBLIC + MLIRArmSVEDialect MLIRBytecodeWriter MLIRDebug MLIRObservers diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index 055256903a152289b89324e2684b4c63fc87abe6..3e7a0cca31c772443907ed5b37b1dacf72a762b8 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -72,15 +72,50 @@ using namespace mlir::dataflow; namespace { +// Set of structures below to be filled with operations and arguments to erase. +// This is done to separate analysis and tree modification phases, +// otherwise analysis is operating on half-deleted tree which is incorrect. + +struct FunctionToCleanUp { + FunctionOpInterface funcOp; + BitVector nonLiveArgs; + BitVector nonLiveRets; +}; + +struct OperationToCleanup { + Operation *op; + BitVector nonLive; +}; + +struct BlockArgsToCleanup { + Block *b; + BitVector nonLiveArgs; +}; + +struct SuccessorOperandsToCleanup { + BranchOpInterface branch; + unsigned successorIndex; + BitVector nonLiveOperands; +}; + +struct RDVFinalCleanupList { + SmallVector operations; + SmallVector values; + SmallVector functions; + SmallVector operands; + SmallVector results; + SmallVector blocks; + SmallVector successorOperands; +}; + // Some helper functions... /// Return true iff at least one value in `values` is live, given the liveness /// information in `la`. -static bool hasLive(ValueRange values, RunLivenessAnalysis &la) { +static bool hasLive(ValueRange values, const DenseSet &nonLiveSet, + RunLivenessAnalysis &la) { for (Value value : values) { - // If there is a null value, it implies that it was dropped during the - // execution of this pass, implying that it was non-live. - if (!value) + if (nonLiveSet.contains(value)) continue; const Liveness *liveness = la.getLiveness(value); @@ -92,11 +127,12 @@ static bool hasLive(ValueRange values, RunLivenessAnalysis &la) { /// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the /// i-th value in `values` is live, given the liveness information in `la`. -static BitVector markLives(ValueRange values, RunLivenessAnalysis &la) { +static BitVector markLives(ValueRange values, const DenseSet &nonLiveSet, + RunLivenessAnalysis &la) { BitVector lives(values.size(), true); for (auto [index, value] : llvm::enumerate(values)) { - if (!value) { + if (nonLiveSet.contains(value)) { lives.reset(index); continue; } @@ -115,6 +151,18 @@ static BitVector markLives(ValueRange values, RunLivenessAnalysis &la) { return lives; } +/// Collects values marked as "non-live" in the provided range and inserts them +/// into the nonLiveSet. A value is considered "non-live" if the corresponding +/// index in the `nonLive` bit vector is set. +static void collectNonLiveValues(DenseSet &nonLiveSet, ValueRange range, + const BitVector &nonLive) { + for (auto [index, result] : llvm::enumerate(range)) { + if (!nonLive[index]) + continue; + nonLiveSet.insert(result); + } +} + /// Drop the uses of the i-th result of `op` and then erase it iff toErase[i] /// is 1. static void dropUsesAndEraseResults(Operation *op, BitVector toErase) { @@ -165,52 +213,59 @@ static SmallVector operandsToOpOperands(OperandRange operands) { return opOperands; } -/// Clean a simple op `op`, given the liveness analysis information in `la`. -/// Here, cleaning means: -/// (1) Dropping all its uses, AND -/// (2) Erasing it -/// iff it has no memory effects and none of its results are live. +/// Process a simple operation `op` using the liveness analysis `la`. +/// If the operation has no memory effects and none of its results are live: +/// 1. Add the operation to a list for future removal, and +/// 2. Mark all its results as non-live values /// -/// It is assumed that `op` is simple. Here, a simple op is one which isn't a -/// symbol op, a symbol-user op, a region branch op, a branch op, a region -/// branch terminator op, or return-like. -static void cleanSimpleOp(Operation *op, RunLivenessAnalysis &la) { - if (!isMemoryEffectFree(op) || hasLive(op->getResults(), la)) +/// The operation `op` is assumed to be simple. A simple operation is one that +/// is NOT: +/// - Function-like +/// - Call-like +/// - A region branch operation +/// - A branch operation +/// - A region branch terminator +/// - Return-like +static void processSimpleOp(Operation *op, RunLivenessAnalysis &la, + DenseSet &nonLiveSet, + RDVFinalCleanupList &cl) { + if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) return; - op->dropAllUses(); - op->erase(); + cl.operations.push_back(op); + collectNonLiveValues(nonLiveSet, op->getResults(), + BitVector(op->getNumResults(), true)); } -/// Clean a function-like op `funcOp`, given the liveness information in `la` -/// and the IR in `module`. Here, cleaning means: -/// (1) Dropping the uses of its unnecessary (non-live) arguments, -/// (2) Erasing these arguments, -/// (3) Erasing their corresponding operands from its callers, -/// (4) Erasing its unnecessary terminator operands (return values that are -/// non-live across all callers), -/// (5) Dropping the uses of these return values from its callers, AND -/// (6) Erasing these return values -/// iff it is not public. -static void cleanFuncOp(FunctionOpInterface funcOp, Operation *module, - RunLivenessAnalysis &la) { - if (funcOp.isPublic()) +/// Process a function-like operation `funcOp` using the liveness analysis `la` +/// and the IR in `module`. If it is not public or external: +/// (1) Adding its non-live arguments to a list for future removal. +/// (2) Marking their corresponding operands in its callers for removal. +/// (3) Identifying and enqueueing unnecessary terminator operands +/// (return values that are non-live across all callers) for removal. +/// (4) Enqueueing the non-live arguments and return values for removal. +/// (5) Collecting the uses of these return values in its callers for future +/// removal. +/// (6) Marking all its results as non-live values. +static void processFuncOp(FunctionOpInterface funcOp, Operation *module, + RunLivenessAnalysis &la, DenseSet &nonLiveSet, + RDVFinalCleanupList &cl) { + if (funcOp.isPublic() || funcOp.isExternal()) return; // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`. SmallVector arguments(funcOp.getArguments()); - BitVector nonLiveArgs = markLives(arguments, la); + BitVector nonLiveArgs = markLives(arguments, nonLiveSet, la); nonLiveArgs = nonLiveArgs.flip(); // Do (1). for (auto [index, arg] : llvm::enumerate(arguments)) - if (arg && nonLiveArgs[index]) - arg.dropAllUses(); + if (arg && nonLiveArgs[index]) { + cl.values.push_back(arg); + nonLiveSet.insert(arg); + } // Do (2). - funcOp.eraseArguments(nonLiveArgs); - - // Do (3). SymbolTable::UseRange uses = *funcOp.getSymbolUses(module); for (SymbolTable::SymbolUse use : uses) { Operation *callOp = use.getUser(); @@ -222,9 +277,10 @@ static void cleanFuncOp(FunctionOpInterface funcOp, Operation *module, operandsToOpOperands(cast(callOp).getArgOperands()); for (int index : nonLiveArgs.set_bits()) nonLiveCallOperands.set(callOpOperands[index]->getOperandNumber()); - callOp->eraseOperands(nonLiveCallOperands); + cl.operands.push_back({callOp, nonLiveCallOperands}); } + // Do (3). // Get the list of unnecessary terminator operands (return values that are // non-live across all callers) in `nonLiveRets`. There is a very important // subtlety here. Unnecessary terminator operands are NOT the operands of the @@ -253,62 +309,74 @@ static void cleanFuncOp(FunctionOpInterface funcOp, Operation *module, for (SymbolTable::SymbolUse use : uses) { Operation *callOp = use.getUser(); assert(isa(callOp) && "expected a call-like user"); - BitVector liveCallRets = markLives(callOp->getResults(), la); + BitVector liveCallRets = markLives(callOp->getResults(), nonLiveSet, la); nonLiveRets &= liveCallRets.flip(); } - // Do (4). // Note that in the absence of control flow ops forcing the control to go from // the entry (first) block to the other blocks, the control never reaches any // block other than the entry block, because every block has a terminator. for (Block &block : funcOp.getBlocks()) { Operation *returnOp = block.getTerminator(); if (returnOp && returnOp->getNumOperands() == numReturns) - returnOp->eraseOperands(nonLiveRets); + cl.operands.push_back({returnOp, nonLiveRets}); } - funcOp.eraseResults(nonLiveRets); + + // Do (4). + cl.functions.push_back({funcOp, nonLiveArgs, nonLiveRets}); // Do (5) and (6). for (SymbolTable::SymbolUse use : uses) { Operation *callOp = use.getUser(); assert(isa(callOp) && "expected a call-like user"); - dropUsesAndEraseResults(callOp, nonLiveRets); + cl.results.push_back({callOp, nonLiveRets}); + collectNonLiveValues(nonLiveSet, callOp->getResults(), nonLiveRets); } } -/// Clean a region branch op `regionBranchOp`, given the liveness information in -/// `la`. Here, cleaning means: -/// (1') Dropping all its uses, AND -/// (2') Erasing it -/// if it has no memory effects and none of its results are live, AND -/// (1) Erasing its unnecessary operands (operands that are forwarded to -/// unneccesary results and arguments), -/// (2) Cleaning each of its regions, -/// (3) Dropping the uses of its unnecessary results (results that are -/// forwarded from unnecessary operands and terminator operands), AND -/// (4) Erasing these results -/// otherwise. -/// Note that here, cleaning a region means: -/// (2.a) Dropping the uses of its unnecessary arguments (arguments that are -/// forwarded from unneccesary operands and terminator operands), -/// (2.b) Erasing these arguments, AND -/// (2.c) Erasing its unnecessary terminator operands (terminator operands -/// that are forwarded to unneccesary results and arguments). -/// It is important to note that values in this op flow from operands and -/// terminator operands (successor operands) to arguments and results (successor -/// inputs). -static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp, - RunLivenessAnalysis &la) { +/// Process a region branch operation `regionBranchOp` using the liveness +/// information in `la`. The processing involves two scenarios: +/// +/// Scenario 1: If the operation has no memory effects and none of its results +/// are live: +/// (1') Enqueue all its uses for deletion. +/// (2') Enqueue the branch itself for deletion. +/// +/// Scenario 2: Otherwise: +/// (1) Collect its unnecessary operands (operands forwarded to unnecessary +/// results or arguments). +/// (2) Process each of its regions. +/// (3) Collect the uses of its unnecessary results (results forwarded from +/// unnecessary operands +/// or terminator operands). +/// (4) Add these results to the deletion list. +/// +/// Processing a region includes: +/// (a) Collecting the uses of its unnecessary arguments (arguments forwarded +/// from unnecessary operands +/// or terminator operands). +/// (b) Collecting these unnecessary arguments. +/// (c) Collecting its unnecessary terminator operands (terminator operands +/// forwarded to unnecessary results +/// or arguments). +/// +/// Value Flow Note: In this operation, values flow as follows: +/// - From operands and terminator operands (successor operands) +/// - To arguments and results (successor inputs). +static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, + RunLivenessAnalysis &la, + DenseSet &nonLiveSet, + RDVFinalCleanupList &cl) { // Mark live results of `regionBranchOp` in `liveResults`. auto markLiveResults = [&](BitVector &liveResults) { - liveResults = markLives(regionBranchOp->getResults(), la); + liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la); }; // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`. auto markLiveArgs = [&](DenseMap &liveArgs) { for (Region ®ion : regionBranchOp->getRegions()) { SmallVector arguments(region.front().getArguments()); - BitVector regionLiveArgs = markLives(arguments, la); + BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la); liveArgs[®ion] = regionLiveArgs; } }; @@ -491,18 +559,19 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp, } }; - // Do (1') and (2'). This is the only case where the entire `regionBranchOp` + // Scenario 1. This is the only case where the entire `regionBranchOp` // is removed. It will not happen in any other scenario. Note that in this // case, a non-forwarded operand of `regionBranchOp` could be live/non-live. // It could never be live because of this op but its liveness could have been // attributed to something else. + // Do (1') and (2'). if (isMemoryEffectFree(regionBranchOp.getOperation()) && - !hasLive(regionBranchOp->getResults(), la)) { - regionBranchOp->dropAllUses(); - regionBranchOp->erase(); + !hasLive(regionBranchOp->getResults(), nonLiveSet, la)) { + cl.operations.push_back(regionBranchOp.getOperation()); return; } + // Scenario 2. // At this point, we know that every non-forwarded operand of `regionBranchOp` // is live. @@ -538,29 +607,127 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp, terminatorOperandsToKeep); // Do (1). - regionBranchOp->eraseOperands(operandsToKeep.flip()); + cl.operands.push_back({regionBranchOp, operandsToKeep.flip()}); // Do (2.a) and (2.b). for (Region ®ion : regionBranchOp->getRegions()) { assert(!region.empty() && "expected a non-empty region in an op " "implementing `RegionBranchOpInterface`"); - for (auto [index, arg] : llvm::enumerate(region.front().getArguments())) { - if (argsToKeep[®ion][index]) - continue; - if (arg) - arg.dropAllUses(); - } - region.front().eraseArguments(argsToKeep[®ion].flip()); + BitVector argsToRemove = argsToKeep[®ion].flip(); + cl.blocks.push_back({®ion.front(), argsToRemove}); + collectNonLiveValues(nonLiveSet, region.front().getArguments(), + argsToRemove); } // Do (2.c). for (Region ®ion : regionBranchOp->getRegions()) { Operation *terminator = region.front().getTerminator(); - terminator->eraseOperands(terminatorOperandsToKeep[terminator].flip()); + cl.operands.push_back( + {terminator, terminatorOperandsToKeep[terminator].flip()}); } // Do (3) and (4). - dropUsesAndEraseResults(regionBranchOp.getOperation(), resultsToKeep.flip()); + BitVector resultsToRemove = resultsToKeep.flip(); + collectNonLiveValues(nonLiveSet, regionBranchOp.getOperation()->getResults(), + resultsToRemove); + cl.results.push_back({regionBranchOp.getOperation(), resultsToRemove}); +} + +/// Steps to process a `BranchOpInterface` operation: +/// Iterate through each successor block of `branchOp`. +/// (1) For each successor block, gather all operands from all successors. +/// (2) Fetch their associated liveness analysis data and collect for future +/// removal. +/// (3) Identify and collect the dead operands from the successor block +/// as well as their corresponding arguments. + +static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la, + DenseSet &nonLiveSet, + RDVFinalCleanupList &cl) { + unsigned numSuccessors = branchOp->getNumSuccessors(); + + for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) { + Block *successorBlock = branchOp->getSuccessor(succIdx); + + // Do (1) + SuccessorOperands successorOperands = + branchOp.getSuccessorOperands(succIdx); + SmallVector operandValues; + for (unsigned operandIdx = 0; operandIdx < successorOperands.size(); + ++operandIdx) { + operandValues.push_back(successorOperands[operandIdx]); + } + + // Do (2) + BitVector successorNonLive = + markLives(operandValues, nonLiveSet, la).flip(); + collectNonLiveValues(nonLiveSet, successorBlock->getArguments(), + successorNonLive); + + // Do (3) + cl.blocks.push_back({successorBlock, successorNonLive}); + cl.successorOperands.push_back({branchOp, succIdx, successorNonLive}); + } +} + +/// Removes dead values collected in RDVFinalCleanupList. +/// To be run once when all dead values have been collected. +static void cleanUpDeadVals(RDVFinalCleanupList &list) { + // 1. Operations + for (auto &op : list.operations) { + op->dropAllUses(); + op->erase(); + } + + // 2. Values + for (auto &v : list.values) { + v.dropAllUses(); + } + + // 3. Functions + for (auto &f : list.functions) { + f.funcOp.eraseArguments(f.nonLiveArgs); + f.funcOp.eraseResults(f.nonLiveRets); + } + + // 4. Operands + for (auto &o : list.operands) { + o.op->eraseOperands(o.nonLive); + } + + // 5. Results + for (auto &r : list.results) { + dropUsesAndEraseResults(r.op, r.nonLive); + } + + // 6. Blocks + for (auto &b : list.blocks) { + // blocks that are accessed via multiple codepaths processed once + if (b.b->getNumArguments() != b.nonLiveArgs.size()) + continue; + // it iterates backwards because erase invalidates all successor indexes + for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) { + if (!b.nonLiveArgs[i]) + continue; + b.b->getArgument(i).dropAllUses(); + b.b->eraseArgument(i); + } + } + + // 7. Successor Operands + for (auto &op : list.successorOperands) { + SuccessorOperands successorOperands = + op.branch.getSuccessorOperands(op.successorIndex); + // blocks that are accessed via multiple codepaths processed once + if (successorOperands.size() != op.nonLiveOperands.size()) + continue; + // it iterates backwards because erase invalidates all successor indexes + for (int i = successorOperands.size() - 1; i >= 0; --i) { + if (!op.nonLiveOperands[i]) + continue; + successorOperands.erase(i); + } + } } struct RemoveDeadValues : public impl::RemoveDeadValuesBase { @@ -572,28 +739,21 @@ void RemoveDeadValues::runOnOperation() { auto &la = getAnalysis(); Operation *module = getOperation(); - // The removal of non-live values is performed iff there are no branch ops, - // all symbol ops present in the IR are function-like, and all symbol user ops - // present in the IR are call-like. - WalkResult acceptableIR = module->walk([&](Operation *op) { - if (isa(op) || - (isa(op) && !isa(op)) || - (isa(op) && !isa(op))) { - op->emitError() << "cannot optimize an IR with non-function symbol ops, " - "non-call symbol user ops or branch ops\n"; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); + // Tracks values eligible for erasure - complements liveness analysis to + // identify "droppable" values. + DenseSet deadVals; - if (acceptableIR.wasInterrupted()) - return; + // Maintains a list of Ops, values, branches, etc., slated for cleanup at the + // end of this pass. + RDVFinalCleanupList finalCleanupList; module->walk([&](Operation *op) { if (auto funcOp = dyn_cast(op)) { - cleanFuncOp(funcOp, module, la); + processFuncOp(funcOp, module, la, deadVals, finalCleanupList); } else if (auto regionBranchOp = dyn_cast(op)) { - cleanRegionBranchOp(regionBranchOp, la); + processRegionBranchOp(regionBranchOp, la, deadVals, finalCleanupList); + } else if (auto branchOp = dyn_cast(op)) { + processBranchOp(branchOp, la, deadVals, finalCleanupList); } else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) { // Nothing to do here because this is a terminator op and it should be // honored with respect to its parent @@ -601,9 +761,11 @@ void RemoveDeadValues::runOnOperation() { // Nothing to do because this op is associated with a function op and gets // cleaned when the latter is cleaned. } else { - cleanSimpleOp(op, la); + processSimpleOp(op, la, deadVals, finalCleanupList); } }); + + cleanUpDeadVals(finalCleanupList); } std::unique_ptr mlir::createRemoveDeadValuesPass() { diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index c310954b906e4e500cec11a90c5115dcd32659d6..20a74a235b1551030a4c75155a68d85cb1ff347d 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -400,16 +400,16 @@ func.func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: v // CHECK: %[[T10:.*]] = builtin.unrealized_conversion_cast %[[T3]] : vector<2x3xf32> to !llvm.array<2 x vector<3xf32>> // CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : i64) : i64 // CHECK: %[[T5:.*]] = llvm.extractelement %[[A]]{{\[}}%[[T4]] : i64] : vector<2xf32> +// CHECK: %[[T8:.*]] = llvm.extractvalue %[[T7]][0] : !llvm.array<2 x vector<3xf32>> // CHECK: %[[T6Insert:.*]] = llvm.insertelement %[[T5]] // CHECK: %[[T6:.*]] = llvm.shufflevector %[[T6Insert]] -// CHECK: %[[T8:.*]] = llvm.extractvalue %[[T7]][0] : !llvm.array<2 x vector<3xf32>> // CHECK: %[[T9:.*]] = llvm.intr.fmuladd(%[[T6]], %[[B]], %[[T8]]) : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> vector<3xf32> // CHECK: %[[T11:.*]] = llvm.insertvalue %[[T9]], %[[T10]][0] : !llvm.array<2 x vector<3xf32>> // CHECK: %[[T12:.*]] = llvm.mlir.constant(1 : i64) : i64 // CHECK: %[[T13:.*]] = llvm.extractelement %[[A]]{{\[}}%[[T12]] : i64] : vector<2xf32> +// CHECK: %[[T16:.*]] = llvm.extractvalue %[[T7]][1] : !llvm.array<2 x vector<3xf32>> // CHECK: %[[T14Insert:.*]] = llvm.insertelement %[[T13]] // CHECK: %[[T14:.*]] = llvm.shufflevector %[[T14Insert]] -// CHECK: %[[T16:.*]] = llvm.extractvalue %[[T7]][1] : !llvm.array<2 x vector<3xf32>> // CHECK: %[[T17:.*]] = llvm.intr.fmuladd(%[[T14]], %[[B]], %[[T16]]) : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> vector<3xf32> // CHECK: %[[T18:.*]] = llvm.insertvalue %[[T17]], %[[T11]][1] : !llvm.array<2 x vector<3xf32>> // CHECK: %[[T19:.*]] = builtin.unrealized_conversion_cast %[[T18]] : !llvm.array<2 x vector<3xf32>> to vector<2x3xf32> diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir index 174eb468cc004108d9e6cb428e54502c324600f0..b4d70abf0843a9d08d99d01f0ee2db344aa90037 100644 --- a/mlir/test/Dialect/Arith/expand-ops.mlir +++ b/mlir/test/Dialect/Arith/expand-ops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -arith-expand="include-bf16=true" -split-input-file | FileCheck %s +// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e5m2=true include-f8e4m3fn=true" -split-input-file | FileCheck %s // Test ceil divide with signed integer // CHECK-LABEL: func @ceildivi @@ -310,3 +310,43 @@ func.func @minui(%a: i32, %b: i32) -> i32 { // CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LHS]], %[[RHS]] : i32 // CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32 // CHECK-NEXT: return %[[RESULT]] : i32 + +// ----- + +func.func @extf_vector_f8E5M2_to_f32(%arg0 : vector<4xf8E5M2>) -> vector<4xf32> { + %0 = arith.extf %arg0 : vector<4xf8E5M2> to vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: @extf_vector_f8E5M2_to_f32 +// CHECK-NOT: arith.extf + +// ----- + +func.func @truncf_vector_f32_to_f8E5M2(%arg0 : vector<4xf32>) -> vector<4xf8E5M2> { + %0 = arith.truncf %arg0 : vector<4xf32> to vector<4xf8E5M2> + return %0 : vector<4xf8E5M2> +} + +// CHECK-LABEL: @truncf_vector_f32_to_f8E5M2 +// CHECK-NOT: arith.truncf + +// ----- + +func.func @extf_vector_f8E4M3FN_to_f32(%arg0 : vector<4xf8E4M3FN>) -> vector<4xf32> { + %0 = arith.extf %arg0 : vector<4xf8E4M3FN> to vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: @extf_vector_f8E4M3FN_to_f32 +// CHECK-NOT: arith.extf + +// ----- + +func.func @truncf_vector_f32_to_f8E4M3FN(%arg0 : vector<4xf32>) -> vector<4xf8E4M3FN> { + %0 = arith.truncf %arg0 : vector<4xf32> to vector<4xf8E4M3FN> + return %0 : vector<4xf8E4M3FN> +} + +// CHECK-LABEL: @truncf_vector_f32_to_f8E4M3FN +// CHECK-NOT: arith.truncf diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir index 146e9780b8ebbe552f1a0969ed6f7e06c16c04cf..802de7c335d9b17e59cc93685f7095ec77460282 100644 --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -485,6 +485,48 @@ func.func @variadic_reduce_memref(%input1: memref<16x32x64xf32>, // ----- +func.func @reduce_asymmetric(%input: tensor<16x32x64xi32>, %input2: tensor<16x32x64xi32>, + %init: tensor<16x64xi32>) -> tensor<16x64xi32> { + %reduce = linalg.reduce + ins(%input, %input2:tensor<16x32x64xi32>, tensor<16x32x64xi32>) + outs(%init:tensor<16x64xi32>) + dimensions = [1] + (%in: i32, %in2: i32, %out: i32) { + %0 = arith.muli %in, %in2: i32 + %1 = arith.addi %out, %0: i32 + linalg.yield %1: i32 + } + func.return %reduce : tensor<16x64xi32> +} +// CHECK-LABEL: func @reduce_asymmetric +// CHECK: linalg.reduce ins(%{{.*}}, %{{.*}}: tensor<16x32x64xi32>, tensor<16x32x64xi32>) +// CHECK-NOT: operandSegmentSize +// CHECK-SAME: outs(%{{.*}}: tensor<16x64xi32>) +// CHECK-SAME: dimensions = [1] + +// ----- + +func.func @reduce_asymmetric_memref(%input: memref<16x32x64xi32>, %input2: memref<16x32x64xi32>, + %init: memref<16x64xi32>) { + linalg.reduce + ins(%input, %input2:memref<16x32x64xi32>, memref<16x32x64xi32>) + outs(%init:memref<16x64xi32>) + dimensions = [1] + (%in: i32, %in2: i32, %out: i32) { + %0 = arith.muli %in, %in2: i32 + %1 = arith.addi %out, %0: i32 + linalg.yield %1: i32 + } + func.return +} +// CHECK-LABEL: func @reduce_asymmetric_memref +// CHECK: linalg.reduce ins(%{{.*}}, %{{.*}}: memref<16x32x64xi32>, memref<16x32x64xi32>) +// CHECK-NOT: operandSegmentSize +// CHECK-SAME: outs(%{{.*}}: memref<16x64xi32>) +// CHECK-SAME: dimensions = [1] + +// ----- + func.func @transpose(%input: tensor<16x32x64xf32>, %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> { %transpose = linalg.transpose diff --git a/mlir/test/Dialect/Vector/lower-vectors-sve-enabled.mlir b/mlir/test/Dialect/Vector/lower-vectors-sve-enabled.mlir new file mode 100644 index 0000000000000000000000000000000000000000..10154601402a9aa2de67795677bb0aea9badc4f5 --- /dev/null +++ b/mlir/test/Dialect/Vector/lower-vectors-sve-enabled.mlir @@ -0,0 +1,77 @@ +// RUN: mlir-opt %s --transform-interpreter --canonicalize | FileCheck %s + +// CHECK-LABEL: func @outerproduct4x8 +// CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<8xf32>, %[[C:.*]]: vector<4x8xf32> +// CHECK-DAG: %[[C0:.*]] = vector.extract %[[C]][0] : vector<8xf32> from vector<4x8xf32 +// CHECK-DAG: %[[A0:.*]] = vector.scalable.insert %[[A]]{{.*}} into vector<[4]xf32> +// CHECK-DAG: %[[A1:.*]] = llvm.call_intrinsic "llvm.aarch64.sve.dupq.lane"(%[[A0]], %c0_i64) +// CHECK-DAG: %[[C1:.*]] = vector.scalable.insert %[[C0]],{{.*}} into vector<[4]xf32> +// CHECK-DAG: %[[B0:.*]] = vector.scalable.insert %[[B]],{{.*}} into vector<[4]xf32> +// CHECK: %[[FMLA0:.*]] = llvm.call_intrinsic "llvm.aarch64.sve.fmla.lane"(%[[C1]], %[[B0]], %[[A1]], %c0_i32){{.*}} vector<[4]xf32> +// CHECK-NEXT: %[[R0:.*]] = vector.scalable.extract %[[FMLA0]][0] : vector<8xf32> from vector<[4]xf32> +// CHECK-NEXT: %[[I0:.*]] = vector.insert %[[R0]], %cst [0] : vector<8xf32> into vector<4x8xf32> +// CHECK-DAG: %[[A2:.*]] = vector.scalable.insert %[[A]]{{.*}} into vector<[4]xf32> +// CHECK-DAG: %[[A3:.*]] = llvm.call_intrinsic "llvm.aarch64.sve.dupq.lane"(%[[A2]], %c0_i64) +// CHECK-DAG: %[[C2:.*]] = vector.extract %[[C]][1] : vector<8xf32> from vector<4x8xf32> +// CHECK-DAG: %[[C3:.*]] = vector.scalable.insert %[[C2]],{{.*}} into vector<[4]xf32> +// CHECK-DAG: %[[B1:.*]] = vector.scalable.insert %[[B]],{{.*}} into vector<[4]xf32> +// CHECK: %[[FMLA1:.*]] = llvm.call_intrinsic "llvm.aarch64.sve.fmla.lane"(%[[C3]], %[[B1]], %[[A3]], %c1_i32){{.*}} vector<[4]xf32> +// CHECK-NEXT: %[[R1:.*]] = vector.scalable.extract %[[FMLA1]][0] : vector<8xf32> from vector<[4]xf32> +// CHECK-NEXT: %[[I1:.*]] = vector.insert %[[R1]], %[[I0]] [1] : vector<8xf32> into vector<4x8xf32> +// CHECK-DAG: %[[A4:.*]] = vector.scalable.insert %[[A]]{{.*}} into vector<[4]xf32> +// CHECK-DAG: %[[A5:.*]] = llvm.call_intrinsic "llvm.aarch64.sve.dupq.lane"(%[[A4]], %c0_i64) +// CHECK-DAG: %[[C4:.*]] = vector.extract %[[C]][2] : vector<8xf32> from vector<4x8xf32> +// CHECK-DAG: %[[C5:.*]] = vector.scalable.insert %[[C4]],{{.*}} into vector<[4]xf32> +// CHECK-DAG: %[[B2:.*]] = vector.scalable.insert %[[B]],{{.*}} into vector<[4]xf32> +// CHECK: %[[FMLA2:.*]] = llvm.call_intrinsic "llvm.aarch64.sve.fmla.lane"(%[[C5]], %[[B2]], %[[A5]], %c2_i32){{.*}} vector<[4]xf32> +// CHECK-NEXT: %[[R2:.*]] = vector.scalable.extract %[[FMLA2]][0] : vector<8xf32> from vector<[4]xf32> +// CHECK-NEXT: %[[I2:.*]] = vector.insert %[[R2]], %[[I1]] [2] : vector<8xf32> into vector<4x8xf32> +// CHECK-DAG: %[[A6:.*]] = vector.scalable.insert %[[A]]{{.*}} into vector<[4]xf32> +// CHECK-DAG: %[[A7:.*]] = llvm.call_intrinsic "llvm.aarch64.sve.dupq.lane"(%[[A6]], %c0_i64) +// CHECK-DAG: %[[C6:.*]] = vector.extract %[[C]][3] : vector<8xf32> from vector<4x8xf32> +// CHECK-DAG: %[[C7:.*]] = vector.scalable.insert %[[C6]],{{.*}} into vector<[4]xf32> +// CHECK-DAG: %[[B3:.*]] = vector.scalable.insert %[[B]],{{.*}} into vector<[4]xf32> +// CHECK: %[[FMLA3:.*]] = llvm.call_intrinsic "llvm.aarch64.sve.fmla.lane"(%[[C7]], %[[B3]], %[[A7]], %c3_i32){{.*}} vector<[4]xf32> +// CHECK-NEXT: %[[R3:.*]] = vector.scalable.extract %[[FMLA3]][0] : vector<8xf32> from vector<[4]xf32> +// CHECK-NEXT: %[[I3:.*]] = vector.insert %[[R3]], %[[I2]] [3] : vector<8xf32> into vector<4x8xf32> +// CHECK: return %[[I3]] +func.func @outerproduct4x8(%arg0 : vector<4xf32>, %arg1 : vector<8xf32>, %arg2 : vector<4x8xf32>) ->vector<4x8xf32> { + %0 = vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind} : vector<4xf32>, vector<8xf32> + return %0 : vector<4x8xf32> +} + +// CHECK-LABEL: func @outerproduct2x8 +// CHECK-SAME: %[[A:.*]]: vector<2xf32>, %[[B:.*]]: vector<8xf32>, %[[C:.*]]: vector<2x8xf32> +// CHECK-DAG: %[[C0:.*]] = vector.extract %[[C]][0] : vector<8xf32> from vector<2x8xf32> +// CHECK-DAG: %[[A0:.*]] = vector.scalable.insert %[[A]]{{.*}} into vector<[4]xf32> +// CHECK-DAG: %[[A1:.*]] = llvm.call_intrinsic "llvm.aarch64.sve.dupq.lane"(%[[A0]], %c0_i64) +// CHECK-DAG: %[[C1:.*]] = vector.scalable.insert %[[C0]],{{.*}} into vector<[4]xf32> +// CHECK-DAG: %[[B0:.*]] = vector.scalable.insert %[[B]],{{.*}} into vector<[4]xf32> +// CHECK: %[[FMLA0:.*]] = llvm.call_intrinsic "llvm.aarch64.sve.fmla.lane"(%[[C1]], %[[B0]], %[[A1]], %c0_i32){{.*}} vector<[4]xf32> +// CHECK-NEXT: %[[R0:.*]] = vector.scalable.extract %[[FMLA0]][0] : vector<8xf32> from vector<[4]xf32> +// CHECK-NEXT: %[[I0:.*]] = vector.insert %[[R0]], %cst [0] : vector<8xf32> into vector<2x8xf32> +// CHECK-DAG: %[[A2:.*]] = vector.scalable.insert %[[A]]{{.*}} into vector<[4]xf32> +// CHECK-DAG: %[[A3:.*]] = llvm.call_intrinsic "llvm.aarch64.sve.dupq.lane"(%[[A2]], %c0_i64) +// CHECK-DAG: %[[C2:.*]] = vector.extract %[[C]][1] : vector<8xf32> from vector<2x8xf32> +// CHECK-DAG: %[[C3:.*]] = vector.scalable.insert %[[C2]],{{.*}} into vector<[4]xf32> +// CHECK-DAG: %[[B1:.*]] = vector.scalable.insert %[[B]],{{.*}} into vector<[4]xf32> +// CHECK: %[[FMLA1:.*]] = llvm.call_intrinsic "llvm.aarch64.sve.fmla.lane"(%[[C3]], %[[B1]], %[[A3]], %c1_i32){{.*}} vector<[4]xf32> +// CHECK-NEXT: %[[R1:.*]] = vector.scalable.extract %[[FMLA1]][0] : vector<8xf32> from vector<[4]xf32> +// CHECK-NEXT: %[[I1:.*]] = vector.insert %[[R1]], %[[I0]] [1] : vector<8xf32> into vector<2x8xf32> +// CHECK: return %[[I1]] +func.func @outerproduct2x8(%arg0 : vector<2xf32>, %arg1 : vector<8xf32>, %arg2 : vector<2x8xf32>) ->vector<2x8xf32> { + %0 = vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind} : vector<2xf32>, vector<8xf32> + return %0 : vector<2x8xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %f = transform.structured.match ops{["func.func"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + + transform.apply_patterns to %f { + transform.apply_patterns.vector.lower_outerproduct enableSVE = true + } : !transform.any_op + transform.yield + } +} \ No newline at end of file diff --git a/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir index 059d955f773131ea63a432a2eb18b4201e508287..e5b600f54c0d8398b0e5cb851ce252c5e5946cae 100644 --- a/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir @@ -26,13 +26,13 @@ func.func @outerproduct_noacc(%arg0: vector<2xf32>, // CHECK-SAME: %[[C:.*2]]: vector<2x3xf32> // CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32> // CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xf32> from vector<2x3xf32> +// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32> // CHECK: %[[T3:.*]] = vector.fma %[[T1]], %[[B]], %[[T2]] : vector<3xf32> // CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32> // CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32> -// CHECK: %[[T6:.*]] = vector.splat %[[T5]] : vector<3xf32> // CHECK: %[[T7:.*]] = vector.extract %[[C]][1] : vector<3xf32> from vector<2x3xf32> +// CHECK: %[[T6:.*]] = vector.splat %[[T5]] : vector<3xf32> // CHECK: %[[T8:.*]] = vector.fma %[[T6]], %[[B]], %[[T7]] : vector<3xf32> // CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : vector<3xf32> into vector<2x3xf32> // CHECK: return %[[T9]] : vector<2x3xf32> @@ -69,14 +69,14 @@ func.func @outerproduct_noacc_int(%arg0: vector<2xi32>, // CHECK-SAME: %[[C:.*2]]: vector<2x3xi32> // CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32> // CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xi32> from vector<2x3xi32> +// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32> // CHECK: %[[T3:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32> // CHECK: %[[T4:.*]] = arith.addi %[[T3]], %[[T2]] : vector<3xi32> // CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32> // CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32> -// CHECK: %[[T7:.*]] = vector.splat %[[T6]] : vector<3xi32> // CHECK: %[[T8:.*]] = vector.extract %[[C]][1] : vector<3xi32> from vector<2x3xi32> +// CHECK: %[[T7:.*]] = vector.splat %[[T6]] : vector<3xi32> // CHECK: %[[T9:.*]] = arith.muli %[[T7]], %[[B]] : vector<3xi32> // CHECK: %[[T10:.*]] = arith.addi %[[T9]], %[[T8]] : vector<3xi32> // CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T5]] [1] : vector<3xi32> into vector<2x3xi32> diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul2.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul2.mlir new file mode 100644 index 0000000000000000000000000000000000000000..8ac11c9d350de56a72878eccea760cd27845ae7e --- /dev/null +++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul2.mlir @@ -0,0 +1,62 @@ +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.consumed}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op + + %tiled_linalg_op, %loops:3 = transform.structured.tile_using_for %0 tile_sizes [[4], [4], 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + + transform.structured.vectorize %tiled_linalg_op vector_sizes [[4], [4], 1] : !transform.any_op + + %1 = transform.bufferization.one_shot_bufferize %arg0 {bufferize_function_boundaries = true} : (!transform.any_op) -> !transform.any_op + + %2 = transform.structured.match ops{["func.func"]} in %1 : (!transform.any_op) -> !transform.any_op + + %3 = transform.apply_registered_pass "convert-linalg-to-loops" to %2 : (!transform.any_op) -> !transform.op<"func.func"> + + transform.apply_patterns to %3 { + transform.apply_patterns.vector.lower_masked_transfers + transform.apply_patterns.vector.transfer_permutation_patterns + transform.apply_patterns.vector.reduction_to_contract + } : !transform.op<"func.func"> + + transform.apply_patterns to %3 { + transform.apply_patterns.vector.cast_away_vector_leading_one_dim + transform.apply_patterns.tensor.fold_tensor_subset_ops_into_vector_transfers + transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct" + transform.apply_patterns.vector.lower_masks + transform.apply_patterns.canonicalization + } : !transform.op<"func.func"> + + %5 = transform.structured.match interface{LoopLikeInterface} in %1 : (!transform.any_op) -> !transform.any_op + + transform.apply_licm to %5 : !transform.any_op + + transform.loop.hoist_loop_invariant_subsets %5 : !transform.any_op + + transform.yield + } + transform.named_sequence @arm_sme_lowering_schedule(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op { + transform.lower_to_arm_sme %arg0 : !transform.any_op + transform.yield %arg0 : !transform.any_op + } + transform.named_sequence @lower_to_llvm_schedule(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op { + transform.lower_to_llvm_new %arg0 {enable_arm_sve = true, enable_index_optimizations = true, vscale_range = 0 : i64} : !transform.any_op + transform.yield %arg0 : !transform.any_op + } + transform.named_sequence @__transform_main_next(%arg0: !transform.any_op) { + %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {deduplicate} : (!transform.any_op) -> !transform.any_op + %2 = transform.include @arm_sme_lowering_schedule failures(propagate) (%1) : (!transform.any_op) -> !transform.any_op + %3 = transform.apply_registered_pass "cse" to %2 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %3 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.scf.for_loop_canonicalization + } : !transform.any_op + %4 = transform.structured.match interface{LoopLikeInterface} in %3 : (!transform.any_op) -> !transform.any_op + transform.apply_licm to %4 : !transform.any_op + %5 = transform.structured.match ops{["func.func"]} in %3 : (!transform.any_op) -> !transform.any_op + %6 = transform.structured.hoist_redundant_vector_transfers %5 : (!transform.any_op) -> !transform.any_op + %7 = transform.structured.hoist_redundant_vector_broadcasts %6 : (!transform.any_op) -> !transform.any_op + %8 = transform.apply_registered_pass "canonicalize" to %7 : (!transform.any_op) -> !transform.any_op + transform.yield + } +} diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir index 69426fdb6208329e9b1e1ffecd371d1a56eafc19..fe7bcbc7c490b6f7b2eb9d1496f5d0a4047ffa09 100644 --- a/mlir/test/Transforms/remove-dead-values.mlir +++ b/mlir/test/Transforms/remove-dead-values.mlir @@ -1,34 +1,104 @@ // RUN: mlir-opt %s -remove-dead-values -split-input-file -verify-diagnostics | FileCheck %s -// The IR remains untouched because of the presence of a non-function-like -// symbol op (module @dont_touch_unacceptable_ir). +// The IR is updated regardless of memref.global private constant // -// expected-error @+1 {{cannot optimize an IR with non-function symbol ops, non-call symbol user ops or branch ops}} -module @dont_touch_unacceptable_ir { - func.func @has_cleanable_simple_op(%arg0 : i32) { - %non_live = arith.addi %arg0, %arg0 : i32 - return +module { + // CHECK: memref.global "private" constant @__constant_4xi32 : memref<4xi32> = dense<[1, 2, 3, 4]> {alignment = 16 : i64} + memref.global "private" constant @__constant_4xi32 : memref<4xi32> = dense<[1, 2, 3, 4]> {alignment = 16 : i64} + func.func @main(%arg0: i32) -> i32 { + %0 = tensor.empty() : tensor<10xbf16> + // CHECK-NOT: memref.get_global + %1 = memref.get_global @__constant_4xi32 : memref<4xi32> + // CHECK-NOT: tensor.empty + return %arg0 : i32 + } +} + +// ----- + +// Dead values are removed from the IR even if the module has a name +// +module @named_module_acceptable { + func.func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> { + %0 = tensor.empty() : tensor<10xbf16> + // CHECK-NOT: tensor.empty + return %arg0 : tensor<10xf32> } } // ----- -// The IR remains untouched because of the presence of a branch op `cf.cond_br`. +// The IR contains both conditional and unconditional branches with a loop +// in which the last cf.cond_br is referncing the first cf.br // -func.func @dont_touch_unacceptable_ir_has_cleanable_simple_op_with_branch_op(%arg0: i1) { +func.func @acceptable_ir_has_cleanable_loop_of_conditional_and_branch_op(%arg0: i1) { %non_live = arith.constant 0 : i32 - // expected-error @+1 {{cannot optimize an IR with non-function symbol ops, non-call symbol user ops or branch ops}} - cf.cond_br %arg0, ^bb1(%non_live : i32), ^bb2(%non_live : i32) -^bb1(%non_live_0 : i32): - cf.br ^bb3 -^bb2(%non_live_1 : i32): - cf.br ^bb3 -^bb3: + // CHECK-NOT: arith.constant + cf.br ^bb1(%non_live : i32) + // CHECK: cf.br ^[[BB1:bb[0-9]+]] +^bb1(%non_live_1 : i32): + // CHECK: ^[[BB1]]: + %non_live_5 = arith.constant 1 : i32 + cf.br ^bb3(%non_live_1, %non_live_5 : i32, i32) + // CHECK: cf.br ^[[BB3:bb[0-9]+]] + // CHECK-NOT: i32 +^bb3(%non_live_2 : i32, %non_live_6 : i32): + // CHECK: ^[[BB3]]: + cf.cond_br %arg0, ^bb1(%non_live_2 : i32), ^bb4(%non_live_2 : i32) + // CHECK: cf.cond_br %arg0, ^[[BB1]], ^[[BB4:bb[0-9]+]] +^bb4(%non_live_4 : i32): + // CHECK: ^[[BB4]]: return } // ----- +// Checking that iter_args are properly handled +// +func.func @cleanable_loop_iter_args_value(%arg0: index) -> index { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %non_live = arith.constant 0 : index + // CHECK: [[RESULT:%.+]] = scf.for [[ARG_1:%.*]] = %c0 to %c10 step %c1 iter_args([[ARG_2:%.*]] = %arg0) -> (index) { + %result, %result_non_live = scf.for %i = %c0 to %c10 step %c1 iter_args(%live_arg = %arg0, %non_live_arg = %non_live) -> (index, index) { + // CHECK: [[SUM:%.+]] = arith.addi [[ARG_2]], [[ARG_1]] : index + %new_live = arith.addi %live_arg, %i : index + // CHECK: scf.yield [[SUM:%.+]] + scf.yield %new_live, %non_live_arg : index, index + } + // CHECK: return [[RESULT]] : index + return %result : index +} + +// ----- + +// Checking that the arguments of linalg.generic are properly handled +// All code below is removed as a result of the pass +// +#map = affine_map<(d0, d1, d2) -> (0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +module { + func.func @main() { + %cst_3 = arith.constant dense<54> : tensor<1x25x13xi32> + %cst_7 = arith.constant dense<11> : tensor<1x25x13xi32> + // CHECK-NOT: arith.constant + %0 = tensor.empty() : tensor<1x25x13xi32> + // CHECK-NOT: tensor + %1 = linalg.generic {indexing_maps = [#map, #map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_3, %cst_7 : tensor<1x25x13xi32>, tensor<1x25x13xi32>) outs(%0 : tensor<1x25x13xi32>) { + // CHECK-NOT: linalg.generic + ^bb0(%in: i32, %in_15: i32, %out: i32): + %29 = arith.xori %in, %in_15 : i32 + // CHECK-NOT: arith.xori + linalg.yield %29 : i32 + // CHECK-NOT: linalg.yield + } -> tensor<1x25x13xi32> + return + } +} + +// ----- + // Note that this cleanup cannot be done by the `canonicalize` pass. // // CHECK-LABEL: func.func private @clean_func_op_remove_argument_and_return_value() { @@ -357,3 +427,13 @@ func.func @kernel(%arg0: memref<18xf32>) { // CHECK: gpu.launch blocks // CHECK: memref.store // CHECK-NEXT: gpu.terminator + +// ----- + +// CHECK: func.func private @no_block_func_declaration() +func.func private @no_block_func_declaration() -> () + +// ----- + +// CHECK: llvm.func @no_block_external_func() +llvm.func @no_block_external_func() attributes {sym_visibility = "private"} diff --git a/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp b/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp index 56f309f150ca5d37c06b7947e3c089f7832a1d0f..a4f0fc6b2ff764c3d7916bf17affc3ca2de69d03 100644 --- a/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp +++ b/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp @@ -42,6 +42,7 @@ struct TestDataLayoutQuery uint64_t preferred = layout.getTypePreferredAlignment(op.getType()); uint64_t index = layout.getTypeIndexBitwidth(op.getType()).value_or(0); Attribute endianness = layout.getEndianness(); + Attribute defaultMemorySpace = layout.getDefaultMemorySpace(); Attribute allocaMemorySpace = layout.getAllocaMemorySpace(); Attribute programMemorySpace = layout.getProgramMemorySpace(); Attribute globalMemorySpace = layout.getGlobalMemorySpace(); @@ -68,6 +69,10 @@ struct TestDataLayoutQuery builder.getNamedAttr("endianness", endianness == Attribute() ? builder.getStringAttr("") : endianness), + builder.getNamedAttr("default_memory_space", + defaultMemorySpace == Attribute() + ? builder.getUI32IntegerAttr(0) + : defaultMemorySpace), builder.getNamedAttr("alloca_memory_space", allocaMemorySpace == Attribute() ? builder.getUI32IntegerAttr(0) diff --git a/mlir/test/lib/Dialect/Test/TestFormatUtils.cpp b/mlir/test/lib/Dialect/Test/TestFormatUtils.cpp index 6e75dd393228107c2a9c3817846a7258fb98a454..9ed1b3a47be3690b8ce61feadd4e9153e35b880f 100644 --- a/mlir/test/lib/Dialect/Test/TestFormatUtils.cpp +++ b/mlir/test/lib/Dialect/Test/TestFormatUtils.cpp @@ -297,11 +297,17 @@ void test::printSwitchCases(OpAsmPrinter &p, Operation *op, // CustomUsingPropertyInCustom //===----------------------------------------------------------------------===// -bool test::parseUsingPropertyInCustom(OpAsmParser &parser, int64_t value[3]) { - return parser.parseLSquare() || parser.parseInteger(value[0]) || - parser.parseComma() || parser.parseInteger(value[1]) || - parser.parseComma() || parser.parseInteger(value[2]) || - parser.parseRSquare(); +bool test::parseUsingPropertyInCustom(OpAsmParser &parser, + SmallVector &value) { + auto elemParser = [&]() { + int64_t v = 0; + if (failed(parser.parseInteger(v))) + return failure(); + value.push_back(v); + return success(); + }; + return failed(parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Square, + elemParser)); } void test::printUsingPropertyInCustom(OpAsmPrinter &printer, Operation *op, diff --git a/mlir/test/lib/Dialect/Test/TestFormatUtils.h b/mlir/test/lib/Dialect/Test/TestFormatUtils.h index 7e9cd834278e347c458a050d30005acf32c23d5f..6d4df7d82ffa543539ec86a7f69eded1928d8602 100644 --- a/mlir/test/lib/Dialect/Test/TestFormatUtils.h +++ b/mlir/test/lib/Dialect/Test/TestFormatUtils.h @@ -160,7 +160,8 @@ void printSwitchCases(mlir::OpAsmPrinter &p, mlir::Operation *op, // CustomUsingPropertyInCustom //===----------------------------------------------------------------------===// -bool parseUsingPropertyInCustom(mlir::OpAsmParser &parser, int64_t value[3]); +bool parseUsingPropertyInCustom(mlir::OpAsmParser &parser, + llvm::SmallVector &value); void printUsingPropertyInCustom(mlir::OpAsmPrinter &printer, mlir::Operation *op, diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 9450764fcb1d5b67d2aac6c737dd0d3319fffc5d..70579cf5f3e1a540bbf91355d24e5ddf4452c3e5 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2947,11 +2947,18 @@ def TestVersionedOpC : TEST_Op<"versionedC"> { // Op with a properties struct defined inline. def TestOpWithProperties : TEST_Op<"with_properties"> { - let assemblyFormat = "prop-dict attr-dict"; + let assemblyFormat = [{ + `a` `=` $a `,` + `b` `=` $b `,` + `c` `=` $c `,` + `flag` `=` $flag `,` + `array` `=` $array attr-dict}]; let arguments = (ins - IntProperty<"int64_t">:$a, + I64Property:$a, StrAttr:$b, // Attributes can directly be used here. - ArrayProperty<"int64_t", 4>:$array // example of an array + StringProperty:$c, + BoolProperty:$flag, + IntArrayProperty<"int64_t">:$array // example of an array ); } @@ -2974,7 +2981,7 @@ def TestOpWithPropertiesAndInferredType // Demonstrate how to wrap an existing C++ class named MyPropStruct. def MyStructProperty : Property<"MyPropStruct"> { - let convertToAttribute = "$_storage.asAttribute($_ctxt)"; + let convertToAttribute = "return $_storage.asAttribute($_ctxt);"; let convertFromAttribute = "return MyPropStruct::setFromAttr($_storage, $_attr, $_diag);"; let hashProperty = "$_storage.hash();"; } @@ -2988,14 +2995,14 @@ def TestOpWithWrappedProperties : TEST_Op<"with_wrapped_properties"> { def TestOpUsingPropertyInCustom : TEST_Op<"using_property_in_custom"> { let assemblyFormat = "custom($prop) attr-dict"; - let arguments = (ins ArrayProperty<"int64_t", 3>:$prop); + let arguments = (ins IntArrayProperty<"int64_t">:$prop); } def TestOpUsingPropertyInCustomAndOther : TEST_Op<"using_property_in_custom_and_other"> { let assemblyFormat = "custom($prop) prop-dict attr-dict"; let arguments = (ins - ArrayProperty<"int64_t", 3>:$prop, + IntArrayProperty<"int64_t">:$prop, IntProperty<"int64_t">:$other ); } @@ -3021,7 +3028,7 @@ def TestOpUsingIntPropertyWithWorseBytecode def PropertiesWithCustomPrint : Property<"PropertiesWithCustomPrint"> { let convertToAttribute = [{ - getPropertiesAsAttribute($_ctxt, $_storage) + return getPropertiesAsAttribute($_ctxt, $_storage); }]; let convertFromAttribute = [{ return setPropertiesFromAttribute($_storage, $_attr, $_diag); @@ -3085,7 +3092,7 @@ def TestOpWithNiceProperties : TEST_Op<"with_nice_properties"> { def VersionedProperties : Property<"VersionedProperties"> { let convertToAttribute = [{ - getPropertiesAsAttribute($_ctxt, $_storage) + return getPropertiesAsAttribute($_ctxt, $_storage); }]; let convertFromAttribute = [{ return setPropertiesFromAttribute($_storage, $_attr, $_diag); @@ -3131,13 +3138,65 @@ def TestOpWithVersionedProperties : TEST_Op<"with_versioned_properties"> { } def TestOpWithDefaultValuedProperties : TEST_Op<"with_default_valued_properties"> { - let assemblyFormat = "prop-dict attr-dict"; - let arguments = (ins DefaultValuedAttr:$a); + let assemblyFormat = [{ + ($a^) : (`na`)? + ($b^)? + ($c^)? + ($unit^)? + attr-dict + }]; + let arguments = (ins DefaultValuedAttr:$a, + DefaultValuedProperty:$b, + DefaultValuedProperty, "-1">:$c, + UnitProperty:$unit); } def TestOpWithOptionalProperties : TEST_Op<"with_optional_properties"> { - let assemblyFormat = "prop-dict attr-dict"; - let arguments = (ins OptionalAttr:$a, OptionalAttr:$b); +let assemblyFormat = [{ + (`anAttr` `=` $anAttr^)? + (`simple` `=` $simple^)? + (`nonTrivialStorage` `=` $nonTrivialStorage^)? + (`hasDefault` `=` $hasDefault^)? + (`nested` `=` $nested^)? + (`longSyntax` `=` $longSyntax^)? + (`hasUnit` $hasUnit^)? + (`maybeUnit` `=` $maybeUnit^)? + attr-dict + }]; + let arguments = (ins + OptionalAttr:$anAttr, + OptionalProperty:$simple, + OptionalProperty:$nonTrivialStorage, + // Confirm that properties with default values now default to nullopt and have + // the long syntax. + OptionalProperty>:$hasDefault, + OptionalProperty>:$nested, + OptionalProperty:$longSyntax, + UnitProperty:$hasUnit, + OptionalProperty:$maybeUnit); +} + +def TestOpWithArrayProperties : TEST_Op<"with_array_properties"> { + let assemblyFormat = [{ + `ints` `=` $ints + `strings` `=` $strings + `nested` `=` $nested + `opt` `=` $opt + `explicitOptions` `=` $explicitOptions + `explicitUnits` `=` $explicitUnits + ($hasDefault^ `thats_has_default`)? + attr-dict + }]; + let arguments = (ins + ArrayProperty:$ints, + ArrayProperty:$strings, + ArrayProperty>:$nested, + OptionalProperty>:$opt, + ArrayProperty>:$explicitOptions, + ArrayProperty:$explicitUnits, + DefaultValuedProperty, + "::llvm::ArrayRef{}", "::llvm::SmallVector{}">:$hasDefault + ); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOpsSyntax.td b/mlir/test/lib/Dialect/Test/TestOpsSyntax.td index 3129085058fd965ee0bc7769be796fa496af55c1..795b9da955632c1381e38df6af9e96e9b1905f68 100644 --- a/mlir/test/lib/Dialect/Test/TestOpsSyntax.td +++ b/mlir/test/lib/Dialect/Test/TestOpsSyntax.td @@ -86,6 +86,17 @@ def OIListTrivial : TEST_Op<"oilist_with_keywords_only"> { }]; } +// Ops related to OIList primitive +def OIListTrivialProperties : TEST_Op<"oilist_with_keywords_only_properties"> { + let arguments = (ins UnitProperty:$keyword, UnitProperty:$otherKeyword, + UnitProperty:$diffNameUnitPropertyKeyword); + let assemblyFormat = [{ + oilist( `keyword` $keyword + | `otherKeyword` $otherKeyword + | `thirdKeyword` $diffNameUnitPropertyKeyword) attr-dict + }]; +} + def OIListSimple : TEST_Op<"oilist_with_simple_args", [AttrSizedOperandSegments]> { let arguments = (ins Optional:$arg0, Optional:$arg1, @@ -392,6 +403,17 @@ def FormatOptionalUnitAttrNoElide let assemblyFormat = "($is_optional^)? attr-dict"; } +def FormatOptionalUnitProperty : TEST_Op<"format_optional_unit_property"> { + let arguments = (ins UnitProperty:$is_optional); + let assemblyFormat = "(`is_optional` $is_optional^)? attr-dict"; +} + +def FormatOptionalUnitPropertyNoElide + : TEST_Op<"format_optional_unit_property_no_elide"> { + let arguments = (ins UnitProperty:$is_optional); + let assemblyFormat = "($is_optional^)? attr-dict"; +} + def FormatOptionalEnumAttr : TEST_Op<"format_optional_enum_attr"> { let arguments = (ins OptionalAttr:$attr); let assemblyFormat = "($attr^)? attr-dict"; diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp index 1593b6d7d7534bd35ef6bf417d8c5e96cff26b37..213c2e69329d7cdfcd9e631552f91e452e73a882 100644 --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -284,7 +284,9 @@ TestTypeWithLayoutType::getIndexBitwidth(const DataLayout &dataLayout, } bool TestTypeWithLayoutType::areCompatible( - DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout) const { + DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout, + DataLayoutSpecInterface newSpec, + const DataLayoutIdentifiedEntryMap &map) const { unsigned old = extractKind(oldLayout, "alignment"); return old == 1 || extractKind(newLayout, "alignment") <= old; } diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 0fc750c7bbc88736247acb9ed71a5ce7daa1e4bd..9a172989205857cc5fc2bd52de526643dd259df6 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -155,6 +155,36 @@ static const char *const valueRangeReturnCode = R"( std::next({0}, valueRange.first + valueRange.second)}; )"; +/// Parse operand/result segment_size property. +/// {0}: Number of elements in the segment array +static const char *const parseTextualSegmentSizeFormat = R"( + size_t i = 0; + auto parseElem = [&]() -> ::mlir::ParseResult { + if (i >= {0}) + return $_parser.emitError($_parser.getCurrentLocation(), + "expected `]` after {0} segment sizes"); + if (failed($_parser.parseInteger($_storage[i]))) + return ::mlir::failure(); + i += 1; + return ::mlir::success(); + }; + if (failed($_parser.parseCommaSeparatedList( + ::mlir::AsmParser::Delimeter::Square, parseElem))) + return failure(); + if (i < {0}) + return $_parser.emitError($_parser.getCurrentLocation(), + "expected {0} segment sizes, found only ") << i; + return success(); +)"; + +static const char *const printTextualSegmentSize = R"( + [&]() { + $_printer << '['; + ::llvm::interleaveComma($_storage, $_printer); + $_printer << ']'; + }() +)"; + /// Read operand/result segment_size from bytecode. static const char *const readBytecodeSegmentSizeNative = R"( if ($_reader.getBytecodeVersion() >= /*kNativePropertiesODSSegmentSize=*/6) @@ -422,8 +452,10 @@ private: // Property std::optional operandSegmentsSize; std::string operandSegmentsSizeStorage; + std::string operandSegmentsSizeParser; std::optional resultSegmentsSize; std::string resultSegmentsSizeStorage; + std::string resultSegmentsSizeParser; // Indices to store the position in the emission order of the operand/result // segment sizes attribute if emitted as part of the properties for legacy @@ -448,31 +480,40 @@ void OpOrAdaptorHelper::computeAttrMetadata() { {namedAttr.name, AttributeMetadata{namedAttr.name, !isOptional, attr}}); } - auto makeProperty = [&](StringRef storageType) { + auto makeProperty = [&](StringRef storageType, StringRef parserCall) { return Property( + /*summary=*/"", + /*description=*/"", /*storageType=*/storageType, /*interfaceType=*/"::llvm::ArrayRef", /*convertFromStorageCall=*/"$_storage", /*assignToStorageCall=*/ "::llvm::copy($_value, $_storage.begin())", /*convertToAttributeCall=*/ - "::mlir::DenseI32ArrayAttr::get($_ctxt, $_storage)", + "return ::mlir::DenseI32ArrayAttr::get($_ctxt, $_storage);", /*convertFromAttributeCall=*/ "return convertFromAttribute($_storage, $_attr, $_diag);", + /*parserCall=*/parserCall, + /*optionalParserCall=*/"", + /*printerCall=*/printTextualSegmentSize, /*readFromMlirBytecodeCall=*/readBytecodeSegmentSizeNative, /*writeToMlirBytecodeCall=*/writeBytecodeSegmentSizeNative, /*hashPropertyCall=*/ "::llvm::hash_combine_range(std::begin($_storage), " "std::end($_storage));", - /*StringRef defaultValue=*/""); + /*StringRef defaultValue=*/"", + /*storageTypeValueOverride=*/""); }; // Include key attributes from several traits as implicitly registered. if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { if (op.getDialect().usePropertiesForAttributes()) { operandSegmentsSizeStorage = llvm::formatv("std::array", op.getNumOperands()); - operandSegmentsSize = {"operandSegmentSizes", - makeProperty(operandSegmentsSizeStorage)}; + operandSegmentsSizeParser = + llvm::formatv(parseTextualSegmentSizeFormat, op.getNumOperands()); + operandSegmentsSize = { + "operandSegmentSizes", + makeProperty(operandSegmentsSizeStorage, operandSegmentsSizeParser)}; } else { attrMetadata.insert( {operandSegmentAttrName, AttributeMetadata{operandSegmentAttrName, @@ -484,8 +525,11 @@ void OpOrAdaptorHelper::computeAttrMetadata() { if (op.getDialect().usePropertiesForAttributes()) { resultSegmentsSizeStorage = llvm::formatv("std::array", op.getNumResults()); - resultSegmentsSize = {"resultSegmentSizes", - makeProperty(resultSegmentsSizeStorage)}; + resultSegmentsSizeParser = + llvm::formatv(parseTextualSegmentSizeFormat, op.getNumResults()); + resultSegmentsSize = { + "resultSegmentSizes", + makeProperty(resultSegmentsSizeStorage, resultSegmentsSizeParser)}; } else { attrMetadata.insert( {resultSegmentAttrName, @@ -572,6 +616,12 @@ private: void genPropertiesSupportForBytecode(ArrayRef attrOrProperties); + // Generates getters for the properties. + void genPropGetters(); + + // Generates seters for the properties. + void genPropSetters(); + // Generates getters for the attributes. void genAttrGetters(); @@ -1041,6 +1091,8 @@ OpEmitter::OpEmitter(const Operator &op, genNamedRegionGetters(); genNamedSuccessorGetters(); genPropertiesSupport(); + genPropGetters(); + genPropSetters(); genAttrGetters(); genAttrSetters(); genOptionalAttrRemovers(); @@ -1198,6 +1250,16 @@ void OpEmitter::genAttrNameGetters() { } } +// Emit the getter for a named property. +// It is templated to be shared between the Op and the adaptor class. +template +static void emitPropGetter(OpClassOrAdaptor &opClass, const Operator &op, + StringRef name, const Property &prop) { + auto *method = opClass.addInlineMethod(prop.getInterfaceType(), name); + ERROR_IF_PRUNED(method, name, op); + method->body() << formatv(" return getProperties().{0}();", name); +} + // Emit the getter for an attribute with the return type specified. // It is templated to be shared between the Op and the adaptor class. template @@ -1313,7 +1375,7 @@ void OpEmitter::genPropertiesSupport() { )decl"; const char *propFromAttrFmt = R"decl( auto setFromAttr = [] (auto &propStorage, ::mlir::Attribute propAttr, - ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {{ + ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) -> ::mlir::LogicalResult {{ {0} }; {2}; @@ -1358,7 +1420,10 @@ void OpEmitter::genPropertiesSupport() { .addSubst("_storage", propertyStorage) .addSubst("_diag", propertyDiag)), name, getAttr); - if (prop.hasDefaultValue()) { + if (prop.hasStorageTypeValueOverride()) { + setPropMethod << formatv(attrGetDefaultFmt, name, + prop.getStorageTypeValueOverride()); + } else if (prop.hasDefaultValue()) { setPropMethod << formatv(attrGetDefaultFmt, name, prop.getDefaultValue()); } else { @@ -1409,8 +1474,10 @@ void OpEmitter::genPropertiesSupport() { const char *propToAttrFmt = R"decl( { const auto &propStorage = prop.{0}; - attrs.push_back(odsBuilder.getNamedAttr("{0}", - {1})); + auto attr = [&]() -> ::mlir::Attribute {{ + {1} + }(); + attrs.push_back(odsBuilder.getNamedAttr("{0}", attr)); } )decl"; for (const auto &attrOrProp : attrOrProperties) { @@ -1458,9 +1525,12 @@ void OpEmitter::genPropertiesSupport() { StringRef name = namedProperty->name; auto &prop = namedProperty->prop; FmtContext fctx; - hashMethod << formatv(propHashFmt, name, - tgfmt(prop.getHashPropertyCall(), - &fctx.addSubst("_storage", propertyStorage))); + if (!prop.getHashPropertyCall().empty()) { + hashMethod << formatv( + propHashFmt, name, + tgfmt(prop.getHashPropertyCall(), + &fctx.addSubst("_storage", propertyStorage))); + } } } hashMethod << " return llvm::hash_combine("; @@ -1468,8 +1538,13 @@ void OpEmitter::genPropertiesSupport() { attrOrProperties, hashMethod, [&](const ConstArgument &attrOrProp) { if (const auto *namedProperty = llvm::dyn_cast_if_present(attrOrProp)) { - hashMethod << "\n hash_" << namedProperty->name << "(prop." - << namedProperty->name << ")"; + if (!namedProperty->prop.getHashPropertyCall().empty()) { + hashMethod << "\n hash_" << namedProperty->name << "(prop." + << namedProperty->name << ")"; + } else { + hashMethod << "\n ::llvm::hash_value(prop." + << namedProperty->name << ")"; + } return; } const auto *namedAttr = @@ -1524,8 +1599,9 @@ void OpEmitter::genPropertiesSupport() { "\"{0}\") return ", resultSegmentAttrName); } - getInherentAttrMethod << tgfmt(prop.getConvertToAttributeCall(), &fctx) - << ";\n"; + getInherentAttrMethod << "[&]() -> ::mlir::Attribute { " + << tgfmt(prop.getConvertToAttributeCall(), &fctx) + << " }();\n"; if (name == operandSegmentAttrName) { setInherentAttrMethod @@ -1549,13 +1625,15 @@ void OpEmitter::genPropertiesSupport() { )decl", name); if (name == operandSegmentAttrName) { - populateInherentAttrsMethod - << formatv(" attrs.append(\"{0}\", {1});\n", operandSegmentAttrName, - tgfmt(prop.getConvertToAttributeCall(), &fctx)); + populateInherentAttrsMethod << formatv( + " attrs.append(\"{0}\", [&]() -> ::mlir::Attribute { {1} }());\n", + operandSegmentAttrName, + tgfmt(prop.getConvertToAttributeCall(), &fctx)); } else { - populateInherentAttrsMethod - << formatv(" attrs.append(\"{0}\", {1});\n", resultSegmentAttrName, - tgfmt(prop.getConvertToAttributeCall(), &fctx)); + populateInherentAttrsMethod << formatv( + " attrs.append(\"{0}\", [&]() -> ::mlir::Attribute { {1} }());\n", + resultSegmentAttrName, + tgfmt(prop.getConvertToAttributeCall(), &fctx)); } } getInherentAttrMethod << " return std::nullopt;\n"; @@ -1701,6 +1779,26 @@ void OpEmitter::genPropertiesSupportForBytecode( readPropertiesMethod << " return ::mlir::success();"; } +void OpEmitter::genPropGetters() { + for (const NamedProperty &prop : op.getProperties()) { + std::string name = op.getGetterName(prop.name); + emitPropGetter(opClass, op, name, prop.prop); + } +} + +void OpEmitter::genPropSetters() { + for (const NamedProperty &prop : op.getProperties()) { + std::string name = op.getSetterName(prop.name); + std::string argName = "new" + convertToCamelFromSnakeCase( + prop.name, /*capitalizeFirst=*/true); + auto *method = opClass.addInlineMethod( + "void", name, MethodParameter(prop.prop.getInterfaceType(), argName)); + if (!method) + return; + method->body() << formatv(" getProperties().{0}({1});", name, argName); + } +} + void OpEmitter::genAttrGetters() { FmtContext fctx; fctx.withBuilder("::mlir::Builder((*this)->getContext())"); @@ -2957,6 +3055,12 @@ void OpEmitter::buildParamList(SmallVectorImpl ¶mList, } // Add parameters for all arguments (operands and attributes). + // Track "attr-like" (property and attribute) optional values separate from + // attributes themselves so that the disambiguation code can look at the first + // attribute specifically when determining where to trim the optional-value + // list to avoid ambiguity while preserving the ability of all-property ops to + // use default parameters. + int defaultValuedAttrLikeStartIndex = op.getNumArgs(); int defaultValuedAttrStartIndex = op.getNumArgs(); // Successors and variadic regions go at the end of the parameter list, so no // default arguments are possible. @@ -2967,6 +3071,15 @@ void OpEmitter::buildParamList(SmallVectorImpl ¶mList, for (int i = op.getNumArgs() - 1; i >= 0; --i) { auto *namedAttr = llvm::dyn_cast_if_present(op.getArg(i)); + auto *namedProperty = + llvm::dyn_cast_if_present(op.getArg(i)); + if (namedProperty) { + Property prop = namedProperty->prop; + if (!prop.hasDefaultValue()) + break; + defaultValuedAttrLikeStartIndex = i; + continue; + } if (!namedAttr) break; @@ -2986,6 +3099,7 @@ void OpEmitter::buildParamList(SmallVectorImpl ¶mList, if (retType == "::llvm::APInt" || retType == "::llvm::APFloat") break; + defaultValuedAttrLikeStartIndex = i; defaultValuedAttrStartIndex = i; } } @@ -3001,8 +3115,10 @@ void OpEmitter::buildParamList(SmallVectorImpl ¶mList, if ((attrParamKind == AttrParamKind::WrappedAttr && canUseUnwrappedRawValue(attr)) || (attrParamKind == AttrParamKind::UnwrappedValue && - !canUseUnwrappedRawValue(attr))) + !canUseUnwrappedRawValue(attr))) { ++defaultValuedAttrStartIndex; + defaultValuedAttrLikeStartIndex = defaultValuedAttrStartIndex; + } } /// Collect any inferred attributes. @@ -3029,8 +3145,16 @@ void OpEmitter::buildParamList(SmallVectorImpl ¶mList, operand->isOptional()); continue; } - if (llvm::isa_and_present(arg)) { - // TODO + if (auto *propArg = llvm::dyn_cast_if_present(arg)) { + const Property &prop = propArg->prop; + StringRef type = prop.getInterfaceType(); + std::string defaultValue; + if (prop.hasDefaultValue() && i >= defaultValuedAttrLikeStartIndex) { + defaultValue = prop.getDefaultValue(); + } + bool isOptional = prop.hasDefaultValue(); + paramList.emplace_back(type, propArg->name, StringRef(defaultValue), + isOptional); continue; } const NamedAttribute &namedAttr = *arg.get(); @@ -3157,6 +3281,15 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder( } } + // Push all properties to the result. + for (const auto &namedProp : op.getProperties()) { + // Use the setter from the Properties struct since the conversion from the + // interface type (used in the builder argument) to the storage type (used + // in the state) is not necessarily trivial. + std::string setterName = op.getSetterName(namedProp.name); + body << formatv(" {0}.getOrAddProperties().{1}({2});\n", + builderOpState, setterName, namedProp.name); + } // Push all attributes to the result. for (const auto &namedAttr : op.getAttributes()) { auto &attr = namedAttr.attr; @@ -3996,17 +4129,19 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter( // Generate the data member using the storage type. os << " using " << name << "Ty = " << prop.getStorageType() << ";\n" << " " << name << "Ty " << name; - if (prop.hasDefaultValue()) + if (prop.hasStorageTypeValueOverride()) + os << " = " << prop.getStorageTypeValueOverride(); + else if (prop.hasDefaultValue()) os << " = " << prop.getDefaultValue(); comparatorOs << " rhs." << name << " == this->" << name << " &&\n"; // Emit accessors using the interface type. const char *accessorFmt = R"decl(; - {0} get{1}() { + {0} get{1}() const { auto &propStorage = this->{2}; return {3}; } - void set{1}(const {0} &propValue) { + void set{1}({0} propValue) { auto &propStorage = this->{2}; {4}; } @@ -4274,6 +4409,11 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter( ERROR_IF_PRUNED(m, "Adaptor::getAttributes", op); m->body() << " return odsAttrs;"; } + for (auto &namedProp : op.getProperties()) { + std::string name = op.getGetterName(namedProp.name); + emitPropGetter(genericAdaptorBase, op, name, namedProp.prop); + } + for (auto &namedAttr : op.getAttributes()) { const auto &name = namedAttr.name; const auto &attr = namedAttr.attr; @@ -4564,4 +4704,4 @@ static mlir::GenRegistration genOpDefs("gen-op-defs", "Generate op definitions", [](const RecordKeeper &records, raw_ostream &os) { return emitOpDefs(records, os); - }); + }); \ No newline at end of file diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp index a97d8760842a98aa929cd1439ae3ae1d6a4ce319..2129c4325c0c0ed1b94bb6675ae2534920fffe15 100644 --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -45,7 +45,7 @@ public: OpVariableElement(const VarT *var) : var(var) {} /// Get the variable. - const VarT *getVar() { return var; } + const VarT *getVar() const { return var; } protected: /// The op variable, e.g. a type or attribute constraint. @@ -64,11 +64,6 @@ struct AttributeVariable return attrType ? attrType->getBuilderCall() : std::nullopt; } - /// Return if this attribute refers to a UnitAttr. - bool isUnitAttr() const { - return var->attr.getBaseAttr().getAttrDefName() == "UnitAttr"; - } - /// Indicate if this attribute is printed "qualified" (that is it is /// prefixed with the `#dialect.mnemonic`). bool shouldBeQualified() { return shouldBeQualifiedFlag; } @@ -98,6 +93,42 @@ using SuccessorVariable = /// This class represents a variable that refers to a property argument. using PropertyVariable = OpVariableElement; + +/// LLVM RTTI helper for attribute-like variables, that is, attributes or +/// properties. This allows for common handling of attributes and properties in +/// parts of the code that are oblivious to whether something is stored as an +/// attribute or a property. +struct AttributeLikeVariable : public VariableElement { + enum { AttributeLike = 1 << 0 }; + + static bool classof(const VariableElement *ve) { + return ve->getKind() == VariableElement::Attribute || + ve->getKind() == VariableElement::Property; + } + + static bool classof(const FormatElement *fe) { + return isa(fe) && classof(cast(fe)); + } + + /// Returns true if the variable is a UnitAttr or a UnitProperty. + bool isUnit() const { + if (const auto *attr = dyn_cast(this)) + return attr->getVar()->attr.getBaseAttr().getAttrDefName() == "UnitAttr"; + if (const auto *prop = dyn_cast(this)) { + return prop->getVar()->prop.getBaseProperty().getPropertyDefName() == + "UnitProperty"; + } + llvm_unreachable("Type that wasn't listed in classof()"); + } + + StringRef getName() const { + if (const auto *attr = dyn_cast(this)) + return attr->getVar()->name; + if (const auto *prop = dyn_cast(this)) + return prop->getVar()->name; + llvm_unreachable("Type that wasn't listed in classof()"); + } +}; } // namespace //===----------------------------------------------------------------------===// @@ -214,11 +245,11 @@ public: /// If the parsing element is a single UnitAttr element, then it returns the /// attribute variable. Otherwise, returns nullptr. - AttributeVariable * - getUnitAttrParsingElement(ArrayRef pelement) { + AttributeLikeVariable * + getUnitVariableParsingElement(ArrayRef pelement) { if (pelement.size() == 1) { - auto *attrElem = dyn_cast(pelement[0]); - if (attrElem && attrElem->isUnitAttr()) + auto *attrElem = dyn_cast(pelement[0]); + if (attrElem && attrElem->isUnit()) return attrElem; } return nullptr; @@ -488,6 +519,36 @@ const char *const enumAttrParserCode = R"( } )"; +/// The code snippet used to generate a parser call for a property. +/// {0}: The name of the property +/// {1}: The C++ class name of the operation +/// {2}: The property's parser code with appropriate substitutions performed +/// {3}: The description of the expected property for the error message. +const char *const propertyParserCode = R"( + auto {0}PropLoc = parser.getCurrentLocation(); + auto {0}PropParseResult = [&](auto& propStorage) -> ::mlir::ParseResult {{ + {2} + return ::mlir::success(); + }(result.getOrAddProperties<{1}::Properties>().{0}); + if (failed({0}PropParseResult)) {{ + return parser.emitError({0}PropLoc, "invalid value for property {0}, expected {3}"); + } +)"; + +/// The code snippet used to generate a parser call for a property. +/// {0}: The name of the property +/// {1}: The C++ class name of the operation +/// {2}: The property's parser code with appropriate substitutions performed +const char *const optionalPropertyParserCode = R"( + auto {0}PropParseResult = [&](auto& propStorage) -> ::mlir::OptionalParseResult {{ + {2} + return ::mlir::success(); + }(result.getOrAddProperties<{1}::Properties>().{0}); + if ({0}PropParseResult.has_value() && failed(*{0}PropParseResult)) {{ + return ::mlir::failure(); + } +)"; + /// The code snippet used to generate a parser call for an operand. /// /// {0}: The name of the operand. @@ -796,9 +857,9 @@ static void genElementParserStorage(FormatElement *element, const Operator &op, // If the anchor is a unit attribute, it won't be parsed directly so elide // it. - auto *anchor = dyn_cast(optional->getAnchor()); + auto *anchor = dyn_cast(optional->getAnchor()); FormatElement *elidedAnchorElement = nullptr; - if (anchor && anchor != elements.front() && anchor->isUnitAttr()) + if (anchor && anchor != elements.front() && anchor->isUnit()) elidedAnchorElement = anchor; for (FormatElement *childElement : elements) if (childElement != elidedAnchorElement) @@ -808,7 +869,7 @@ static void genElementParserStorage(FormatElement *element, const Operator &op, } else if (auto *oilist = dyn_cast(element)) { for (ArrayRef pelement : oilist->getParsingElements()) { - if (!oilist->getUnitAttrParsingElement(pelement)) + if (!oilist->getUnitVariableParsingElement(pelement)) for (FormatElement *element : pelement) genElementParserStorage(element, op, body); } @@ -1049,7 +1110,6 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body, body << llvm::formatv(" result.addAttribute(\"{0}\", {0}Attr);\n", var->name); } - } else if (auto *operand = dyn_cast(param)) { const NamedTypeConstraint *var = operand->getVar(); if (var->isOptional()) { @@ -1137,6 +1197,29 @@ static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body, validCaseKeywordsStr, errorMessage, attrAssignment); } +// Generate the parser for a property. +static void genPropertyParser(PropertyVariable *propVar, MethodBody &body, + StringRef opCppClassName, + bool requireParse = true) { + StringRef name = propVar->getVar()->name; + const Property &prop = propVar->getVar()->prop; + bool parseOptionally = + prop.hasDefaultValue() && !requireParse && prop.hasOptionalParser(); + FmtContext fmtContext; + fmtContext.addSubst("_parser", "parser"); + fmtContext.addSubst("_ctxt", "parser.getContext()"); + fmtContext.addSubst("_storage", "propStorage"); + + if (parseOptionally) { + body << formatv(optionalPropertyParserCode, name, opCppClassName, + tgfmt(prop.getOptionalParserCall(), &fmtContext)); + } else { + body << formatv(propertyParserCode, name, opCppClassName, + tgfmt(prop.getParserCall(), &fmtContext), + prop.getSummary()); + } +} + // Generate the parser for an attribute. static void genAttrParser(AttributeVariable *attr, MethodBody &body, FmtContext &attrTypeCtx, bool parseAsOptional, @@ -1213,14 +1296,16 @@ if (!dict) { } )decl"; - // TODO: properties might be optional as well. + // {0}: fromAttribute call + // {1}: property name + // {2}: isRequired const char *propFromAttrFmt = R"decl( auto setFromAttr = [] (auto &propStorage, ::mlir::Attribute propAttr, - ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {{ + ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) -> ::mlir::LogicalResult {{ {0}; }; auto attr = dict.get("{1}"); -if (!attr) {{ +if (!attr && {2}) {{ emitError() << "expected key entry for {1} in DictionaryAttr to set " "Properties."; return ::mlir::failure(); @@ -1238,13 +1323,14 @@ if (::mlir::failed(setFromAttr(prop.{1}, attr, emitError))) StringRef name = namedProperty.name; const Property &prop = namedProperty.prop; + bool isRequired = !prop.hasDefaultValue(); FmtContext fctx; body << formatv(propFromAttrFmt, tgfmt(prop.getConvertFromAttributeCall(), &fctx.addSubst("_attr", "propAttr") .addSubst("_storage", "propStorage") .addSubst("_diag", "emitError")), - name); + name, isRequired); } // Generate the setter for any attribute not parsed elsewhere. @@ -1331,20 +1417,24 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body, // If the anchor is a unit attribute, we don't need to print it. When // parsing, we will add this attribute if this group is present. FormatElement *elidedAnchorElement = nullptr; - auto *anchorAttr = dyn_cast(optional->getAnchor()); - if (anchorAttr && anchorAttr != firstElement && - anchorAttr->isUnitAttr()) { - elidedAnchorElement = anchorAttr; + auto *anchorVar = dyn_cast(optional->getAnchor()); + if (anchorVar && anchorVar != firstElement && anchorVar->isUnit()) { + elidedAnchorElement = anchorVar; if (!thenGroup == optional->isInverted()) { - // Add the anchor unit attribute to the operation state. - if (useProperties) { + // Add the anchor unit attribute or property to the operation state + // or set the property to true. + if (isa(anchorVar)) { + body << formatv( + " result.getOrAddProperties<{1}::Properties>().{0} = true;", + anchorVar->getName(), opCppClassName); + } else if (useProperties) { body << formatv( " result.getOrAddProperties<{1}::Properties>().{0} = " "parser.getBuilder().getUnitAttr();", - anchorAttr->getVar()->name, opCppClassName); + anchorVar->getName(), opCppClassName); } else { - body << " result.addAttribute(\"" << anchorAttr->getVar()->name + body << " result.addAttribute(\"" << anchorVar->getName() << "\", parser.getBuilder().getUnitAttr());\n"; } } @@ -1368,6 +1458,12 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body, genAttrParser(attrVar, body, attrTypeCtx, /*parseAsOptional=*/true, useProperties, opCppClassName); body << " if (" << attrVar->getVar()->name << "Attr) {\n"; + } else if (auto *propVar = dyn_cast(firstElement)) { + genPropertyParser(propVar, body, opCppClassName, /*requireParse=*/false); + body << llvm::formatv("if ({0}PropParseResult.has_value() && " + "succeeded(*{0}PropParseResult)) ", + propVar->getVar()->name) + << " {\n"; } else if (auto *literal = dyn_cast(firstElement)) { body << " if (::mlir::succeeded(parser.parseOptional"; genLiteralParser(literal->getSpelling(), body); @@ -1430,15 +1526,19 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body, body << ")) {\n"; StringRef lelementName = lelement->getSpelling(); body << formatv(oilistParserCode, lelementName); - if (AttributeVariable *unitAttrElem = - oilist->getUnitAttrParsingElement(pelement)) { - if (useProperties) { + if (AttributeLikeVariable *unitVarElem = + oilist->getUnitVariableParsingElement(pelement)) { + if (isa(unitVarElem)) { + body << formatv( + " result.getOrAddProperties<{1}::Properties>().{0} = true;", + unitVarElem->getName(), opCppClassName); + } else if (useProperties) { body << formatv( " result.getOrAddProperties<{1}::Properties>().{0} = " "parser.getBuilder().getUnitAttr();", - unitAttrElem->getVar()->name, opCppClassName); + unitVarElem->getName(), opCppClassName); } else { - body << " result.addAttribute(\"" << unitAttrElem->getVar()->name + body << " result.addAttribute(\"" << unitVarElem->getName() << "\", UnitAttr::get(parser.getContext()));\n"; } } else { @@ -1468,6 +1568,8 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body, (genCtx == GenContext::Normal && attr->getVar()->attr.isOptional()); genAttrParser(attr, body, attrTypeCtx, parseAsOptional, useProperties, opCppClassName); + } else if (auto *prop = dyn_cast(element)) { + genPropertyParser(prop, body, opCppClassName); } else if (auto *operand = dyn_cast(element)) { ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar()); @@ -1876,6 +1978,38 @@ const char *enumAttrBeginPrinterCode = R"( auto caseValueStr = {1}(caseValue); )"; +/// Generate a check that an optional or default-valued attribute or property +/// has a non-default value. For these purposes, the default value of an +/// optional attribute is its presence, even if the attribute itself has a +/// default value. +static void genNonDefaultValueCheck(MethodBody &body, const Operator &op, + AttributeVariable &attrElement) { + Attribute attr = attrElement.getVar()->attr; + std::string getter = op.getGetterName(attrElement.getVar()->name); + bool optionalAndDefault = attr.isOptional() && attr.hasDefaultValue(); + if (optionalAndDefault) + body << "("; + if (attr.isOptional()) + body << getter << "Attr()"; + if (optionalAndDefault) + body << " && "; + if (attr.hasDefaultValue()) { + FmtContext fctx; + fctx.withBuilder("::mlir::OpBuilder((*this)->getContext())"); + body << getter << "Attr() != " + << tgfmt(attr.getConstBuilderTemplate(), &fctx, + attr.getDefaultValue()); + } + if (optionalAndDefault) + body << ")"; +} + +static void genNonDefaultValueCheck(MethodBody &body, const Operator &op, + PropertyVariable &propElement) { + body << op.getGetterName(propElement.getVar()->name) + << "() != " << propElement.getVar()->prop.getDefaultValue(); +} + /// Generate the printer for the 'prop-dict' directive. static void genPropDictPrinter(OperationFormat &fmt, Operator &op, MethodBody &body) { @@ -1904,6 +2038,15 @@ static void genPropDictPrinter(OperationFormat &fmt, Operator &op, body << " }\n"; } } + // Similarly, elide default-valued properties. + for (const NamedProperty &prop : op.getProperties()) { + if (prop.prop.hasDefaultValue()) { + body << " if (" << op.getGetterName(prop.name) + << "() == " << prop.prop.getDefaultValue() << ") {"; + body << " elidedProps.push_back(\"" << prop.name << "\");\n"; + body << " }\n"; + } + } body << " _odsPrinter << \" \";\n" << " printProperties(this->getContext(), _odsPrinter, " @@ -2031,7 +2174,6 @@ static void genCustomDirectiveParameterPrinter(FormatElement *element, } else if (auto *property = dyn_cast(element)) { FmtContext ctx; - ctx.addSubst("_ctxt", "getContext()"); const NamedProperty *namedProperty = property->getVar(); ctx.addSubst("_storage", "getProperties()." + namedProperty->name); body << tgfmt(namedProperty->prop.getConvertFromStorageCall(), &ctx); @@ -2154,16 +2296,6 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op, " }\n"; } -/// Generate a check that a DefaultValuedAttr has a value that is non-default. -static void genNonDefaultValueCheck(MethodBody &body, const Operator &op, - AttributeVariable &attrElement) { - FmtContext fctx; - Attribute attr = attrElement.getVar()->attr; - fctx.withBuilder("::mlir::OpBuilder((*this)->getContext())"); - body << " && " << op.getGetterName(attrElement.getVar()->name) << "Attr() != " - << tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()); -} - /// Generate the check for the anchor of an optional group. static void genOptionalGroupPrinterAnchor(FormatElement *anchor, const Operator &op, @@ -2190,17 +2322,12 @@ static void genOptionalGroupPrinterAnchor(FormatElement *anchor, genOptionalGroupPrinterAnchor(element->getInputs(), op, body); }) .Case([&](AttributeVariable *element) { - Attribute attr = element->getVar()->attr; - body << op.getGetterName(element->getVar()->name) << "Attr()"; - if (attr.isOptional()) - return; // done - if (attr.hasDefaultValue()) { - // Consider a default-valued attribute as present if it's not the - // default value. - genNonDefaultValueCheck(body, op, *element); - return; - } - llvm_unreachable("attribute must be optional or default-valued"); + // Consider a default-valued attribute as present if it's not the + // default value and an optional one present if it is set. + genNonDefaultValueCheck(body, op, *element); + }) + .Case([&](PropertyVariable *element) { + genNonDefaultValueCheck(body, op, *element); }) .Case([&](CustomDirective *ele) { body << '('; @@ -2276,10 +2403,10 @@ void OperationFormat::genElementPrinter(FormatElement *element, ArrayRef thenElements = optional->getThenElements(); ArrayRef elseElements = optional->getElseElements(); FormatElement *elidedAnchorElement = nullptr; - auto *anchorAttr = dyn_cast(anchor); + auto *anchorAttr = dyn_cast(anchor); if (anchorAttr && anchorAttr != thenElements.front() && (elseElements.empty() || anchorAttr != elseElements.front()) && - anchorAttr->isUnitAttr()) { + anchorAttr->isUnit()) { elidedAnchorElement = anchorAttr; } auto genElementPrinters = [&](ArrayRef elements) { @@ -2319,13 +2446,13 @@ void OperationFormat::genElementPrinter(FormatElement *element, for (VariableElement *var : vars) { TypeSwitch(var) .Case([&](AttributeVariable *attrEle) { - body << " || (" << op.getGetterName(attrEle->getVar()->name) - << "Attr()"; - Attribute attr = attrEle->getVar()->attr; - if (attr.hasDefaultValue()) { - // Don't print default-valued attributes. - genNonDefaultValueCheck(body, op, *attrEle); - } + body << " || ("; + genNonDefaultValueCheck(body, op, *attrEle); + body << ")"; + }) + .Case([&](PropertyVariable *propEle) { + body << " || ("; + genNonDefaultValueCheck(body, op, *propEle); body << ")"; }) .Case([&](OperandVariable *ele) { @@ -2352,7 +2479,7 @@ void OperationFormat::genElementPrinter(FormatElement *element, body << ") {\n"; genLiteralPrinter(lelement->getSpelling(), body, shouldEmitSpace, lastWasPunctuation); - if (oilist->getUnitAttrParsingElement(pelement) == nullptr) { + if (oilist->getUnitVariableParsingElement(pelement) == nullptr) { for (FormatElement *element : pelement) genElementPrinter(element, body, op, shouldEmitSpace, lastWasPunctuation); @@ -2369,7 +2496,7 @@ void OperationFormat::genElementPrinter(FormatElement *element, return; } - // Emit the attribute dictionary. + // Emit the property dictionary. if (isa(element)) { genPropDictPrinter(*this, op, body); lastWasPunctuation = false; @@ -2408,6 +2535,13 @@ void OperationFormat::genElementPrinter(FormatElement *element, else body << "_odsPrinter.printStrippedAttrOrType(" << op.getGetterName(var->name) << "Attr());\n"; + } else if (auto *property = dyn_cast(element)) { + const NamedProperty *var = property->getVar(); + FmtContext fmtContext; + fmtContext.addSubst("_printer", "_odsPrinter"); + fmtContext.addSubst("_ctxt", "getContext()"); + fmtContext.addSubst("_storage", "getProperties()." + var->name); + body << tgfmt(var->prop.getPrinterCall(), &fmtContext) << ";\n"; } else if (auto *operand = dyn_cast(element)) { if (operand->getVar()->isVariadicOfVariadic()) { body << " ::llvm::interleaveComma(" @@ -2737,6 +2871,10 @@ static bool isOptionallyParsed(FormatElement *el) { Attribute attr = attrVar->getVar()->attr; return attr.isOptional() || attr.hasDefaultValue(); } + if (auto *propVar = dyn_cast(el)) { + const Property &prop = propVar->getVar()->prop; + return prop.hasDefaultValue() && prop.hasOptionalParser(); + } if (auto *operandVar = dyn_cast(el)) { const NamedTypeConstraint *operand = operandVar->getVar(); return operand->isOptional() || operand->isVariadic() || @@ -3141,10 +3279,9 @@ OpFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) { } if (const NamedProperty *property = findArg(op.getProperties(), name)) { - if (ctx != CustomDirectiveContext && ctx != RefDirectiveContext) + if (ctx == TypeDirectiveContext) return emitError( - loc, "properties currently only supported in `custom` directive"); - + loc, "properties cannot be used as children to a `type` directive"); if (ctx == RefDirectiveContext) { if (!seenProperties.count(property)) return emitError(loc, "property '" + name + @@ -3428,6 +3565,15 @@ LogicalResult OpFormatParser::verifyOIListParsingElement(FormatElement *element, "an oilist parsing group"); return success(); }) + // Only optional properties can be within an oilist parsing group. + .Case([&](PropertyVariable *propEle) { + if (!propEle->getVar()->prop.hasDefaultValue()) + return emitError( + loc, + "only default-valued or optional properties can be used in " + "an olist parsing group"); + return success(); + }) // Only optional-like(i.e. variadic) operands can be within an // oilist parsing group. .Case([&](OperandVariable *ele) { @@ -3557,6 +3703,16 @@ LogicalResult OpFormatParser::verifyOptionalGroupElement(SMLoc loc, "can be used to anchor an optional group"); return success(); }) + // All properties can be within the optional group, but only optional + // properties can be the anchor. + .Case([&](PropertyVariable *propEle) { + Property prop = propEle->getVar()->prop; + if (isAnchor && !(prop.hasDefaultValue() && prop.hasOptionalParser())) + return emitError(loc, "only properties with default values " + "that can be optionally parsed " + "can be used to anchor an optional group"); + return success(); + }) // Only optional-like(i.e. variadic) operands can be within an optional // group. .Case([&](OperandVariable *ele) { @@ -3649,4 +3805,4 @@ void mlir::tblgen::generateOpFormat(const Operator &constOp, OpClass &opClass) { // Generate the printer and parser based on the parsed format. format.genParser(op, opClass); format.genPrinter(op, opClass); -} +} \ No newline at end of file diff --git a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp index d1227b045d4ed3ba79b47f7a400c586dd36bbe95..c7350c218a186bc812a5e4ce5c6639f9e1330010 100644 --- a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp +++ b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp @@ -23,6 +23,8 @@ using namespace mlir; namespace { constexpr static llvm::StringLiteral kAttrName = "dltest.layout"; constexpr static llvm::StringLiteral kEndiannesKeyName = "dltest.endianness"; +constexpr static llvm::StringLiteral kDefaultKeyName = + "dltest.default_memory_space"; constexpr static llvm::StringLiteral kAllocaKeyName = "dltest.alloca_memory_space"; constexpr static llvm::StringLiteral kProgramKeyName = @@ -83,6 +85,9 @@ struct CustomDataLayoutSpec StringAttr getAllocaMemorySpaceIdentifier(MLIRContext *context) const { return Builder(context).getStringAttr(kAllocaKeyName); } + StringAttr getDefaultMemorySpaceIdentifier(MLIRContext *context) const { + return Builder(context).getStringAttr(kDefaultKeyName); + } StringAttr getProgramMemorySpaceIdentifier(MLIRContext *context) const { return Builder(context).getStringAttr(kProgramKeyName); } @@ -201,6 +206,15 @@ struct SingleQueryType return Attribute(); } + Attribute getDefaultMemorySpace(DataLayoutEntryInterface entry) { + static bool executed = false; + if (executed) + llvm::report_fatal_error("repeated call"); + + executed = true; + return Attribute(); + } + Attribute getProgramMemorySpace(DataLayoutEntryInterface entry) { static bool executed = false; if (executed) @@ -458,6 +472,7 @@ module {} EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 2u); EXPECT_EQ(layout.getEndianness(), Attribute()); + EXPECT_EQ(layout.getDefaultMemorySpace(), Attribute()); EXPECT_EQ(layout.getAllocaMemorySpace(), Attribute()); EXPECT_EQ(layout.getProgramMemorySpace(), Attribute()); EXPECT_EQ(layout.getGlobalMemorySpace(), Attribute()); @@ -490,6 +505,7 @@ TEST(DataLayout, NullSpec) { EXPECT_EQ(layout.getTypeIndexBitwidth(IndexType::get(&ctx)), 64u); EXPECT_EQ(layout.getEndianness(), Attribute()); + EXPECT_EQ(layout.getDefaultMemorySpace(), Attribute()); EXPECT_EQ(layout.getAllocaMemorySpace(), Attribute()); EXPECT_EQ(layout.getProgramMemorySpace(), Attribute()); EXPECT_EQ(layout.getGlobalMemorySpace(), Attribute()); @@ -530,6 +546,7 @@ TEST(DataLayout, EmptySpec) { EXPECT_EQ(layout.getTypeIndexBitwidth(IndexType::get(&ctx)), 64u); EXPECT_EQ(layout.getEndianness(), Attribute()); + EXPECT_EQ(layout.getDefaultMemorySpace(), Attribute()); EXPECT_EQ(layout.getAllocaMemorySpace(), Attribute()); EXPECT_EQ(layout.getProgramMemorySpace(), Attribute()); EXPECT_EQ(layout.getGlobalMemorySpace(), Attribute()); @@ -552,6 +569,7 @@ TEST(DataLayout, SpecWithEntries) { #dlti.dl_entry, #dlti.dl_entry, #dlti.dl_entry<"dltest.endianness", "little">, + #dlti.dl_entry<"dltest.default_memory_space", 1 : i32>, #dlti.dl_entry<"dltest.alloca_memory_space", 5 : i32>, #dlti.dl_entry<"dltest.program_memory_space", 3 : i32>, #dlti.dl_entry<"dltest.global_memory_space", 2 : i32>, @@ -588,6 +606,7 @@ TEST(DataLayout, SpecWithEntries) { EXPECT_EQ(layout.getTypePreferredAlignment(Float32Type::get(&ctx)), 64u); EXPECT_EQ(layout.getEndianness(), Builder(&ctx).getStringAttr("little")); + EXPECT_EQ(layout.getDefaultMemorySpace(), Builder(&ctx).getI32IntegerAttr(1)); EXPECT_EQ(layout.getAllocaMemorySpace(), Builder(&ctx).getI32IntegerAttr(5)); EXPECT_EQ(layout.getProgramMemorySpace(), Builder(&ctx).getI32IntegerAttr(3)); EXPECT_EQ(layout.getGlobalMemorySpace(), Builder(&ctx).getI32IntegerAttr(2));