blob: 3cfb9d13df94d094760ee816d81a5229057ced09 [file] [log] [blame]
Andrew Walbran16937d02019-10-22 13:54:20 +01001//===- llvm/Analysis/DivergenceAnalysis.h - Divergence Analysis -*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// The divergence analysis determines which instructions and branches are
11// divergent given a set of divergent source instructions.
12//
13//===----------------------------------------------------------------------===//
14
15#ifndef LLVM_ANALYSIS_DIVERGENCE_ANALYSIS_H
16#define LLVM_ANALYSIS_DIVERGENCE_ANALYSIS_H
17
18#include "llvm/ADT/DenseSet.h"
19#include "llvm/Analysis/SyncDependenceAnalysis.h"
20#include "llvm/IR/Function.h"
21#include "llvm/Pass.h"
22#include <vector>
23
24namespace llvm {
25class Module;
26class Value;
27class Instruction;
28class Loop;
29class raw_ostream;
30class TargetTransformInfo;
31
32/// \brief Generic divergence analysis for reducible CFGs.
33///
34/// This analysis propagates divergence in a data-parallel context from sources
35/// of divergence to all users. It requires reducible CFGs. All assignments
36/// should be in SSA form.
37class DivergenceAnalysis {
38public:
39 /// \brief This instance will analyze the whole function \p F or the loop \p
40 /// RegionLoop.
41 ///
42 /// \param RegionLoop if non-null the analysis is restricted to \p RegionLoop.
43 /// Otherwise the whole function is analyzed.
44 /// \param IsLCSSAForm whether the analysis may assume that the IR in the
45 /// region in in LCSSA form.
46 DivergenceAnalysis(const Function &F, const Loop *RegionLoop,
47 const DominatorTree &DT, const LoopInfo &LI,
48 SyncDependenceAnalysis &SDA, bool IsLCSSAForm);
49
50 /// \brief The loop that defines the analyzed region (if any).
51 const Loop *getRegionLoop() const { return RegionLoop; }
52 const Function &getFunction() const { return F; }
53
54 /// \brief Whether \p BB is part of the region.
55 bool inRegion(const BasicBlock &BB) const;
56 /// \brief Whether \p I is part of the region.
57 bool inRegion(const Instruction &I) const;
58
59 /// \brief Mark \p UniVal as a value that is always uniform.
60 void addUniformOverride(const Value &UniVal);
61
62 /// \brief Mark \p DivVal as a value that is always divergent.
63 void markDivergent(const Value &DivVal);
64
65 /// \brief Propagate divergence to all instructions in the region.
66 /// Divergence is seeded by calls to \p markDivergent.
67 void compute();
68
69 /// \brief Whether any value was marked or analyzed to be divergent.
70 bool hasDetectedDivergence() const { return !DivergentValues.empty(); }
71
72 /// \brief Whether \p Val will always return a uniform value regardless of its
73 /// operands
74 bool isAlwaysUniform(const Value &Val) const;
75
76 /// \brief Whether \p Val is a divergent value
77 bool isDivergent(const Value &Val) const;
78
79 void print(raw_ostream &OS, const Module *) const;
80
81private:
82 bool updateTerminator(const Instruction &Term) const;
83 bool updatePHINode(const PHINode &Phi) const;
84
85 /// \brief Computes whether \p Inst is divergent based on the
86 /// divergence of its operands.
87 ///
88 /// \returns Whether \p Inst is divergent.
89 ///
90 /// This should only be called for non-phi, non-terminator instructions.
91 bool updateNormalInstruction(const Instruction &Inst) const;
92
93 /// \brief Mark users of live-out users as divergent.
94 ///
95 /// \param LoopHeader the header of the divergent loop.
96 ///
97 /// Marks all users of live-out values of the loop headed by \p LoopHeader
98 /// as divergent and puts them on the worklist.
99 void taintLoopLiveOuts(const BasicBlock &LoopHeader);
100
101 /// \brief Push all users of \p Val (in the region) to the worklist
102 void pushUsers(const Value &I);
103
104 /// \brief Push all phi nodes in @block to the worklist
105 void pushPHINodes(const BasicBlock &Block);
106
107 /// \brief Mark \p Block as join divergent
108 ///
109 /// A block is join divergent if two threads may reach it from different
110 /// incoming blocks at the same time.
111 void markBlockJoinDivergent(const BasicBlock &Block) {
112 DivergentJoinBlocks.insert(&Block);
113 }
114
115 /// \brief Whether \p Val is divergent when read in \p ObservingBlock.
116 bool isTemporalDivergent(const BasicBlock &ObservingBlock,
117 const Value &Val) const;
118
119 /// \brief Whether \p Block is join divergent
120 ///
121 /// (see markBlockJoinDivergent).
122 bool isJoinDivergent(const BasicBlock &Block) const {
123 return DivergentJoinBlocks.find(&Block) != DivergentJoinBlocks.end();
124 }
125
126 /// \brief Propagate control-induced divergence to users (phi nodes and
127 /// instructions).
128 //
129 // \param JoinBlock is a divergent loop exit or join point of two disjoint
130 // paths.
131 // \returns Whether \p JoinBlock is a divergent loop exit of \p TermLoop.
132 bool propagateJoinDivergence(const BasicBlock &JoinBlock,
133 const Loop *TermLoop);
134
135 /// \brief Propagate induced value divergence due to control divergence in \p
136 /// Term.
137 void propagateBranchDivergence(const Instruction &Term);
138
139 /// \brief Propagate divergent caused by a divergent loop exit.
140 ///
141 /// \param ExitingLoop is a divergent loop.
142 void propagateLoopDivergence(const Loop &ExitingLoop);
143
144private:
145 const Function &F;
146 // If regionLoop != nullptr, analysis is only performed within \p RegionLoop.
147 // Otw, analyze the whole function
148 const Loop *RegionLoop;
149
150 const DominatorTree &DT;
151 const LoopInfo &LI;
152
153 // Recognized divergent loops
154 DenseSet<const Loop *> DivergentLoops;
155
156 // The SDA links divergent branches to divergent control-flow joins.
157 SyncDependenceAnalysis &SDA;
158
159 // Use simplified code path for LCSSA form.
160 bool IsLCSSAForm;
161
162 // Set of known-uniform values.
163 DenseSet<const Value *> UniformOverrides;
164
165 // Blocks with joining divergent control from different predecessors.
166 DenseSet<const BasicBlock *> DivergentJoinBlocks;
167
168 // Detected/marked divergent values.
169 DenseSet<const Value *> DivergentValues;
170
171 // Internal worklist for divergence propagation.
172 std::vector<const Instruction *> Worklist;
173};
174
175/// \brief Divergence analysis frontend for GPU kernels.
176class GPUDivergenceAnalysis {
177 SyncDependenceAnalysis SDA;
178 DivergenceAnalysis DA;
179
180public:
181 /// Runs the divergence analysis on @F, a GPU kernel
182 GPUDivergenceAnalysis(Function &F, const DominatorTree &DT,
183 const PostDominatorTree &PDT, const LoopInfo &LI,
184 const TargetTransformInfo &TTI);
185
186 /// Whether any divergence was detected.
187 bool hasDivergence() const { return DA.hasDetectedDivergence(); }
188
189 /// The GPU kernel this analysis result is for
190 const Function &getFunction() const { return DA.getFunction(); }
191
192 /// Whether \p V is divergent.
193 bool isDivergent(const Value &V) const;
194
195 /// Whether \p V is uniform/non-divergent
196 bool isUniform(const Value &V) const { return !isDivergent(V); }
197
198 /// Print all divergent values in the kernel.
199 void print(raw_ostream &OS, const Module *) const;
200};
201
202} // namespace llvm
203
204#endif // LLVM_ANALYSIS_DIVERGENCE_ANALYSIS_H