diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h index 7a18a834410d927bcdf31ba6db4f69e8e9e384c7..21bcd4f3bf612b96f77c72dc6ead25bef1e1d79f 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -826,6 +826,8 @@ public: bool isProfitableToHoist(Instruction *I) const; bool isProfitableToLoopVersioning() const; + + bool isProfitableToLoopPrefetch() const; bool useAA() const; @@ -1798,6 +1800,7 @@ public: virtual bool isTruncateFree(Type *Ty1, Type *Ty2) = 0; virtual bool isProfitableToHoist(Instruction *I) = 0; virtual bool isProfitableToLoopVersioning() = 0; + virtual bool isProfitableToLoopPrefetch() = 0; virtual bool useAA() = 0; virtual bool isTypeLegal(Type *Ty) = 0; virtual unsigned getRegUsageForType(Type *Ty) = 0; @@ -2291,6 +2294,9 @@ public: bool isProfitableToLoopVersioning() override { return Impl.isProfitableToLoopVersioning(); } + bool isProfitableToLoopPrefetch() override { + return Impl.isProfitableToLoopPrefetch(); + } bool useAA() override { return Impl.useAA(); } bool isTypeLegal(Type *Ty) override { return Impl.isTypeLegal(Ty); } unsigned getRegUsageForType(Type *Ty) override { diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h index 14e635127b59d4f8814c5526bd5b1c888a3e1019..a345d337e56633f769a4fc8271969291e0b43f84 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -327,6 +327,8 @@ public: bool isProfitableToLoopVersioning() const { return false; } + bool isProfitableToLoopPrefetch() const { return false; } + bool useAA() const { return false; } bool isTypeLegal(Type *Ty) const { return false; } diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h index b8b74f6ab278b1e3a1d3eadbd75bdd22b33fe62b..bcc217414a4a14c97694f03dedc6d00db4b98d04 100644 --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -409,6 +409,8 @@ public: bool isProfitableToLoopVersioning() const { return false; } + bool isProfitableToLoopPrefetch() const { return false; } + bool useAA() const { return getST()->useAA(); } bool isTypeLegal(Type *Ty) { diff --git a/llvm/include/llvm/Transforms/Scalar/LoopIterationPrefetchBefore.h b/llvm/include/llvm/Transforms/Scalar/LoopIterationPrefetchBefore.h new file mode 100644 index 0000000000000000000000000000000000000000..4e19782b03912d06d24fc155cf1ba0255e0c8aff --- /dev/null +++ b/llvm/include/llvm/Transforms/Scalar/LoopIterationPrefetchBefore.h @@ -0,0 +1,35 @@ +//===------ LoopIterationPrefetchBefore.h - Loop Iteration Prefetch Before Pass ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===---------------------------------------------------------------------------------===// +// +//===---------------------------------------------------------------------------------===// +// +/// \file +/// This file provides the interfaces of Loop Iteration Prefetch Before Pass. +// +//===---------------------------------------------------------------------------------===// +#ifndef LLVM_TRANSFORMS_SCALAR_LLVM_LOOPITERATIONPREFETCH_BEFORE_H +#define LLVM_TRANSFORMS_SCALAR_LLVM_LOOPITERATIONPREFETCH_BEFORE_H + +#include "llvm/IR/PassManager.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Transforms/Utils.h" + +namespace llvm { + +extern cl::opt EnableLoopIterationPrefetchBefore; + +class LoopIterationPrefetchBefore : public PassInfoMixin { + static bool shouldRunOnFunction(Function &F); + +public: + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); +}; + +} // namespace llvm + +#endif LLVM_TRANSFORMS_SCALAR_LLVM_LOOPITERATIONPREFETCH_BEFORE_H diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp index 5b91415bcb378c73ae154335f8c4d89df6753974..a19cdd8b4debfafb8b2581e712283e03f4521692 100644 --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -523,6 +523,10 @@ bool TargetTransformInfo::isProfitableToLoopVersioning() const { return TTIImpl->isProfitableToLoopVersioning(); } +bool TargetTransformInfo::isProfitableToLoopPrefetch() const { + return TTIImpl->isProfitableToLoopPrefetch(); +} + bool TargetTransformInfo::useAA() const { return TTIImpl->useAA(); } bool TargetTransformInfo::isTypeLegal(Type *Ty) const { diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp index 5368db256cd909fae2f4aa117a1de83daa207c6f..778d3f0d2c026e431c9ec2dd721c67741fdf4c13 100644 --- a/llvm/lib/Passes/PassBuilder.cpp +++ b/llvm/lib/Passes/PassBuilder.cpp @@ -277,6 +277,8 @@ #include "llvm/Transforms/Instrumentation/AI4CAnalysis.h" #endif +#include "llvm/Transforms/Scalar/LoopIterationPrefetchBefore.h" + using namespace llvm; static const Regex DefaultAliasRegex( diff --git a/llvm/lib/Passes/PassBuilderPipelines.cpp b/llvm/lib/Passes/PassBuilderPipelines.cpp index c9f3512da32b22083c835fb8eeecfafa4a38ed37..19c9ec6d19b837e1cf12150c6d17b8e9f97ab14e 100644 --- a/llvm/lib/Passes/PassBuilderPipelines.cpp +++ b/llvm/lib/Passes/PassBuilderPipelines.cpp @@ -98,6 +98,7 @@ #include "llvm/Transforms/Scalar/LoopIdiomRecognize.h" #include "llvm/Transforms/Scalar/LoopInstSimplify.h" #include "llvm/Transforms/Scalar/LoopInterchange.h" +#include "llvm/Transforms/Scalar/LoopIterationPrefetchBefore.h" #include "llvm/Transforms/Scalar/LoopLoadElimination.h" #include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Scalar/LoopRotation.h" @@ -1254,6 +1255,9 @@ PassBuilder::buildModuleSimplificationPipeline(OptimizationLevel Level, if (EnableSyntheticCounts && !PGOOpt) MPM.addPass(SyntheticCountsPropagation()); + if (EnableLoopIterationPrefetchBefore) + MPM.addPass(createModuleToFunctionPassAdaptor(LoopIterationPrefetchBefore())); + #if defined(ENABLE_AUTOTUNER) #if defined(ENABLE_ACPO) if (!PGOOpt && EnableACPOBWModel) diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def index 7dfa618b4506c13a73e679afddbad918ef6ba7d5..d86fecc63a3b7ebb04004236d74b6f8f9d40fbd5 100644 --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -370,6 +370,7 @@ FUNCTION_PASS("guard-widening", GuardWideningPass()) FUNCTION_PASS("load-store-vectorizer", LoadStoreVectorizerPass()) FUNCTION_PASS("loop-simplify", LoopSimplifyPass()) FUNCTION_PASS("loop-sink", LoopSinkPass()) +FUNCTION_PASS("loop-iteration-prefetch-before", LoopIterationPrefetchBefore()) FUNCTION_PASS("lowerinvoke", LowerInvokePass()) FUNCTION_PASS("lowerswitch", LowerSwitchPass()) FUNCTION_PASS("mem2reg", PromotePass()) diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index e7927352a5ca7887d242f8ed2962a759331f4a9c..e2ed21f856b46389a84f1440e2c7c268c138dcc3 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -466,6 +466,14 @@ bool AArch64TTIImpl::isProfitableToLoopVersioning() const { return ST->isHiSiliconProc() || ForceEnableExperimentalOpt; } +bool AArch64TTIImpl::isProfitableToLoopPrefetch() const { + // Prove to work well for HiSilicon Processors. + // You can experimentally enable optimization by option + // -mllvm -force-enable-experimental-optimization if you + // want to test it on other platforms. + return ST->isHiSiliconProc() || ForceEnableExperimentalOpt; +} + InstructionCost AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA, TTI::TargetCostKind CostKind) { diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h index be62dabaa8c82966bb83f51e8c2c3f44f631a76a..5eed5bbd4db90ad69eb5ce30097731efef9f620a 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -98,6 +98,8 @@ public: bool isProfitableToLoopVersioning() const; + bool isProfitableToLoopPrefetch() const; + /// @} /// \name Vector TTI Implementations diff --git a/llvm/lib/Transforms/Scalar/CMakeLists.txt b/llvm/lib/Transforms/Scalar/CMakeLists.txt index e5a82ea8f923ff5527626d1ff900ebbc565f5b29..d3060f02374b6b6bf30a763104ad79a0d67c2a32 100644 --- a/llvm/lib/Transforms/Scalar/CMakeLists.txt +++ b/llvm/lib/Transforms/Scalar/CMakeLists.txt @@ -36,6 +36,7 @@ add_llvm_component_library(LLVMScalarOpts LoopIdiomRecognize.cpp LoopInstSimplify.cpp LoopInterchange.cpp + LoopIterationPrefetchBefore.cpp LoopFlatten.cpp LoopLoadElimination.cpp LoopPassManager.cpp diff --git a/llvm/lib/Transforms/Scalar/LoopIterationPrefetchBefore.cpp b/llvm/lib/Transforms/Scalar/LoopIterationPrefetchBefore.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6407d90eeadf318d75293bd80d31a221f3902f51 --- /dev/null +++ b/llvm/lib/Transforms/Scalar/LoopIterationPrefetchBefore.cpp @@ -0,0 +1,392 @@ +//===------ LoopIterationPrefetchBefore.cpp - Loop Iteration Prefetch Before Pass ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===-----------------------------------------------------------------------------------===// +// +//===-----------------------------------------------------------------------------------===// +// +// This file implements a Loop Iteration Prefetch Pass. +// The pass aims to generate a prefetch pattern before the loops of a function. +// The function name is provided in a command-line argument. In this prefetch +// pattern, the prefetching calls are inserted to prefetch all data loaded in +// the loop. +// +//===-----------------------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/LoopIterationPrefetchBefore.h" + +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/IntrinsicsAArch64.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/LoopSimplify.h" +#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" + +#include +#include + +#define DEBUG_TYPE "loop-iteration-prefetch-before" + +using namespace llvm; + +namespace llvm { + +using LoopIterPrefBfArgType = std::string; +struct LoopIterPrefBfArgParser : public cl::parser { + explicit LoopIterPrefBfArgParser(cl::Option &O) + : cl::parser(O) {} + bool parse(cl::Option &O, StringRef ArgName, StringRef ArgValue, + LoopIterPrefBfArgType &Val) { + StringRef Func = ArgValue.trim(); + if (Func.empty()) + return O.error("Invalid argument '" + ArgValue + + "'. The input value should not be empty."); + Val = Func.str(); + return false; + } +}; + +cl::opt EnableLoopIterationPrefetchBefore( + "enable-loop-iteration-prefetch-before", cl::init(false), cl::ReallyHidden, + cl::desc("Enable Loop Iteration Prefetch Before Pass")); + +static std::vector LoopIterationPrefetchBeforeArgs; +static cl::list, + LoopIterPrefBfArgParser> + LoopIterationPrefetchBeforeCmdArgs( + "loop-iteration-prefetch-before-funcs", + cl::desc("Specify function for prefetching before"), cl::CommaSeparated, + cl::Hidden, cl::location(LoopIterationPrefetchBeforeArgs), + cl::callback([](const LoopIterPrefBfArgType &) { + EnableLoopIterationPrefetchBefore = true; + })); + +bool LoopIterationPrefetchBefore::shouldRunOnFunction(Function &F) { + static std::unordered_set M( + LoopIterationPrefetchBeforeArgs.begin(), + LoopIterationPrefetchBeforeArgs.end()); + + auto I = M.find(F.getName().str()); + if (I == M.end()) + return false; + M.erase(I); // Ensure each function is processed only once. + return true; +} + +struct LoopIterPrefBfInst { + Instruction *MemI; + Value *PtrValue; + Value *BaseValue; + const SCEVAddRecExpr *LSCEVAddRec; + LoopIterPrefBfInst(Instruction *I, Value *Val, const SCEVAddRecExpr *E) + : MemI{I}, PtrValue{Val}, BaseValue{nullptr}, LSCEVAddRec{E} {} +}; + +class LoopIterPrefBfRecord { + Loop *L; + SmallVector Prefetches; + Instruction *PrefInsertPt; + Value *PrefLoopIV; + + void initializePrefetches(Function &F, FunctionAnalysisManager &AM); + +public: + bool LoopReady; + LoopIterPrefBfRecord(Loop *LP, Function &F, FunctionAnalysisManager &AM) + : L{LP}, PrefInsertPt{nullptr}, PrefLoopIV{nullptr} { + initializePrefetches(F, AM); + } + + bool empty() const { return Prefetches.empty(); } + void emitPrefetchBasicBlocksBeforeLoop(Function &F, + FunctionAnalysisManager &AM); + void emitPrefetchCallsBeforeLoop(Function &F, FunctionAnalysisManager &AM); +}; + +void LoopIterPrefBfRecord::initializePrefetches(Function &F, + FunctionAnalysisManager &AM) { + ScalarEvolution &SE = AM.getResult(F); + TargetTransformInfo &TTI = AM.getResult(F); + // Assume the cache-line size is 64B when it is not given in TTI. + int64_t CacheLineSize = + TTI.getCacheLineSize() > 0 ? TTI.getCacheLineSize() : 64; + + // Find all load instructions that should be prefetched in the loop and + // populate the `Prefetches` member. + for (const auto BB : L->blocks()) { + for (auto &I : *BB) { + LoadInst *MemI = dyn_cast(&I); + Value *PtrValue = nullptr; + if (MemI != nullptr) { + PtrValue = MemI->getPointerOperand(); + } else { + IntrinsicInst *MemIIntrinsic = dyn_cast(&I); + if (MemIIntrinsic == nullptr) + continue; + switch (MemIIntrinsic->getIntrinsicID()) { + case Intrinsic::aarch64_neon_ld1x2: + case Intrinsic::aarch64_neon_ld1x3: + case Intrinsic::aarch64_neon_ld1x4: + case Intrinsic::aarch64_neon_ld2: + case Intrinsic::aarch64_neon_ld3: + case Intrinsic::aarch64_neon_ld4: + case Intrinsic::aarch64_neon_ld2lane: + case Intrinsic::aarch64_neon_ld3lane: + case Intrinsic::aarch64_neon_ld4lane: + case Intrinsic::aarch64_neon_ld2r: + case Intrinsic::aarch64_neon_ld3r: + case Intrinsic::aarch64_neon_ld4r: + case Intrinsic::masked_load: // sve load + { + PtrValue = MemIIntrinsic->getArgOperand(0); + break; + } + default: + continue; + } + } + + unsigned PtrAddrSpace = PtrValue->getType()->getPointerAddressSpace(); + if (!TTI.shouldPrefetchAddressSpace(PtrAddrSpace)) + continue; + if (L->isLoopInvariant(PtrValue)) + continue; + + const SCEV *LSCEV = SE.getSCEV(PtrValue); + const SCEVAddRecExpr *LSCEVAddRec = dyn_cast(LSCEV); + if (LSCEVAddRec == nullptr) + continue; + + bool DupPref = false; + for (auto &Pref : Prefetches) { + const SCEV *PtrDiff = SE.getMinusSCEV(LSCEVAddRec, Pref.LSCEVAddRec); + if (const SCEVConstant *ConstPtrDiff = + dyn_cast(PtrDiff)) { + int64_t PD = std::abs(ConstPtrDiff->getValue()->getSExtValue()); + if (PD < CacheLineSize) { + DupPref = true; + break; + } + } + } + if (!DupPref) + Prefetches.push_back(LoopIterPrefBfInst(MemI, PtrValue, LSCEVAddRec)); + } + } +} + +void LoopIterPrefBfRecord::emitPrefetchBasicBlocksBeforeLoop( + Function &F, FunctionAnalysisManager &AM) { + ScalarEvolution &SE = AM.getResult(F); + TargetTransformInfo &TTI = AM.getResult(F); + // Assume the cache-line size is 64B when it is not given in TTI. + int64_t CacheLineSize = + TTI.getCacheLineSize() > 0 ? TTI.getCacheLineSize() : 64; + + LoopReady = false; + + const DataLayout &DL = F.getParent()->getDataLayout(); + Type *FloatTy = Type::getFloatTy(F.getContext()); + uint64_t FloatByteLength = DL.getTypeAllocSize(FloatTy); + + BasicBlock *PreHeader = L->getLoopPreheader(); + SCEVExpander SCEVE(SE, PreHeader->getModule()->getDataLayout(), + "pref.loopIV"); + const SCEV *Step = nullptr; + // Find base address for each MemI + for (auto &P : Prefetches) { + P.BaseValue = P.PtrValue; + const SCEV *LSCEV = SE.getSCEV(P.BaseValue); + const SCEVAddRecExpr *LSCEVAddRec = dyn_cast(LSCEV); + const SCEV *Start = LSCEVAddRec->getStart(); + // If the Start address is not loop invariant, then we forgive this + // instruction + if (!SE.isLoopInvariant(Start, L)) { + P.BaseValue = nullptr; + continue; + } + P.BaseValue = SCEVE.expandCodeFor(Start, Start->getType(), + PreHeader->getTerminator()); + + // Assume the first step is the real step. + if (Step == nullptr) + Step = LSCEVAddRec->getStepRecurrence(SE); + } + + // Find LoopBound + Value *LoopBound = nullptr; + LLVMContext &Context = F.getContext(); + const SCEV *BackedgeTakenCount = SE.getBackedgeTakenCount(L); + if (!isa(BackedgeTakenCount) && + SE.isLoopInvariant(BackedgeTakenCount, L)) { + // In our new loop, we will prefetch "cachelinesize" bytes data per loop. + // Our "new loopStep" = cachelinesize / FloatTySize since we get data as + // float, so "new loopStep" means we get this much "float data" per loop. + // Then our new loopBound should be the total "float data" we need to + // prefetch which is "Total bytes" / FloatTySize. Step * BackedgeTakenCount + // = "Total bytes" in original loop. As a result, our loopBound is Step * + // BackedgeTakenCount / FloatTySize + const SCEV *NextLSCEV = + SE.getUDivExpr(SE.getMulExpr(Step, BackedgeTakenCount), + SE.getConstant(Type::getInt64Ty(Context), + DL.getTypeAllocSize(FloatTy))); + + LoopBound = SCEVE.expandCodeFor(NextLSCEV, Type::getInt64Ty(Context), + PreHeader->getTerminator()); + } else { + return; + } + + // Create and insert loop BB + BasicBlock *Header = L->getHeader(); + + BasicBlock *NewLoopBB = BasicBlock::Create(Context, "", &F, Header); + BasicBlock *NewPreheadBB = BasicBlock::Create(Context, "", &F, Header); + BranchInst *PreheaderTermBranch = + cast(PreHeader->getTerminator()); + PreheaderTermBranch->setSuccessor(0, NewLoopBB); + + IRBuilder<> Builder(NewLoopBB); + + for (PHINode &PHI : Header->phis()) { + for (unsigned i = 0, nums = PHI.getNumIncomingValues(); i < nums; i++) { + if (PHI.getIncomingBlock(i) == PreHeader) { + PHI.setIncomingBlock(i, NewPreheadBB); + } + } + } + + Builder.SetInsertPoint(NewLoopBB); + + PHINode *IndVar = Builder.CreatePHI(Type::getInt64Ty(Context), 2); + + uint64_t LoopStep = (uint64_t)CacheLineSize / FloatByteLength; + if (LoopStep == 0) + LoopStep = 1; + auto *IndVarNext = Builder.CreateAdd( + IndVar, ConstantInt::get(Type::getInt64Ty(Context), LoopStep), + IndVar->getName() + ".next", /*HasNUW=*/true, /*HasNSW=*/true); + + Value *LoopCond = Builder.CreateICmpULT(IndVarNext, LoopBound); + Builder.CreateCondBr(LoopCond, NewLoopBB, NewPreheadBB); + + IndVar->addIncoming(ConstantInt::get(Type::getInt64Ty(Context), 0), + PreHeader); + IndVar->addIncoming(IndVarNext, NewLoopBB); + + IRBuilder<> BuilderForPrehead(NewPreheadBB); + BuilderForPrehead.CreateBr(Header); + + PrefInsertPt = NewLoopBB->getFirstNonPHI(); + PrefLoopIV = IndVar; + LoopReady = true; + return; +} + +void LoopIterPrefBfRecord::emitPrefetchCallsBeforeLoop( + Function &F, FunctionAnalysisManager &AM) { + IRBuilder<> Builder(F.getContext()); + for (auto &P : Prefetches) { + if (P.BaseValue == nullptr) + continue; + Builder.SetInsertPoint(PrefInsertPt); + Type *FloatTy = Type::getFloatTy(F.getContext()); + Value *PtrForLoad = Builder.CreateGEP(FloatTy, P.BaseValue, PrefLoopIV); + + Function *PrefetchFunc = Intrinsic::getDeclaration( + F.getParent(), Intrinsic::prefetch, PtrForLoad->getType()); + + Type *I32 = Type::getInt32Ty(F.getContext()); + Builder.CreateCall(PrefetchFunc, + {PtrForLoad, ConstantInt::get(I32, 0), + ConstantInt::get(I32, 3), ConstantInt::get(I32, 1)}); + } + return; +} + +PreservedAnalyses +LoopIterationPrefetchBefore::run(Function &F, FunctionAnalysisManager &AM) { + TargetTransformInfo &TTI = AM.getResult(F); + if (!TTI.isProfitableToLoopPrefetch()) + return PreservedAnalyses::all(); + + if (!shouldRunOnFunction(F)) + return PreservedAnalyses::all(); + + LLVM_DEBUG(dbgs() << "Try LoopIterationPrefetchBefore on Function " + << F.getName() << "\n"); + + DominatorTree &DT = AM.getResult(F); + LoopInfo &LI = AM.getResult(F); + ScalarEvolution &SE = AM.getResult(F); + AssumptionCache &AC = AM.getResult(F); + + // Ensure loops are in simplifed form as a pre-requisite for this pass + bool Changed = false; + for (Loop *L : LI) + Changed |= + simplifyLoop(L, &DT, &LI, &SE, &AC, nullptr, false /* PreserveLCSSA */); + + // Insert the skeleton of the basic blocks for prefetching + SmallVector LoopIterationPrefetches; + for (Loop *I : LI) { + for (Loop *L : depth_first(I)) { + LLVM_DEBUG(dbgs() << "Try LoopIterationPrefetchBefore on Loop " + << L->getName() << "\n"); + + if (!L->isInnermost()) { + LLVM_DEBUG(dbgs() << L->getName() << " is not an innermost loop\n"); + continue; + } + + if (!L->isLoopSimplifyForm()) { + LLVM_DEBUG(dbgs() << L->getName() << " is not loop simplify form\n"); + continue; + } + + // Find all load intructions that should be prefetched + LoopIterPrefBfRecord PrefRecord(L, F, AM); + if (PrefRecord.empty()) { + LLVM_DEBUG(dbgs() << L->getName() << " has no prefetchable loads\n"); + continue; + } + LoopIterationPrefetches.push_back(PrefRecord); + } + } + + if (LoopIterationPrefetches.empty()) + return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); + + for (auto &Record : LoopIterationPrefetches) + Record.emitPrefetchBasicBlocksBeforeLoop(F, AM); + + AM.invalidate(F, PreservedAnalyses::none()); + + for (auto &Record : LoopIterationPrefetches) { + if (Record.LoopReady) + Record.emitPrefetchCallsBeforeLoop(F, AM); + } + + return PreservedAnalyses::none(); +} + +} // namespace llvm diff --git a/llvm/test/Transforms/LoopIterationPrefetchBefore/basic-neon.ll b/llvm/test/Transforms/LoopIterationPrefetchBefore/basic-neon.ll new file mode 100644 index 0000000000000000000000000000000000000000..c895ab5be954e6d9d2bee2187a0679b3ed92016a --- /dev/null +++ b/llvm/test/Transforms/LoopIterationPrefetchBefore/basic-neon.ll @@ -0,0 +1,88 @@ +; REQUIRES: aarch64-registered-target +; RUN: opt -mtriple aarch64-linue-gnu -passes=loop-iteration-prefetch-before \ +; RUN: -force-enable-experimental-optimization=true \ +; RUN: -loop-iteration-prefetch-before-funcs=foo -S < %s | FileCheck %s + +define float @foo(ptr noundef %x, i64 noundef %d) { +; CHECK-LABEL: define float @foo +; CHECK: while.body.lr.ph: +; CHECK-NEXT: %add.ptr = getelementptr inbounds float, ptr %x, i64 %d +; CHECK-NEXT: [[TMP0:%.*]] = add i64 %d, -16 +; CHECK-NEXT: [[TMP1:%.*]] = lshr i64 [[TMP0]], 4 +; CHECK-NEXT: [[TMP2:%.*]] = shl i64 [[TMP1]], 6 +; CHECK-NEXT: [[TMP3:%.*]] = lshr i64 [[TMP2]], 2 +; CHECK-NEXT: br label %[[TMP4:.*]] + +; CHECK: [[TMP4]]: +; CHECK-NEXT: [[TMP5:%.*]] = phi i64 [ 0, %while.body.lr.ph ], [ [[NEXT:%.*]], %[[TMP4]] ] +; CHECK-NEXT: [[TMP6:%.*]] = getelementptr float, ptr %x, i64 [[TMP5]] +; CHECK-NEXT: call void @llvm.prefetch.p0(ptr [[TMP6]], i32 0, i32 3, i32 1) +; CHECK-NEXT: [[NEXT]] = add nuw nsw i64 [[TMP5]], 16 +; CHECK-NEXT: [[TMP7:%.*]] = icmp ult i64 [[NEXT]], [[TMP3]] +; CHECK-NEXT: br i1 [[TMP7]], label %[[TMP4]], label %[[TMP8:.*]] + +; CHECK: [[TMP8]]: +; CHECK-NEXT: br label %while.body + +; CHECK: while.body: +; CHECK-NEXT: %d.addr.024 = phi i64 [ %d, %[[TMP8]] ], [ %sub, %while.body ] +; CHECK-NEXT: %sum_.023 = phi <4 x float> [ zeroinitializer, %[[TMP8]] ], [ %add.i, %while.body ] +; CHECK-NEXT: %idx.neg = sub i64 0, %d.addr.024 +; CHECK-NEXT: %add.ptr1 = getelementptr inbounds float, ptr %add.ptr, i64 %idx.neg +; CHECK-NEXT: %vld1xN = tail call { <4 x float>, <4 x float>, <4 x float>, <4 x float> } @llvm.aarch64.neon.ld1x4.v4f32.p0(ptr nonnull %add.ptr1) +; CHECK-NEXT: %vld1xN.fca.0.extract = extractvalue { <4 x float>, <4 x float>, <4 x float>, <4 x float> } %vld1xN, 0 +; CHECK-NEXT: %vld1xN.fca.1.extract = extractvalue { <4 x float>, <4 x float>, <4 x float>, <4 x float> } %vld1xN, 1 +; CHECK-NEXT: %vld1xN.fca.2.extract = extractvalue { <4 x float>, <4 x float>, <4 x float>, <4 x float> } %vld1xN, 2 +; CHECK-NEXT: %vld1xN.fca.3.extract = extractvalue { <4 x float>, <4 x float>, <4 x float>, <4 x float> } %vld1xN, 3 +; CHECK-NEXT: %add.i21 = fadd contract <4 x float> %sum_.023, %vld1xN.fca.0.extract +; CHECK-NEXT: %add.i20 = fadd contract <4 x float> %vld1xN.fca.1.extract, %add.i21 +; CHECK-NEXT: %add.i19 = fadd contract <4 x float> %vld1xN.fca.2.extract, %add.i20 +; CHECK-NEXT: %add.i = fadd contract <4 x float> %vld1xN.fca.3.extract, %add.i19 +; CHECK-NEXT: %sub = add i64 %d.addr.024, -16 +; CHECK-NEXT: %cmp = icmp ugt i64 %sub, 15 +; CHECK-NEXT: br i1 %cmp, label %while.body, label %[[LOOPEXIT:.*]] + +; CHECK: [[LOOPEXIT]]: +; CHECK-NEXT: br label %while.end + +; CHECK: while.end: +; CHECK-NEXT: %sum_.0.lcssa = phi <4 x float> [ zeroinitializer, %entry ], [ %add.i, %[[LOOPEXIT]] ] +; CHECK-NEXT: %vaddvq_f32.i = tail call contract float @llvm.aarch64.neon.faddv.f32.v4f32(<4 x float> %sum_.0.lcssa) +; CHECK-NEXT: ret float %vaddvq_f32.i + +entry: + %cmp22 = icmp ugt i64 %d, 15 + br i1 %cmp22, label %while.body.lr.ph, label %while.end + +while.body.lr.ph: ; preds = %entry + %add.ptr = getelementptr inbounds float, ptr %x, i64 %d + br label %while.body + +while.body: ; preds = %while.body.lr.ph, %while.body + %d.addr.024 = phi i64 [ %d, %while.body.lr.ph ], [ %sub, %while.body ] + %sum_.023 = phi <4 x float> [ zeroinitializer, %while.body.lr.ph ], [ %add.i, %while.body ] + %idx.neg = sub i64 0, %d.addr.024 + %add.ptr1 = getelementptr inbounds float, ptr %add.ptr, i64 %idx.neg + %vld1xN = tail call { <4 x float>, <4 x float>, <4 x float>, <4 x float> } @llvm.aarch64.neon.ld1x4.v4f32.p0(ptr nonnull %add.ptr1) + %vld1xN.fca.0.extract = extractvalue { <4 x float>, <4 x float>, <4 x float>, <4 x float> } %vld1xN, 0 + %vld1xN.fca.1.extract = extractvalue { <4 x float>, <4 x float>, <4 x float>, <4 x float> } %vld1xN, 1 + %vld1xN.fca.2.extract = extractvalue { <4 x float>, <4 x float>, <4 x float>, <4 x float> } %vld1xN, 2 + %vld1xN.fca.3.extract = extractvalue { <4 x float>, <4 x float>, <4 x float>, <4 x float> } %vld1xN, 3 + %add.i21 = fadd contract <4 x float> %sum_.023, %vld1xN.fca.0.extract + %add.i20 = fadd contract <4 x float> %vld1xN.fca.1.extract, %add.i21 + %add.i19 = fadd contract <4 x float> %vld1xN.fca.2.extract, %add.i20 + %add.i = fadd contract <4 x float> %vld1xN.fca.3.extract, %add.i19 + %sub = add i64 %d.addr.024, -16 + %cmp = icmp ugt i64 %sub, 15 + br i1 %cmp, label %while.body, label %while.end + +while.end: ; preds = %while.body, %entry + %sum_.0.lcssa = phi <4 x float> [ zeroinitializer, %entry ], [ %add.i, %while.body ] + %vaddvq_f32.i = tail call contract float @llvm.aarch64.neon.faddv.f32.v4f32(<4 x float> %sum_.0.lcssa) + ret float %vaddvq_f32.i +} + +declare { <4 x float>, <4 x float>, <4 x float>, <4 x float> } @llvm.aarch64.neon.ld1x4.v4f32.p0(ptr) + +declare float @llvm.aarch64.neon.faddv.f32.v4f32(<4 x float>) + diff --git a/llvm/test/Transforms/LoopIterationPrefetchBefore/basic-sve.ll b/llvm/test/Transforms/LoopIterationPrefetchBefore/basic-sve.ll new file mode 100644 index 0000000000000000000000000000000000000000..3fa381321ef4513c0ae5603339a43b2fb3967c8d --- /dev/null +++ b/llvm/test/Transforms/LoopIterationPrefetchBefore/basic-sve.ll @@ -0,0 +1,81 @@ +; REQUIRES: aarch64-registered-target +; RUN: opt -mtriple aarch64-linux-gnu -passes=loop-iteration-prefetch-before \ +; RUN: -force-enable-experimental-optimization=true \ +; RUN: -loop-iteration-prefetch-before-funcs=foo -S < %s | FileCheck %s + +define float @foo(ptr nocapture noundef readonly %x, i64 noundef %d) { +; CHECK-LABEL: define float @foo +; CHECK: entry: +; CHECK-NEXT: %cmp5.not = icmp eq i64 %d, 0 +; CHECK-NEXT: br i1 %cmp5.not, label %for.cond.cleanup, label %[[PREHEADER:.*]] + +; CHECK: [[PREHEADER]]: +; CHECK-NEXT: [[TMP0:%.*]] = add i64 %d, -1 +; CHECK-NEXT: [[TMP1:%.*]] = lshr i64 [[TMP0]], 2 +; CHECK-NEXT: [[TMP2:%.*]] = shl i64 [[TMP1]], 4 +; CHECK-NEXT: [[TMP3:%.*]] = lshr i64 [[TMP2]], 2 +; CHECK-NEXT: br label %[[TMP6:.*]] + +; CHECK: [[LOOPEXIT:.*]]: +; CHECK-NEXT: br label %for.cond.cleanup + +; CHECK: for.cond.cleanup: +; CHECK-NEXT: %sum.0.lcssa = phi [ zeroinitializer, %entry ], [ [[TMP13:%.*]], %[[LOOPEXIT]] ] +; CHECK-NEXT: [[TMP4:%.*]] = tail call @llvm.aarch64.sve.ptrue.nxv4i1(i32 31) +; CHECK-NEXT: [[TMP5:%.*]] = tail call contract float @llvm.aarch64.sve.faddv.nxv4f32( [[TMP4]], %sum.0.lcssa) +; CHECK-NEXT: ret float [[TMP5]] + +; CHECK: [[TMP6]]: +; CHECK-NEXT: [[TMP7:%.*]] = phi i64 [ 0, %[[PREHEADER]] ], [ [[NEXT:%.*]], %[[TMP6]] ] +; CHECK-NEXT: [[TMP8:%.*]] = getelementptr float, ptr %x, i64 [[TMP7]] +; CHECK-NEXT: call void @llvm.prefetch.p0(ptr [[TMP8]], i32 0, i32 3, i32 1) +; CHECK-NEXT: [[NEXT]] = add nuw nsw i64 [[TMP7]], 16 +; CHECK-NEXT: [[TMP9:%.*]] = icmp ult i64 [[NEXT]], [[TMP3]] +; CHECK-NEXT: br i1 [[TMP9]], label %[[TMP6]], label %[[TMP10:.*]] + +; CHECK: [[TMP10]]: +; CHECK-NEXT: br label %for.body + +; CHECK: for.body: +; CHECK-NEXT: %sum.07 = phi [ [[TMP13]], %for.body ], [ zeroinitializer, %[[TMP10]] ] +; CHECK-NEXT: %i.06 = phi i64 [ %add, %for.body ], [ 0, %[[TMP10]] ] +; CHECK-NEXT: [[TMP11:%.*]] = tail call @llvm.aarch64.sve.whilelo.nxv4i1.i64(i64 %i.06, i64 %d) +; CHECK-NEXT: %arrayidx = getelementptr inbounds float, ptr %x, i64 %i.06 +; CHECK-NEXT: [[TMP12:%.*]] = tail call @llvm.masked.load.nxv4f32.p0(ptr %arrayidx, i32 1, [[TMP11]], zeroinitializer) +; CHECK-NEXT: [[TMP13]] = tail call contract @llvm.aarch64.sve.fsub.nxv4f32( [[TMP11]], %sum.07, [[TMP12]]) +; CHECK-NEXT: %add = add nuw i64 %i.06, 4 +; CHECK-NEXT: %cmp = icmp ult i64 %add, %d +; CHECK-NEXT: br i1 %cmp, label %for.body, label %[[LOOPEXIT]] + +entry: + %cmp5.not = icmp eq i64 %d, 0 + br i1 %cmp5.not, label %for.cond.cleanup, label %for.body + +for.cond.cleanup: ; preds = %for.body, %entry + %sum.0.lcssa = phi [ zeroinitializer, %entry ], [ %4, %for.body ] + %0 = tail call @llvm.aarch64.sve.ptrue.nxv4i1(i32 31) + %1 = tail call contract float @llvm.aarch64.sve.faddv.nxv4f32( %0, %sum.0.lcssa) + ret float %1 + +for.body: ; preds = %entry, %for.body + %sum.07 = phi [ %4, %for.body ], [ zeroinitializer, %entry ] + %i.06 = phi i64 [ %add, %for.body ], [ 0, %entry ] + %2 = tail call @llvm.aarch64.sve.whilelo.nxv4i1.i64(i64 %i.06, i64 %d) + %arrayidx = getelementptr inbounds float, ptr %x, i64 %i.06 + %3 = tail call @llvm.masked.load.nxv4f32.p0(ptr %arrayidx, i32 1, %2, zeroinitializer) + %4 = tail call contract @llvm.aarch64.sve.fsub.nxv4f32( %2, %sum.07, %3) + %add = add nuw i64 %i.06, 4 + %cmp = icmp ult i64 %add, %d + br i1 %cmp, label %for.body, label %for.cond.cleanup +} + +declare @llvm.aarch64.sve.whilelo.nxv4i1.i64(i64, i64) + +declare @llvm.aarch64.sve.fsub.nxv4f32(, , ) + +declare @llvm.aarch64.sve.ptrue.nxv4i1(i32 immarg) + +declare float @llvm.aarch64.sve.faddv.nxv4f32(, ) + +declare @llvm.masked.load.nxv4f32.p0(ptr nocapture, i32 immarg, , ) + diff --git a/llvm/test/Transforms/LoopIterationPrefetchBefore/basic.ll b/llvm/test/Transforms/LoopIterationPrefetchBefore/basic.ll new file mode 100644 index 0000000000000000000000000000000000000000..6db1979744a3303463a10369078b147331194b7c --- /dev/null +++ b/llvm/test/Transforms/LoopIterationPrefetchBefore/basic.ll @@ -0,0 +1,52 @@ +; REQUIRES: aarch64-registered-target +; RUN: opt -mtriple aarch64-linux-gnu -passes=loop-iteration-prefetch-before \ +; RUN: -force-enable-experimental-optimization=true \ +; RUN: -loop-iteration-prefetch-before-funcs=foo -S < %s | FileCheck %s + +define void @foo (ptr nocapture %a, ptr nocapture readonly %b) { +; CHECK-LABEL: define void @foo +; CHECK: entry: +; CHECK-NEXT: br label %[[TMP0:.*]] + +; CHECK: [[TMP0]]: +; CHECK-NEXT: [[TMP1:%.*]] = phi i64 [ 0, %entry ], [ [[NEXT:%.*]], %[[TMP0]] ] +; CHECK-NEXT: [[TMP2:%.*]] = getelementptr float, ptr %b, i64 [[TMP1]] +; CHECK-NEXT: call void @llvm.prefetch.p0(ptr [[TMP2]], i32 0, i32 3, i32 1) +; CHECK-NEXT: [[NEXT]] = add nuw nsw i64 [[TMP1]], 16 +; CHECK-NEXT: [[TMP3:%.*]] = icmp ult i64 [[NEXT]], 3198 +; CHECK-NEXT: br i1 [[TMP3]], label %[[TMP0]], label %[[TMP4:.*]] + +; CHECK: [[TMP4]]: +; CHECK-NEXT: br label %for.body + +; CHECK: for.body: +; CHECK-NEXT: %indvars.iv = phi i64 [ 0, %[[TMP4:.*]] ], [ %indvars.iv.next, %for.body ] +; CHECK-NEXT: %arrayidx = getelementptr inbounds double, ptr %b, i64 %indvars.iv +; CHECK-NEXT: [[TMP5:%.*]] = load double, ptr %arrayidx, align 8 +; CHECK-NEXT: %add = fadd double [[TMP5]], 1.000000e+00 +; CHECK-NEXT: %arrayidx2 = getelementptr inbounds double, ptr %a, i64 %indvars.iv +; CHECK-NEXT: store double %add, ptr %arrayidx2, align 8 +; CHECK-NEXT: %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 +; CHECK-NEXT: %exitcond = icmp eq i64 %indvars.iv.next, 1600 +; CHECK-NEXT: br i1 %exitcond, label %for.end, label %for.body + +; CHECK: for.end: +; CHECK-NEXT: ret void + +entry: + br label %for.body + +for.body: ; preds = %for.body, %entry + %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ] + %arrayidx = getelementptr inbounds double, ptr %b, i64 %indvars.iv + %0 = load double, ptr %arrayidx, align 8 + %add = fadd double %0, 1.000000e+00 + %arrayidx2 = getelementptr inbounds double, ptr %a, i64 %indvars.iv + store double %add, ptr %arrayidx2, align 8 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %exitcond = icmp eq i64 %indvars.iv.next, 1600 + br i1 %exitcond, label %for.end, label %for.body + +for.end: ; preds = %for.body + ret void +}