File indexing completed on 2026-05-10 08:44:07
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014 #ifndef LLVM_IR_MATRIXBUILDER_H
0015 #define LLVM_IR_MATRIXBUILDER_H
0016
0017 #include "llvm/IR/Constant.h"
0018 #include "llvm/IR/Constants.h"
0019 #include "llvm/IR/IRBuilder.h"
0020 #include "llvm/IR/InstrTypes.h"
0021 #include "llvm/IR/Instruction.h"
0022 #include "llvm/IR/IntrinsicInst.h"
0023 #include "llvm/IR/Type.h"
0024 #include "llvm/IR/Value.h"
0025 #include "llvm/Support/Alignment.h"
0026
0027 namespace llvm {
0028
0029 class Function;
0030 class Twine;
0031 class Module;
0032
0033 class MatrixBuilder {
0034 IRBuilderBase &B;
0035 Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); }
0036
0037 std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS,
0038 Value *RHS) {
0039 assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) &&
0040 "One of the operands must be a matrix (embedded in a vector)");
0041 if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
0042 assert(!isa<ScalableVectorType>(LHS->getType()) &&
0043 "LHS Assumed to be fixed width");
0044 RHS = B.CreateVectorSplat(
0045 cast<VectorType>(LHS->getType())->getElementCount(), RHS,
0046 "scalar.splat");
0047 } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
0048 assert(!isa<ScalableVectorType>(RHS->getType()) &&
0049 "RHS Assumed to be fixed width");
0050 LHS = B.CreateVectorSplat(
0051 cast<VectorType>(RHS->getType())->getElementCount(), LHS,
0052 "scalar.splat");
0053 }
0054 return {LHS, RHS};
0055 }
0056
0057 public:
0058 MatrixBuilder(IRBuilderBase &Builder) : B(Builder) {}
0059
0060
0061
0062
0063
0064
0065
0066 CallInst *CreateColumnMajorLoad(Type *EltTy, Value *DataPtr, Align Alignment,
0067 Value *Stride, bool IsVolatile, unsigned Rows,
0068 unsigned Columns, const Twine &Name = "") {
0069 auto *RetType = FixedVectorType::get(EltTy, Rows * Columns);
0070
0071 Value *Ops[] = {DataPtr, Stride, B.getInt1(IsVolatile), B.getInt32(Rows),
0072 B.getInt32(Columns)};
0073 Type *OverloadedTypes[] = {RetType, Stride->getType()};
0074
0075 Function *TheFn = Intrinsic::getOrInsertDeclaration(
0076 getModule(), Intrinsic::matrix_column_major_load, OverloadedTypes);
0077
0078 CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
0079 Attribute AlignAttr =
0080 Attribute::getWithAlignment(Call->getContext(), Alignment);
0081 Call->addParamAttr(0, AlignAttr);
0082 return Call;
0083 }
0084
0085
0086
0087
0088
0089 CallInst *CreateColumnMajorStore(Value *Matrix, Value *Ptr, Align Alignment,
0090 Value *Stride, bool IsVolatile,
0091 unsigned Rows, unsigned Columns,
0092 const Twine &Name = "") {
0093 Value *Ops[] = {Matrix, Ptr,
0094 Stride, B.getInt1(IsVolatile),
0095 B.getInt32(Rows), B.getInt32(Columns)};
0096 Type *OverloadedTypes[] = {Matrix->getType(), Stride->getType()};
0097
0098 Function *TheFn = Intrinsic::getOrInsertDeclaration(
0099 getModule(), Intrinsic::matrix_column_major_store, OverloadedTypes);
0100
0101 CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
0102 Attribute AlignAttr =
0103 Attribute::getWithAlignment(Call->getContext(), Alignment);
0104 Call->addParamAttr(1, AlignAttr);
0105 return Call;
0106 }
0107
0108
0109
0110 CallInst *CreateMatrixTranspose(Value *Matrix, unsigned Rows,
0111 unsigned Columns, const Twine &Name = "") {
0112 auto *OpType = cast<VectorType>(Matrix->getType());
0113 auto *ReturnType =
0114 FixedVectorType::get(OpType->getElementType(), Rows * Columns);
0115
0116 Type *OverloadedTypes[] = {ReturnType};
0117 Value *Ops[] = {Matrix, B.getInt32(Rows), B.getInt32(Columns)};
0118 Function *TheFn = Intrinsic::getOrInsertDeclaration(
0119 getModule(), Intrinsic::matrix_transpose, OverloadedTypes);
0120
0121 return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
0122 }
0123
0124
0125
0126 CallInst *CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows,
0127 unsigned LHSColumns, unsigned RHSColumns,
0128 const Twine &Name = "") {
0129 auto *LHSType = cast<VectorType>(LHS->getType());
0130 auto *RHSType = cast<VectorType>(RHS->getType());
0131
0132 auto *ReturnType =
0133 FixedVectorType::get(LHSType->getElementType(), LHSRows * RHSColumns);
0134
0135 Value *Ops[] = {LHS, RHS, B.getInt32(LHSRows), B.getInt32(LHSColumns),
0136 B.getInt32(RHSColumns)};
0137 Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType};
0138
0139 Function *TheFn = Intrinsic::getOrInsertDeclaration(
0140 getModule(), Intrinsic::matrix_multiply, OverloadedTypes);
0141 return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
0142 }
0143
0144
0145
0146 Value *CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx,
0147 Value *ColumnIdx, unsigned NumRows) {
0148 return B.CreateInsertElement(
0149 Matrix, NewVal,
0150 B.CreateAdd(B.CreateMul(ColumnIdx, ConstantInt::get(
0151 ColumnIdx->getType(), NumRows)),
0152 RowIdx));
0153 }
0154
0155
0156
0157 Value *CreateAdd(Value *LHS, Value *RHS) {
0158 assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
0159 if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
0160 assert(!isa<ScalableVectorType>(LHS->getType()) &&
0161 "LHS Assumed to be fixed width");
0162 RHS = B.CreateVectorSplat(
0163 cast<VectorType>(LHS->getType())->getElementCount(), RHS,
0164 "scalar.splat");
0165 } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
0166 assert(!isa<ScalableVectorType>(RHS->getType()) &&
0167 "RHS Assumed to be fixed width");
0168 LHS = B.CreateVectorSplat(
0169 cast<VectorType>(RHS->getType())->getElementCount(), LHS,
0170 "scalar.splat");
0171 }
0172
0173 return cast<VectorType>(LHS->getType())
0174 ->getElementType()
0175 ->isFloatingPointTy()
0176 ? B.CreateFAdd(LHS, RHS)
0177 : B.CreateAdd(LHS, RHS);
0178 }
0179
0180
0181
0182 Value *CreateSub(Value *LHS, Value *RHS) {
0183 assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
0184 if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
0185 assert(!isa<ScalableVectorType>(LHS->getType()) &&
0186 "LHS Assumed to be fixed width");
0187 RHS = B.CreateVectorSplat(
0188 cast<VectorType>(LHS->getType())->getElementCount(), RHS,
0189 "scalar.splat");
0190 } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
0191 assert(!isa<ScalableVectorType>(RHS->getType()) &&
0192 "RHS Assumed to be fixed width");
0193 LHS = B.CreateVectorSplat(
0194 cast<VectorType>(RHS->getType())->getElementCount(), LHS,
0195 "scalar.splat");
0196 }
0197
0198 return cast<VectorType>(LHS->getType())
0199 ->getElementType()
0200 ->isFloatingPointTy()
0201 ? B.CreateFSub(LHS, RHS)
0202 : B.CreateSub(LHS, RHS);
0203 }
0204
0205
0206
0207 Value *CreateScalarMultiply(Value *LHS, Value *RHS) {
0208 std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS);
0209 if (LHS->getType()->getScalarType()->isFloatingPointTy())
0210 return B.CreateFMul(LHS, RHS);
0211 return B.CreateMul(LHS, RHS);
0212 }
0213
0214
0215
0216 Value *CreateScalarDiv(Value *LHS, Value *RHS, bool IsUnsigned) {
0217 assert(LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy());
0218 assert(!isa<ScalableVectorType>(LHS->getType()) &&
0219 "LHS Assumed to be fixed width");
0220 RHS =
0221 B.CreateVectorSplat(cast<VectorType>(LHS->getType())->getElementCount(),
0222 RHS, "scalar.splat");
0223 return cast<VectorType>(LHS->getType())
0224 ->getElementType()
0225 ->isFloatingPointTy()
0226 ? B.CreateFDiv(LHS, RHS)
0227 : (IsUnsigned ? B.CreateUDiv(LHS, RHS) : B.CreateSDiv(LHS, RHS));
0228 }
0229
0230
0231 void CreateIndexAssumption(Value *Idx, unsigned NumElements,
0232 Twine const &Name = "") {
0233 Value *NumElts =
0234 B.getIntN(Idx->getType()->getScalarSizeInBits(), NumElements);
0235 auto *Cmp = B.CreateICmpULT(Idx, NumElts);
0236 if (isa<ConstantInt>(Cmp))
0237 assert(cast<ConstantInt>(Cmp)->isOne() && "Index must be valid!");
0238 else
0239 B.CreateAssumption(Cmp);
0240 }
0241
0242
0243
0244 Value *CreateIndex(Value *RowIdx, Value *ColumnIdx, unsigned NumRows,
0245 Twine const &Name = "") {
0246 unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(),
0247 ColumnIdx->getType()->getScalarSizeInBits());
0248 Type *IntTy = IntegerType::get(RowIdx->getType()->getContext(), MaxWidth);
0249 RowIdx = B.CreateZExt(RowIdx, IntTy);
0250 ColumnIdx = B.CreateZExt(ColumnIdx, IntTy);
0251 Value *NumRowsV = B.getIntN(MaxWidth, NumRows);
0252 return B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx);
0253 }
0254 };
0255
0256 }
0257
0258 #endif