Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-05-10 08:43:38

0001 //===- llvm/CodeGen/TileShapeInfo.h - ---------------------------*- C++ -*-===//
0002 //
0003 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
0004 // See https://llvm.org/LICENSE.txt for license information.
0005 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
0006 //
0007 //===----------------------------------------------------------------------===//
0008 //
0009 /// \file Shape utility for AMX.
0010 /// AMX hardware requires to config the shape of tile data register before use.
0011 /// The 2D shape includes row and column. In AMX intrinsics interface the shape
0012 /// is passed as 1st and 2nd parameter and they are lowered as the 1st and 2nd
0013 /// machine operand of AMX pseudo instructions. ShapeT class is to facilitate
0014 /// tile config and register allocator. The row and column are machine operand
0015 /// of AMX pseudo instructions.
0016 //
0017 //===----------------------------------------------------------------------===//
0018 
0019 #ifndef LLVM_CODEGEN_TILESHAPEINFO_H
0020 #define LLVM_CODEGEN_TILESHAPEINFO_H
0021 
0022 #include "llvm/CodeGen/MachineInstr.h"
0023 #include "llvm/CodeGen/MachineOperand.h"
0024 #include "llvm/CodeGen/MachineRegisterInfo.h"
0025 #include "llvm/CodeGen/Register.h"
0026 
0027 namespace llvm {
0028 
0029 class ShapeT {
0030 public:
0031   ShapeT(MachineOperand *Row, MachineOperand *Col,
0032          const MachineRegisterInfo *MRI = nullptr)
0033       : Row(Row), Col(Col) {
0034     if (MRI)
0035       deduceImm(MRI);
0036   }
0037   // When ShapeT has multiple shapes, we only use Shapes (never use Row and Col)
0038   // and ImmShapes. Due to the most case is only one shape (just simply use
0039   // Shape.Row or Shape.Col), so here we don't merge Row and Col into vector
0040   // Shapes to keep the speed and code simplicity.
0041   // TODO: The upper solution is a temporary way to minimize current tile
0042   // register allocation code changes. It can not handle both Reg shape and
0043   // Imm shape for different shapes (e.g. shape 1 is reg shape while shape 2
0044   // is imm shape). Refine me when we have more multi-tile shape instructions!
0045   ShapeT(ArrayRef<MachineOperand *> ShapesOperands,
0046          const MachineRegisterInfo *MRI = nullptr)
0047       : Row(nullptr), Col(nullptr), RowImm(InvalidImmShape),
0048         ColImm(InvalidImmShape) {
0049     assert(ShapesOperands.size() % 2 == 0 && "Miss row or col!");
0050 
0051     for (auto *Shape : ShapesOperands)
0052       Shapes.push_back(Shape);
0053 
0054     if (MRI)
0055       deduceImm(MRI);
0056   }
0057   ShapeT()
0058       : Row(nullptr), Col(nullptr), RowImm(InvalidImmShape),
0059         ColImm(InvalidImmShape) {}
0060   // TODO: We need to extern cmp operator for multi-shapes if
0061   // we have requirement in the future.
0062   bool operator==(const ShapeT &Shape) const {
0063     MachineOperand *R = Shape.Row;
0064     MachineOperand *C = Shape.Col;
0065     if (!R || !C)
0066       return false;
0067     if (!Row || !Col)
0068       return false;
0069     if (Row->getReg() == R->getReg() && Col->getReg() == C->getReg())
0070       return true;
0071     if ((RowImm != InvalidImmShape) && (ColImm != InvalidImmShape))
0072       return RowImm == Shape.getRowImm() && ColImm == Shape.getColImm();
0073     return false;
0074   }
0075 
0076   bool operator!=(const ShapeT &Shape) const { return !(*this == Shape); }
0077 
0078   MachineOperand *getRow(unsigned I = 0) const {
0079     if (Shapes.empty())
0080       return Row;
0081     assert(Shapes.size() / 2 >= I && "Get invalid row from id!");
0082     return Shapes[I * 2];
0083   }
0084 
0085   MachineOperand *getCol(unsigned I = 0) const {
0086     if (Shapes.empty())
0087       return Col;
0088     assert(Shapes.size() / 2 >= I && "Get invalid col from id!");
0089     return Shapes[I * 2 + 1];
0090   }
0091 
0092   int64_t getRowImm(unsigned I = 0) const {
0093     if (ImmShapes.empty())
0094       return RowImm;
0095     assert(ImmShapes.size() / 2 >= I && "Get invalid imm row from id!");
0096     return ImmShapes[I * 2];
0097   }
0098 
0099   int64_t getColImm(unsigned I = 0) const {
0100     if (ImmShapes.empty())
0101       return ColImm;
0102     assert(ImmShapes.size() / 2 >= I && "Get invalid imm col from id!");
0103     return ImmShapes[I * 2 + 1];
0104   }
0105 
0106   unsigned getShapeNum() {
0107     if (Shapes.empty())
0108       return isValid() ? 1 : 0;
0109     else
0110       return Shapes.size() / 2;
0111   }
0112 
0113   bool isValid() { return (Row != nullptr) && (Col != nullptr); }
0114 
0115   void deduceImm(const MachineRegisterInfo *MRI) {
0116     // All def must be the same value, otherwise it is invalid MIs.
0117     // Find the immediate.
0118     // TODO copy propagation.
0119     auto GetImm = [&](Register Reg) {
0120       int64_t Imm = InvalidImmShape;
0121       for (const MachineOperand &DefMO : MRI->def_operands(Reg)) {
0122         const auto *MI = DefMO.getParent();
0123         if (MI->isMoveImmediate()) {
0124           assert(MI->getNumOperands() == 2 &&
0125                  "Unsupported number of operands in instruction for setting "
0126                  "row/column.");
0127           if (MI->getOperand(1).isImm()) {
0128             Imm = MI->getOperand(1).getImm();
0129           } else {
0130             assert(MI->getOperand(1).isImplicit() &&
0131                    "Operand 1 is assumed to be implicit.");
0132             Imm = 0;
0133           }
0134           break;
0135         }
0136       }
0137       return Imm;
0138     };
0139     if (Shapes.empty()) { // Single Shape
0140       RowImm = GetImm(Row->getReg());
0141       ColImm = GetImm(Col->getReg());
0142       // The number of rows of 2nd destination buffer is assigned by the one of
0143       // 1st destination buffer. If the column size is equal to zero, the row
0144       // size should be reset to zero too.
0145       if (ColImm == 0)
0146         Row = Col;
0147     } else { // Multiple Shapes
0148       for (auto *Shape : Shapes) {
0149         int64_t ImmShape = GetImm(Shape->getReg());
0150         ImmShapes.push_back(ImmShape);
0151       }
0152     }
0153   }
0154 
0155 private:
0156   static constexpr int64_t InvalidImmShape = -1;
0157   MachineOperand *Row;
0158   MachineOperand *Col;
0159   int64_t RowImm = -1;
0160   int64_t ColImm = -1;
0161   // Multiple Shapes
0162   SmallVector<MachineOperand *, 0> Shapes;
0163   SmallVector<int64_t, 0> ImmShapes;
0164 };
0165 
0166 } // namespace llvm
0167 
0168 #endif