blob: 031d23555b7e56723bd0130331eda511fbeb1189 [file] [log] [blame]
Olivier Deprezf4ef2d02021-04-20 13:36:24 +02001//===- llvm/CodeGen/TileShapeInfo.h - ---------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file Shape utility for AMX.
10/// AMX hardware requires to config the shape of tile data register before use.
11/// The 2D shape includes row and column. In AMX intrinsics interface the shape
12/// is passed as 1st and 2nd parameter and they are lowered as the 1st and 2nd
13/// machine operand of AMX pseudo instructions. ShapeT class is to facilitate
14/// tile config and register allocator. The row and column are machine operand
15/// of AMX pseudo instructions.
16//
17//===----------------------------------------------------------------------===//
18
19#ifndef LLVM_CODEGEN_TILESHAPEINFO_H
20#define LLVM_CODEGEN_TILESHAPEINFO_H
21
22#include "llvm/ADT/DenseMapInfo.h"
23#include "llvm/CodeGen/MachineInstr.h"
24#include "llvm/CodeGen/MachineOperand.h"
25#include "llvm/CodeGen/MachineRegisterInfo.h"
26#include "llvm/CodeGen/Register.h"
27#include <utility>
28
29namespace llvm {
30
31class ShapeT {
32public:
33 ShapeT(MachineOperand *Row, MachineOperand *Col,
34 const MachineRegisterInfo *MRI = nullptr)
35 : Row(Row), Col(Col) {
36 if (MRI)
37 deduceImm(MRI);
38 }
39 ShapeT()
40 : Row(nullptr), Col(nullptr), RowImm(InvalidImmShape),
41 ColImm(InvalidImmShape) {}
42 bool operator==(const ShapeT &Shape) {
43 MachineOperand *R = Shape.Row;
44 MachineOperand *C = Shape.Col;
45 if (!R || !C)
46 return false;
47 if (!Row || !Col)
48 return false;
49 if (Row->getReg() == R->getReg() && Col->getReg() == C->getReg())
50 return true;
51 if ((RowImm != InvalidImmShape) && (ColImm != InvalidImmShape))
52 return RowImm == Shape.getRowImm() && ColImm == Shape.getColImm();
53 return false;
54 }
55
56 bool operator!=(const ShapeT &Shape) { return !(*this == Shape); }
57
58 MachineOperand *getRow() const { return Row; }
59
60 MachineOperand *getCol() const { return Col; }
61
62 int64_t getRowImm() const { return RowImm; }
63
64 int64_t getColImm() const { return ColImm; }
65
66 bool isValid() { return (Row != nullptr) && (Col != nullptr); }
67
68 void deduceImm(const MachineRegisterInfo *MRI) {
69 // All def must be the same value, otherwise it is invalid MIs.
70 // Find the immediate.
71 // TODO copy propagation.
72 auto GetImm = [&](Register Reg) {
73 int64_t Imm = InvalidImmShape;
74 for (const MachineOperand &DefMO : MRI->def_operands(Reg)) {
75 const auto *MI = DefMO.getParent();
76 if (MI->isMoveImmediate()) {
77 Imm = MI->getOperand(1).getImm();
78 break;
79 }
80 }
81 return Imm;
82 };
83 RowImm = GetImm(Row->getReg());
84 ColImm = GetImm(Col->getReg());
85 }
86
87private:
88 static constexpr int64_t InvalidImmShape = -1;
89 MachineOperand *Row;
90 MachineOperand *Col;
91 int64_t RowImm;
92 int64_t ColImm;
93};
94
95} // namespace llvm
96
97#endif