blob: 1afccf84ee4826fbbda9b09a0c32049b5dd1042e [file] [log] [blame]
Olivier Deprezf4ef2d02021-04-20 13:36:24 +02001//===- MLInlineAdvisor.h - ML - based InlineAdvisor factories ---*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef LLVM_ANALYSIS_MLINLINEADVISOR_H
10#define LLVM_ANALYSIS_MLINLINEADVISOR_H
11
12#include "llvm/Analysis/CallGraph.h"
13#include "llvm/Analysis/InlineAdvisor.h"
14#include "llvm/Analysis/MLModelRunner.h"
15#include "llvm/IR/PassManager.h"
16
17#include <memory>
18#include <unordered_map>
19
20namespace llvm {
21class Module;
22class MLInlineAdvice;
23
24class MLInlineAdvisor : public InlineAdvisor {
25public:
26 MLInlineAdvisor(Module &M, ModuleAnalysisManager &MAM,
27 std::unique_ptr<MLModelRunner> ModelRunner);
28
29 CallGraph *callGraph() const { return CG.get(); }
30 virtual ~MLInlineAdvisor() = default;
31
32 void onPassEntry() override;
33
34 int64_t getIRSize(const Function &F) const { return F.getInstructionCount(); }
35 void onSuccessfulInlining(const MLInlineAdvice &Advice,
36 bool CalleeWasDeleted);
37
38 bool isForcedToStop() const { return ForceStop; }
39 int64_t getLocalCalls(Function &F);
40 const MLModelRunner &getModelRunner() const { return *ModelRunner.get(); }
41
42protected:
43 std::unique_ptr<InlineAdvice> getAdviceImpl(CallBase &CB) override;
44
45 std::unique_ptr<InlineAdvice> getMandatoryAdvice(CallBase &CB,
46 bool Advice) override;
47
48 virtual std::unique_ptr<MLInlineAdvice> getMandatoryAdviceImpl(CallBase &CB);
49
50 virtual std::unique_ptr<MLInlineAdvice>
51 getAdviceFromModel(CallBase &CB, OptimizationRemarkEmitter &ORE);
52
53 Module &M;
54 std::unique_ptr<MLModelRunner> ModelRunner;
55
56private:
57 int64_t getModuleIRSize() const;
58
59 std::unique_ptr<CallGraph> CG;
60
61 int64_t NodeCount = 0;
62 int64_t EdgeCount = 0;
63 std::map<const Function *, unsigned> FunctionLevels;
64 const int32_t InitialIRSize = 0;
65 int32_t CurrentIRSize = 0;
66
67 bool ForceStop = false;
68};
69
70/// InlineAdvice that tracks changes post inlining. For that reason, it only
71/// overrides the "successful inlining" extension points.
72class MLInlineAdvice : public InlineAdvice {
73public:
74 MLInlineAdvice(MLInlineAdvisor *Advisor, CallBase &CB,
75 OptimizationRemarkEmitter &ORE, bool Recommendation)
76 : InlineAdvice(Advisor, CB, ORE, Recommendation),
77 CallerIRSize(Advisor->isForcedToStop() ? 0
78 : Advisor->getIRSize(*Caller)),
79 CalleeIRSize(Advisor->isForcedToStop() ? 0
80 : Advisor->getIRSize(*Callee)),
81 CallerAndCalleeEdges(Advisor->isForcedToStop()
82 ? 0
83 : (Advisor->getLocalCalls(*Caller) +
84 Advisor->getLocalCalls(*Callee))) {}
85 virtual ~MLInlineAdvice() = default;
86
87 void recordInliningImpl() override;
88 void recordInliningWithCalleeDeletedImpl() override;
89 void recordUnsuccessfulInliningImpl(const InlineResult &Result) override;
90 void recordUnattemptedInliningImpl() override;
91
92 Function *getCaller() const { return Caller; }
93 Function *getCallee() const { return Callee; }
94
95 const int64_t CallerIRSize;
96 const int64_t CalleeIRSize;
97 const int64_t CallerAndCalleeEdges;
98
99private:
100 void reportContextForRemark(DiagnosticInfoOptimizationBase &OR);
101
102 MLInlineAdvisor *getAdvisor() const {
103 return static_cast<MLInlineAdvisor *>(Advisor);
104 };
105};
106
107} // namespace llvm
108
109#endif // LLVM_ANALYSIS_MLINLINEADVISOR_H