Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-05-10 08:44:07

0001 //===- llvm/MatrixBuilder.h - Builder to lower matrix ops -------*- C++ -*-===//
0002 //
0003 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
0004 // See https://llvm.org/LICENSE.txt for license information.
0005 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
0006 //
0007 //===----------------------------------------------------------------------===//
0008 //
0009 // This file defines the MatrixBuilder class, which is used as a convenient way
0010 // to lower matrix operations to LLVM IR.
0011 //
0012 //===----------------------------------------------------------------------===//
0013 
0014 #ifndef LLVM_IR_MATRIXBUILDER_H
0015 #define LLVM_IR_MATRIXBUILDER_H
0016 
0017 #include "llvm/IR/Constant.h"
0018 #include "llvm/IR/Constants.h"
0019 #include "llvm/IR/IRBuilder.h"
0020 #include "llvm/IR/InstrTypes.h"
0021 #include "llvm/IR/Instruction.h"
0022 #include "llvm/IR/IntrinsicInst.h"
0023 #include "llvm/IR/Type.h"
0024 #include "llvm/IR/Value.h"
0025 #include "llvm/Support/Alignment.h"
0026 
0027 namespace llvm {
0028 
0029 class Function;
0030 class Twine;
0031 class Module;
0032 
0033 class MatrixBuilder {
0034   IRBuilderBase &B;
0035   Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); }
0036 
0037   std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS,
0038                                                          Value *RHS) {
0039     assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) &&
0040            "One of the operands must be a matrix (embedded in a vector)");
0041     if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
0042       assert(!isa<ScalableVectorType>(LHS->getType()) &&
0043              "LHS Assumed to be fixed width");
0044       RHS = B.CreateVectorSplat(
0045           cast<VectorType>(LHS->getType())->getElementCount(), RHS,
0046           "scalar.splat");
0047     } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
0048       assert(!isa<ScalableVectorType>(RHS->getType()) &&
0049              "RHS Assumed to be fixed width");
0050       LHS = B.CreateVectorSplat(
0051           cast<VectorType>(RHS->getType())->getElementCount(), LHS,
0052           "scalar.splat");
0053     }
0054     return {LHS, RHS};
0055   }
0056 
0057 public:
0058   MatrixBuilder(IRBuilderBase &Builder) : B(Builder) {}
0059 
0060   /// Create a column major, strided matrix load.
0061   /// \p EltTy   - Matrix element type
0062   /// \p DataPtr - Start address of the matrix read
0063   /// \p Rows    - Number of rows in matrix (must be a constant)
0064   /// \p Columns - Number of columns in matrix (must be a constant)
0065   /// \p Stride  - Space between columns
0066   CallInst *CreateColumnMajorLoad(Type *EltTy, Value *DataPtr, Align Alignment,
0067                                   Value *Stride, bool IsVolatile, unsigned Rows,
0068                                   unsigned Columns, const Twine &Name = "") {
0069     auto *RetType = FixedVectorType::get(EltTy, Rows * Columns);
0070 
0071     Value *Ops[] = {DataPtr, Stride, B.getInt1(IsVolatile), B.getInt32(Rows),
0072                     B.getInt32(Columns)};
0073     Type *OverloadedTypes[] = {RetType, Stride->getType()};
0074 
0075     Function *TheFn = Intrinsic::getOrInsertDeclaration(
0076         getModule(), Intrinsic::matrix_column_major_load, OverloadedTypes);
0077 
0078     CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
0079     Attribute AlignAttr =
0080         Attribute::getWithAlignment(Call->getContext(), Alignment);
0081     Call->addParamAttr(0, AlignAttr);
0082     return Call;
0083   }
0084 
0085   /// Create a column major, strided matrix store.
0086   /// \p Matrix  - Matrix to store
0087   /// \p Ptr     - Pointer to write back to
0088   /// \p Stride  - Space between columns
0089   CallInst *CreateColumnMajorStore(Value *Matrix, Value *Ptr, Align Alignment,
0090                                    Value *Stride, bool IsVolatile,
0091                                    unsigned Rows, unsigned Columns,
0092                                    const Twine &Name = "") {
0093     Value *Ops[] = {Matrix,           Ptr,
0094                     Stride,           B.getInt1(IsVolatile),
0095                     B.getInt32(Rows), B.getInt32(Columns)};
0096     Type *OverloadedTypes[] = {Matrix->getType(), Stride->getType()};
0097 
0098     Function *TheFn = Intrinsic::getOrInsertDeclaration(
0099         getModule(), Intrinsic::matrix_column_major_store, OverloadedTypes);
0100 
0101     CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
0102     Attribute AlignAttr =
0103         Attribute::getWithAlignment(Call->getContext(), Alignment);
0104     Call->addParamAttr(1, AlignAttr);
0105     return Call;
0106   }
0107 
0108   /// Create a llvm.matrix.transpose call, transposing \p Matrix with \p Rows
0109   /// rows and \p Columns columns.
0110   CallInst *CreateMatrixTranspose(Value *Matrix, unsigned Rows,
0111                                   unsigned Columns, const Twine &Name = "") {
0112     auto *OpType = cast<VectorType>(Matrix->getType());
0113     auto *ReturnType =
0114         FixedVectorType::get(OpType->getElementType(), Rows * Columns);
0115 
0116     Type *OverloadedTypes[] = {ReturnType};
0117     Value *Ops[] = {Matrix, B.getInt32(Rows), B.getInt32(Columns)};
0118     Function *TheFn = Intrinsic::getOrInsertDeclaration(
0119         getModule(), Intrinsic::matrix_transpose, OverloadedTypes);
0120 
0121     return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
0122   }
0123 
0124   /// Create a llvm.matrix.multiply call, multiplying matrixes \p LHS and \p
0125   /// RHS.
0126   CallInst *CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows,
0127                                  unsigned LHSColumns, unsigned RHSColumns,
0128                                  const Twine &Name = "") {
0129     auto *LHSType = cast<VectorType>(LHS->getType());
0130     auto *RHSType = cast<VectorType>(RHS->getType());
0131 
0132     auto *ReturnType =
0133         FixedVectorType::get(LHSType->getElementType(), LHSRows * RHSColumns);
0134 
0135     Value *Ops[] = {LHS, RHS, B.getInt32(LHSRows), B.getInt32(LHSColumns),
0136                     B.getInt32(RHSColumns)};
0137     Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType};
0138 
0139     Function *TheFn = Intrinsic::getOrInsertDeclaration(
0140         getModule(), Intrinsic::matrix_multiply, OverloadedTypes);
0141     return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
0142   }
0143 
0144   /// Insert a single element \p NewVal into \p Matrix at indices (\p RowIdx, \p
0145   /// ColumnIdx).
0146   Value *CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx,
0147                             Value *ColumnIdx, unsigned NumRows) {
0148     return B.CreateInsertElement(
0149         Matrix, NewVal,
0150         B.CreateAdd(B.CreateMul(ColumnIdx, ConstantInt::get(
0151                                                ColumnIdx->getType(), NumRows)),
0152                     RowIdx));
0153   }
0154 
0155   /// Add matrixes \p LHS and \p RHS. Support both integer and floating point
0156   /// matrixes.
0157   Value *CreateAdd(Value *LHS, Value *RHS) {
0158     assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
0159     if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
0160       assert(!isa<ScalableVectorType>(LHS->getType()) &&
0161              "LHS Assumed to be fixed width");
0162       RHS = B.CreateVectorSplat(
0163           cast<VectorType>(LHS->getType())->getElementCount(), RHS,
0164           "scalar.splat");
0165     } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
0166       assert(!isa<ScalableVectorType>(RHS->getType()) &&
0167              "RHS Assumed to be fixed width");
0168       LHS = B.CreateVectorSplat(
0169           cast<VectorType>(RHS->getType())->getElementCount(), LHS,
0170           "scalar.splat");
0171     }
0172 
0173     return cast<VectorType>(LHS->getType())
0174                    ->getElementType()
0175                    ->isFloatingPointTy()
0176                ? B.CreateFAdd(LHS, RHS)
0177                : B.CreateAdd(LHS, RHS);
0178   }
0179 
0180   /// Subtract matrixes \p LHS and \p RHS. Support both integer and floating
0181   /// point matrixes.
0182   Value *CreateSub(Value *LHS, Value *RHS) {
0183     assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
0184     if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
0185       assert(!isa<ScalableVectorType>(LHS->getType()) &&
0186              "LHS Assumed to be fixed width");
0187       RHS = B.CreateVectorSplat(
0188           cast<VectorType>(LHS->getType())->getElementCount(), RHS,
0189           "scalar.splat");
0190     } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
0191       assert(!isa<ScalableVectorType>(RHS->getType()) &&
0192              "RHS Assumed to be fixed width");
0193       LHS = B.CreateVectorSplat(
0194           cast<VectorType>(RHS->getType())->getElementCount(), LHS,
0195           "scalar.splat");
0196     }
0197 
0198     return cast<VectorType>(LHS->getType())
0199                    ->getElementType()
0200                    ->isFloatingPointTy()
0201                ? B.CreateFSub(LHS, RHS)
0202                : B.CreateSub(LHS, RHS);
0203   }
0204 
0205   /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p
0206   /// RHS.
0207   Value *CreateScalarMultiply(Value *LHS, Value *RHS) {
0208     std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS);
0209     if (LHS->getType()->getScalarType()->isFloatingPointTy())
0210       return B.CreateFMul(LHS, RHS);
0211     return B.CreateMul(LHS, RHS);
0212   }
0213 
0214   /// Divide matrix \p LHS by scalar \p RHS. If the operands are integers, \p
0215   /// IsUnsigned indicates whether UDiv or SDiv should be used.
0216   Value *CreateScalarDiv(Value *LHS, Value *RHS, bool IsUnsigned) {
0217     assert(LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy());
0218     assert(!isa<ScalableVectorType>(LHS->getType()) &&
0219            "LHS Assumed to be fixed width");
0220     RHS =
0221         B.CreateVectorSplat(cast<VectorType>(LHS->getType())->getElementCount(),
0222                             RHS, "scalar.splat");
0223     return cast<VectorType>(LHS->getType())
0224                    ->getElementType()
0225                    ->isFloatingPointTy()
0226                ? B.CreateFDiv(LHS, RHS)
0227                : (IsUnsigned ? B.CreateUDiv(LHS, RHS) : B.CreateSDiv(LHS, RHS));
0228   }
0229 
0230   /// Create an assumption that \p Idx is less than \p NumElements.
0231   void CreateIndexAssumption(Value *Idx, unsigned NumElements,
0232                              Twine const &Name = "") {
0233     Value *NumElts =
0234         B.getIntN(Idx->getType()->getScalarSizeInBits(), NumElements);
0235     auto *Cmp = B.CreateICmpULT(Idx, NumElts);
0236     if (isa<ConstantInt>(Cmp))
0237       assert(cast<ConstantInt>(Cmp)->isOne() && "Index must be valid!");
0238     else
0239       B.CreateAssumption(Cmp);
0240   }
0241 
0242   /// Compute the index to access the element at (\p RowIdx, \p ColumnIdx) from
0243   /// a matrix with \p NumRows embedded in a vector.
0244   Value *CreateIndex(Value *RowIdx, Value *ColumnIdx, unsigned NumRows,
0245                      Twine const &Name = "") {
0246     unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(),
0247                                  ColumnIdx->getType()->getScalarSizeInBits());
0248     Type *IntTy = IntegerType::get(RowIdx->getType()->getContext(), MaxWidth);
0249     RowIdx = B.CreateZExt(RowIdx, IntTy);
0250     ColumnIdx = B.CreateZExt(ColumnIdx, IntTy);
0251     Value *NumRowsV = B.getIntN(MaxWidth, NumRows);
0252     return B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx);
0253   }
0254 };
0255 
0256 } // end namespace llvm
0257 
0258 #endif // LLVM_IR_MATRIXBUILDER_H