Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions llvm/docs/LangRef.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20895,12 +20895,12 @@ Overview:

The '``llvm.matrix.column.major.load.*``' intrinsics load a ``<Rows> x <Cols>``
matrix using a stride of ``%Stride`` to compute the start address of the
different columns. The offset is computed using ``%Stride``'s bitwidth. This
allows for convenient loading of sub matrixes. If ``<IsVolatile>`` is true, the
intrinsic is considered a :ref:`volatile memory access <volatile>`. The result
matrix is returned in the result vector. If the ``%Ptr`` argument is known to
be aligned to some boundary, this can be specified as an attribute on the
argument.
different columns. This allows for convenient loading of sub matrixes.
Independent of ``%Stride``'s bitwidth, the offset is computed using the target
daya layout's pointer index type. If ``<IsVolatile>`` is true, the intrinsic is
considered a :ref:`volatile memory access <volatile>`. The result matrix is
returned in the result vector. If the ``%Ptr`` argument is known to be aligned
to some boundary, this can be specified as an attribute on the argument.

Arguments:
""""""""""
Expand Down Expand Up @@ -20935,9 +20935,9 @@ Overview:

The '``llvm.matrix.column.major.store.*``' intrinsics store the ``<Rows> x
<Cols>`` matrix in ``%In`` to memory using a stride of ``%Stride`` between
columns. The offset is computed using ``%Stride``'s bitwidth. If
``<IsVolatile>`` is true, the intrinsic is considered a
:ref:`volatile memory access <volatile>`.
columns. Independent of ``%Stride``'s bitwidth, the offset is computed using
the target daya layout's pointer index type. If ``<IsVolatile>`` is true, the
intrinsic is considered a :ref:`volatile memory access <volatile>`.

If the ``%Ptr`` argument is known to be aligned to some boundary, this can be
specified as an attribute on the argument.
Expand Down
43 changes: 31 additions & 12 deletions llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1295,6 +1295,24 @@ class LowerMatrixIntrinsics {
return commonAlignment(InitialAlign, ElementSizeInBits / 8);
}

IntegerType *getIndexType(Value *Ptr) const {
return cast<IntegerType>(DL.getIndexType(Ptr->getType()));
}

Value *getIndex(Value *Ptr, uint64_t V) const {
return ConstantInt::get(getIndexType(Ptr), V);
}

Value *castToIndexType(Value *Ptr, Value *V, IRBuilder<> &Builder) const {
assert(isa<IntegerType>(V->getType()) &&
"Attempted to cast non-integral type to integer index");
// In case the data layout's index type differs in width from the type of
// the value we're given, truncate or zero extend to the appropriate width.
// We zero extend here as indices are unsigned.
return Builder.CreateZExtOrTrunc(V, getIndexType(Ptr),
V->getName() + ".cast");
}

/// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
/// vectors.
MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
Expand All @@ -1304,6 +1322,7 @@ class LowerMatrixIntrinsics {
Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride());
Value *EltPtr = Ptr;
MatrixTy Result;
Stride = castToIndexType(Ptr, Stride, Builder);
for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {
Value *GEP = computeVectorAddr(
EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I),
Expand All @@ -1325,14 +1344,14 @@ class LowerMatrixIntrinsics {
ShapeInfo ResultShape, Type *EltTy,
IRBuilder<> &Builder) {
Value *Offset = Builder.CreateAdd(
Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
Builder.CreateMul(J, getIndex(MatrixPtr, MatrixShape.getStride())), I);

Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset);
auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows *
ResultShape.NumColumns);

return loadMatrix(TileTy, TileStart, Align,
Builder.getInt64(MatrixShape.getStride()), IsVolatile,
getIndex(MatrixPtr, MatrixShape.getStride()), IsVolatile,
ResultShape, Builder);
}

Expand Down Expand Up @@ -1363,14 +1382,15 @@ class LowerMatrixIntrinsics {
MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,
Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) {
Value *Offset = Builder.CreateAdd(
Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
Builder.CreateMul(J, getIndex(MatrixPtr, MatrixShape.getStride())), I);

Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset);
auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() *
StoreVal.getNumColumns());

storeMatrix(TileTy, StoreVal, TileStart, MAlign,
Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
getIndex(MatrixPtr, MatrixShape.getStride()), IsVolatile,
Builder);
}

/// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
Expand All @@ -1380,6 +1400,7 @@ class LowerMatrixIntrinsics {
IRBuilder<> &Builder) {
auto *VType = cast<FixedVectorType>(Ty);
Value *EltPtr = Ptr;
Stride = castToIndexType(Ptr, Stride, Builder);
for (auto Vec : enumerate(StoreVal.vectors())) {
Value *GEP = computeVectorAddr(
EltPtr,
Expand Down Expand Up @@ -2011,18 +2032,17 @@ class LowerMatrixIntrinsics {
const unsigned TileM = std::min(M - K, unsigned(TileSize));
MatrixTy A =
loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(),
LShape, Builder.getInt64(I), Builder.getInt64(K),
LShape, getIndex(APtr, I), getIndex(APtr, K),
{TileR, TileM}, EltType, Builder);
MatrixTy B =
loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(),
RShape, Builder.getInt64(K), Builder.getInt64(J),
RShape, getIndex(BPtr, K), getIndex(BPtr, J),
{TileM, TileC}, EltType, Builder);
emitMatrixMultiply(Res, A, B, Builder, true, false,
getFastMathFlags(MatMul));
}
storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M},
Builder.getInt64(I), Builder.getInt64(J), EltType,
Builder);
getIndex(CPtr, I), getIndex(CPtr, J), EltType, Builder);
}
}

Expand Down Expand Up @@ -2254,15 +2274,14 @@ class LowerMatrixIntrinsics {
/// Lower load instructions.
MatrixTy VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr,
IRBuilder<> &Builder) {
return LowerLoad(Inst, Ptr, Inst->getAlign(),
Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI,
Builder);
return LowerLoad(Inst, Ptr, Inst->getAlign(), getIndex(Ptr, SI.getStride()),
Inst->isVolatile(), SI, Builder);
}

MatrixTy VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal,
Value *Ptr, IRBuilder<> &Builder) {
return LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI,
getIndex(Ptr, SI.getStride()), Inst->isVolatile(), SI,
Builder);
}

Expand Down
Loading