Olivier Deprez | f4ef2d0 | 2021-04-20 13:36:24 +0200 | [diff] [blame] | 1 | //===- llvm/MatrixBuilder.h - Builder to lower matrix ops -------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file defines the MatrixBuilder class, which is used as a convenient way |
| 10 | // to lower matrix operations to LLVM IR. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #ifndef LLVM_IR_MATRIXBUILDER_H |
| 15 | #define LLVM_IR_MATRIXBUILDER_H |
| 16 | |
| 17 | #include "llvm/IR/Constant.h" |
| 18 | #include "llvm/IR/Constants.h" |
| 19 | #include "llvm/IR/IRBuilder.h" |
| 20 | #include "llvm/IR/InstrTypes.h" |
| 21 | #include "llvm/IR/Instruction.h" |
| 22 | #include "llvm/IR/IntrinsicInst.h" |
| 23 | #include "llvm/IR/Type.h" |
| 24 | #include "llvm/IR/Value.h" |
| 25 | #include "llvm/Support/Alignment.h" |
| 26 | |
| 27 | namespace llvm { |
| 28 | |
| 29 | class Function; |
| 30 | class Twine; |
| 31 | class Module; |
| 32 | |
| 33 | template <class IRBuilderTy> class MatrixBuilder { |
| 34 | IRBuilderTy &B; |
| 35 | Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); } |
| 36 | |
| 37 | std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS, |
| 38 | Value *RHS) { |
| 39 | assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) && |
| 40 | "One of the operands must be a matrix (embedded in a vector)"); |
| 41 | if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) { |
| 42 | assert(!isa<ScalableVectorType>(LHS->getType()) && |
| 43 | "LHS Assumed to be fixed width"); |
| 44 | RHS = B.CreateVectorSplat( |
| 45 | cast<VectorType>(LHS->getType())->getElementCount(), RHS, |
| 46 | "scalar.splat"); |
| 47 | } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) { |
| 48 | assert(!isa<ScalableVectorType>(RHS->getType()) && |
| 49 | "RHS Assumed to be fixed width"); |
| 50 | LHS = B.CreateVectorSplat( |
| 51 | cast<VectorType>(RHS->getType())->getElementCount(), LHS, |
| 52 | "scalar.splat"); |
| 53 | } |
| 54 | return {LHS, RHS}; |
| 55 | } |
| 56 | |
| 57 | public: |
| 58 | MatrixBuilder(IRBuilderTy &Builder) : B(Builder) {} |
| 59 | |
| 60 | /// Create a column major, strided matrix load. |
| 61 | /// \p DataPtr - Start address of the matrix read |
| 62 | /// \p Rows - Number of rows in matrix (must be a constant) |
| 63 | /// \p Columns - Number of columns in matrix (must be a constant) |
| 64 | /// \p Stride - Space between columns |
| 65 | CallInst *CreateColumnMajorLoad(Value *DataPtr, Align Alignment, |
| 66 | Value *Stride, bool IsVolatile, unsigned Rows, |
| 67 | unsigned Columns, const Twine &Name = "") { |
| 68 | |
| 69 | // Deal with the pointer |
| 70 | PointerType *PtrTy = cast<PointerType>(DataPtr->getType()); |
| 71 | Type *EltTy = PtrTy->getElementType(); |
| 72 | |
| 73 | auto *RetType = FixedVectorType::get(EltTy, Rows * Columns); |
| 74 | |
| 75 | Value *Ops[] = {DataPtr, Stride, B.getInt1(IsVolatile), B.getInt32(Rows), |
| 76 | B.getInt32(Columns)}; |
| 77 | Type *OverloadedTypes[] = {RetType}; |
| 78 | |
| 79 | Function *TheFn = Intrinsic::getDeclaration( |
| 80 | getModule(), Intrinsic::matrix_column_major_load, OverloadedTypes); |
| 81 | |
| 82 | CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); |
| 83 | Attribute AlignAttr = |
| 84 | Attribute::getWithAlignment(Call->getContext(), Alignment); |
| 85 | Call->addAttribute(1, AlignAttr); |
| 86 | return Call; |
| 87 | } |
| 88 | |
| 89 | /// Create a column major, strided matrix store. |
| 90 | /// \p Matrix - Matrix to store |
| 91 | /// \p Ptr - Pointer to write back to |
| 92 | /// \p Stride - Space between columns |
| 93 | CallInst *CreateColumnMajorStore(Value *Matrix, Value *Ptr, Align Alignment, |
| 94 | Value *Stride, bool IsVolatile, |
| 95 | unsigned Rows, unsigned Columns, |
| 96 | const Twine &Name = "") { |
| 97 | Value *Ops[] = {Matrix, Ptr, |
| 98 | Stride, B.getInt1(IsVolatile), |
| 99 | B.getInt32(Rows), B.getInt32(Columns)}; |
| 100 | Type *OverloadedTypes[] = {Matrix->getType()}; |
| 101 | |
| 102 | Function *TheFn = Intrinsic::getDeclaration( |
| 103 | getModule(), Intrinsic::matrix_column_major_store, OverloadedTypes); |
| 104 | |
| 105 | CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); |
| 106 | Attribute AlignAttr = |
| 107 | Attribute::getWithAlignment(Call->getContext(), Alignment); |
| 108 | Call->addAttribute(2, AlignAttr); |
| 109 | return Call; |
| 110 | } |
| 111 | |
| 112 | /// Create a llvm.matrix.transpose call, transposing \p Matrix with \p Rows |
| 113 | /// rows and \p Columns columns. |
| 114 | CallInst *CreateMatrixTranspose(Value *Matrix, unsigned Rows, |
| 115 | unsigned Columns, const Twine &Name = "") { |
| 116 | auto *OpType = cast<VectorType>(Matrix->getType()); |
| 117 | auto *ReturnType = |
| 118 | FixedVectorType::get(OpType->getElementType(), Rows * Columns); |
| 119 | |
| 120 | Type *OverloadedTypes[] = {ReturnType}; |
| 121 | Value *Ops[] = {Matrix, B.getInt32(Rows), B.getInt32(Columns)}; |
| 122 | Function *TheFn = Intrinsic::getDeclaration( |
| 123 | getModule(), Intrinsic::matrix_transpose, OverloadedTypes); |
| 124 | |
| 125 | return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); |
| 126 | } |
| 127 | |
| 128 | /// Create a llvm.matrix.multiply call, multiplying matrixes \p LHS and \p |
| 129 | /// RHS. |
| 130 | CallInst *CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows, |
| 131 | unsigned LHSColumns, unsigned RHSColumns, |
| 132 | const Twine &Name = "") { |
| 133 | auto *LHSType = cast<VectorType>(LHS->getType()); |
| 134 | auto *RHSType = cast<VectorType>(RHS->getType()); |
| 135 | |
| 136 | auto *ReturnType = |
| 137 | FixedVectorType::get(LHSType->getElementType(), LHSRows * RHSColumns); |
| 138 | |
| 139 | Value *Ops[] = {LHS, RHS, B.getInt32(LHSRows), B.getInt32(LHSColumns), |
| 140 | B.getInt32(RHSColumns)}; |
| 141 | Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType}; |
| 142 | |
| 143 | Function *TheFn = Intrinsic::getDeclaration( |
| 144 | getModule(), Intrinsic::matrix_multiply, OverloadedTypes); |
| 145 | return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); |
| 146 | } |
| 147 | |
| 148 | /// Insert a single element \p NewVal into \p Matrix at indices (\p RowIdx, \p |
| 149 | /// ColumnIdx). |
| 150 | Value *CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx, |
| 151 | Value *ColumnIdx, unsigned NumRows) { |
| 152 | return B.CreateInsertElement( |
| 153 | Matrix, NewVal, |
| 154 | B.CreateAdd(B.CreateMul(ColumnIdx, ConstantInt::get( |
| 155 | ColumnIdx->getType(), NumRows)), |
| 156 | RowIdx)); |
| 157 | } |
| 158 | |
| 159 | /// Add matrixes \p LHS and \p RHS. Support both integer and floating point |
| 160 | /// matrixes. |
| 161 | Value *CreateAdd(Value *LHS, Value *RHS) { |
| 162 | assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()); |
| 163 | if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) { |
| 164 | assert(!isa<ScalableVectorType>(LHS->getType()) && |
| 165 | "LHS Assumed to be fixed width"); |
| 166 | RHS = B.CreateVectorSplat( |
| 167 | cast<VectorType>(LHS->getType())->getElementCount(), RHS, |
| 168 | "scalar.splat"); |
| 169 | } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) { |
| 170 | assert(!isa<ScalableVectorType>(RHS->getType()) && |
| 171 | "RHS Assumed to be fixed width"); |
| 172 | LHS = B.CreateVectorSplat( |
| 173 | cast<VectorType>(RHS->getType())->getElementCount(), LHS, |
| 174 | "scalar.splat"); |
| 175 | } |
| 176 | |
| 177 | return cast<VectorType>(LHS->getType()) |
| 178 | ->getElementType() |
| 179 | ->isFloatingPointTy() |
| 180 | ? B.CreateFAdd(LHS, RHS) |
| 181 | : B.CreateAdd(LHS, RHS); |
| 182 | } |
| 183 | |
| 184 | /// Subtract matrixes \p LHS and \p RHS. Support both integer and floating |
| 185 | /// point matrixes. |
| 186 | Value *CreateSub(Value *LHS, Value *RHS) { |
| 187 | assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()); |
| 188 | if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) { |
| 189 | assert(!isa<ScalableVectorType>(LHS->getType()) && |
| 190 | "LHS Assumed to be fixed width"); |
| 191 | RHS = B.CreateVectorSplat( |
| 192 | cast<VectorType>(LHS->getType())->getElementCount(), RHS, |
| 193 | "scalar.splat"); |
| 194 | } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) { |
| 195 | assert(!isa<ScalableVectorType>(RHS->getType()) && |
| 196 | "RHS Assumed to be fixed width"); |
| 197 | LHS = B.CreateVectorSplat( |
| 198 | cast<VectorType>(RHS->getType())->getElementCount(), LHS, |
| 199 | "scalar.splat"); |
| 200 | } |
| 201 | |
| 202 | return cast<VectorType>(LHS->getType()) |
| 203 | ->getElementType() |
| 204 | ->isFloatingPointTy() |
| 205 | ? B.CreateFSub(LHS, RHS) |
| 206 | : B.CreateSub(LHS, RHS); |
| 207 | } |
| 208 | |
| 209 | /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p |
| 210 | /// RHS. |
| 211 | Value *CreateScalarMultiply(Value *LHS, Value *RHS) { |
| 212 | std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS); |
| 213 | if (LHS->getType()->getScalarType()->isFloatingPointTy()) |
| 214 | return B.CreateFMul(LHS, RHS); |
| 215 | return B.CreateMul(LHS, RHS); |
| 216 | } |
| 217 | |
| 218 | /// Extracts the element at (\p RowIdx, \p ColumnIdx) from \p Matrix. |
| 219 | Value *CreateExtractElement(Value *Matrix, Value *RowIdx, Value *ColumnIdx, |
| 220 | unsigned NumRows, Twine const &Name = "") { |
| 221 | |
| 222 | unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(), |
| 223 | ColumnIdx->getType()->getScalarSizeInBits()); |
| 224 | Type *IntTy = IntegerType::get(RowIdx->getType()->getContext(), MaxWidth); |
| 225 | RowIdx = B.CreateZExt(RowIdx, IntTy); |
| 226 | ColumnIdx = B.CreateZExt(ColumnIdx, IntTy); |
| 227 | Value *NumRowsV = B.getIntN(MaxWidth, NumRows); |
| 228 | return B.CreateExtractElement( |
| 229 | Matrix, B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx), |
| 230 | "matext"); |
| 231 | } |
| 232 | }; |
| 233 | |
| 234 | } // end namespace llvm |
| 235 | |
| 236 | #endif // LLVM_IR_MATRIXBUILDER_H |