@@ -807,6 +807,200 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
807807 }
808808};
809809
810+ // / Distribute a scattered store op. The offsets argument is required.
811+ // / Both offset and mask vectors must be 1D and have #subgroup_size elements.
812+ // / The layouts are fixed and implicit: one offset/mask per lane.
813+ // / The pass changes the offset/mask vector shapes to a
814+ // / single-element vector, **it is assumed that their producer will also be
815+ // / distributed**. The payload vector also has a fixed distribution:
816+ // / no chunk size -> vector of one element.
817+ // / chunk size -> vector of the innermost dimension of the SG-payload.
818+ // / Example 1 (no chunk size):
819+ // / %mask = producer_op : vector<16xi1>
820+ // / %offset = producer_op : vector<16xindex>
821+ // / xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
822+ // / memref<256xf16>, vector<16xindex>, vector<16xi1>
823+ // / To
824+ // / %mask = producer_op : vector<1xi1>
825+ // / %offset = producer_op : vector<1xindex>
826+ // / xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
827+ // / memref<256xf16>, vector<1xindex>, vector<1xi1>
828+ // / Example 2 (chunk size, same mask and offsets):
829+ // / xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
830+ // / vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
831+ // / To
832+ // / xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
833+ // / vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
834+ struct StoreDistribution final : public gpu::WarpDistributionPattern {
835+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
836+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp,
837+ PatternRewriter &rewriter) const override {
838+ Operation *lastNode = warpOp.getTerminator ()->getPrevNode ();
839+ auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
840+ if (!storeScatterOp)
841+ return failure ();
842+ auto offsets = storeScatterOp.getOffsets ();
843+ if (!offsets || !isa<VectorType>(offsets.getType ()))
844+ return rewriter.notifyMatchFailure (
845+ storeScatterOp, " Store op must have a vector of offsets argument" );
846+ VectorType offsetsTy = cast<VectorType>(offsets.getType ());
847+ VectorType maskTy = cast<VectorType>(storeScatterOp.getMask ().getType ());
848+ if (offsetsTy.getRank () != 1 || maskTy.getRank () != 1 )
849+ return rewriter.notifyMatchFailure (storeScatterOp,
850+ " Expected 1D offsets and mask vector" );
851+ VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType ());
852+ if (storeVecTy.getRank () > 2 )
853+ return rewriter.notifyMatchFailure (
854+ storeScatterOp, " Expected at most 2D result at SG level" );
855+
856+ std::string layoutPayloadName =
857+ xegpu::getLayoutName (storeScatterOp->getOpOperand (0 ));
858+ std::string layoutOffsetsName =
859+ xegpu::getLayoutName (storeScatterOp->getOpOperand (2 ));
860+ std::string layoutMaskName =
861+ xegpu::getLayoutName (storeScatterOp->getOpOperand (3 ));
862+
863+ xegpu::LayoutAttr layoutPayload =
864+ storeScatterOp->getAttrOfType <xegpu::LayoutAttr>(layoutPayloadName);
865+ xegpu::LayoutAttr layoutOffsets =
866+ storeScatterOp->getAttrOfType <xegpu::LayoutAttr>(layoutOffsetsName);
867+ xegpu::LayoutAttr layoutMask =
868+ storeScatterOp->getAttrOfType <xegpu::LayoutAttr>(layoutMaskName);
869+
870+ FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
871+ getDistVecTypeBasedOnLaneLayout (layoutPayload, storeVecTy);
872+ FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
873+ getDistVecTypeBasedOnLaneLayout (layoutOffsets, offsetsTy);
874+ FailureOr<VectorType> distMaskByWarpOpOrFailure =
875+ getDistVecTypeBasedOnLaneLayout (layoutMask, maskTy);
876+ if (failed (distStoreVecByWarpOpOrFailure) ||
877+ failed (distOffsetsByWarpOpOrFailure) ||
878+ failed (distMaskByWarpOpOrFailure)) {
879+ return rewriter.notifyMatchFailure (
880+ storeScatterOp,
881+ " Some vector operands have no layouts, using defaults instead." );
882+ }
883+ VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value ();
884+ VectorType expectedPayloadTy = VectorType::get (
885+ {distPayloadTy.getNumElements ()}, distPayloadTy.getElementType ());
886+
887+ SmallVector<size_t > newRetIndices;
888+ SmallVector<Value> operands = storeScatterOp->getOperands ();
889+ SmallVector<Type> operandTypesToYield = {
890+ expectedPayloadTy, operands[1 ].getType (),
891+ distOffsetsByWarpOpOrFailure.value (),
892+ distMaskByWarpOpOrFailure.value ()};
893+
894+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
895+ rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
896+ SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector (
897+ newRetIndices, [&](size_t idx) { return newWarpOp.getResult (idx); });
898+
899+ rewriter.setInsertionPointAfter (newWarpOp);
900+ xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create (
901+ rewriter, newWarpOp.getLoc (), TypeRange{}, newStoreScatterOpOperands,
902+ storeScatterOp->getAttrs ());
903+ xegpu::removeLayoutAttrs (newOp);
904+ rewriter.eraseOp (storeScatterOp);
905+ return success ();
906+ }
907+ };
908+
909+ // / Distribute a scattered load op. The logic and requirements are the same as
910+ // / for the scattered store distribution. The warpOp's payload vector is
911+ // / expected to be distributed by the load's result consumer.
912+ // / Example 1 (no chunk size):
913+ // / %mask = producer_op : vector<16xi1>
914+ // / %offset = producer_op : vector<16xindex>
915+ // / %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
916+ // / vector<16xindex>, vector<16xi1> -> vector<16xf16>
917+ // / To
918+ // / %mask = producer_op : vector<1xi1>
919+ // / %offset = producer_op : vector<1xindex>
920+ // / %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
921+ // / vector<1xindex>, vector<1xi1> -> vector<1xf16>
922+ // / Example 2 (chunk size, same mask and offsets):
923+ // / %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
924+ // / memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
925+ // / To
926+ // / %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
927+ // / memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
928+ struct LoadDistribution final : public gpu::WarpDistributionPattern {
929+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
930+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp,
931+ PatternRewriter &rewriter) const override {
932+ OpOperand *producedByLastLoad = getWarpResult (warpOp, [&](Operation *op) {
933+ // Check if the yield operand that was produced by the *last* scattered
934+ // load op to avoid sinking it before barriers (maintain memory order).
935+ return isa<xegpu::LoadGatherOp>(op) &&
936+ warpOp.getTerminator ()->getPrevNode () == op;
937+ });
938+ if (!producedByLastLoad)
939+ return rewriter.notifyMatchFailure (
940+ warpOp, " The last op is not xegpu::LoadGatherOp" );
941+
942+ auto loadGatherOp =
943+ producedByLastLoad->get ().getDefiningOp <xegpu::LoadGatherOp>();
944+ auto offsets = loadGatherOp.getOffsets ();
945+ if (!offsets || !isa<VectorType>(offsets.getType ()) ||
946+ !isa<VectorType>(loadGatherOp.getMask ().getType ()))
947+ return rewriter.notifyMatchFailure (
948+ loadGatherOp,
949+ " Load op must have a vector arguments for offsets and mask" );
950+ VectorType offsetsTy = cast<VectorType>(offsets.getType ());
951+ VectorType maskTy = cast<VectorType>(loadGatherOp.getMask ().getType ());
952+ if (offsetsTy.getRank () != 1 || maskTy.getRank () != 1 )
953+ return rewriter.notifyMatchFailure (loadGatherOp,
954+ " Expected 1D offsets and mask vector" );
955+ // Assume offset and mask producers will be distributed as well.
956+ std::string layoutOffsetsName =
957+ xegpu::getLayoutName (loadGatherOp->getOpOperand (1 ));
958+ std::string layoutMaskName =
959+ xegpu::getLayoutName (loadGatherOp->getOpOperand (2 ));
960+
961+ xegpu::LayoutAttr layoutOffsets =
962+ loadGatherOp->getAttrOfType <xegpu::LayoutAttr>(layoutOffsetsName);
963+ xegpu::LayoutAttr layoutMask =
964+ loadGatherOp->getAttrOfType <xegpu::LayoutAttr>(layoutMaskName);
965+
966+ FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
967+ getDistVecTypeBasedOnLaneLayout (layoutOffsets, offsetsTy);
968+ FailureOr<VectorType> distMaskByWarpOpOrFailure =
969+ getDistVecTypeBasedOnLaneLayout (layoutMask, maskTy);
970+ if (failed (distOffsetsByWarpOpOrFailure) ||
971+ failed (distMaskByWarpOpOrFailure)) {
972+ return rewriter.notifyMatchFailure (
973+ loadGatherOp,
974+ " Some vector operands have no layouts, using defaults instead." );
975+ }
976+
977+ SmallVector<size_t > newRetIndices;
978+ SmallVector<Value> operands = loadGatherOp->getOperands ();
979+ SmallVector<Type> operandTypesToYield = {
980+ operands[0 ].getType (), distOffsetsByWarpOpOrFailure.value (),
981+ distMaskByWarpOpOrFailure.value ()};
982+
983+ const unsigned operandIdx = producedByLastLoad->getOperandNumber ();
984+ VectorType loadVecTy =
985+ cast<VectorType>(warpOp.getResult (operandIdx).getType ());
986+
987+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
988+ rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
989+
990+ SmallVector<Value> newLoadGatherOperands = llvm::map_to_vector (
991+ newRetIndices, [&](size_t idx) { return newWarpOp.getResult (idx); });
992+
993+ rewriter.setInsertionPointAfter (newWarpOp);
994+ xegpu::LoadGatherOp newOp = rewriter.create <xegpu::LoadGatherOp>(
995+ newWarpOp.getLoc (), loadVecTy, newLoadGatherOperands,
996+ loadGatherOp->getAttrs ());
997+ xegpu::removeLayoutAttrs (newOp);
998+ Value distributedVal = newWarpOp.getResult (operandIdx);
999+ rewriter.replaceAllUsesWith (distributedVal, newOp->getResult (0 ));
1000+ return success ();
1001+ }
1002+ };
1003+
8101004} // namespace
8111005
8121006namespace {
@@ -819,10 +1013,11 @@ struct XeGPUSubgroupDistributePass final
8191013
8201014void xegpu::populateXeGPUSubgroupDistributePatterns (
8211015 RewritePatternSet &patterns) {
822- patterns.add <CreateNdDescDistribution, StoreNdDistribution,
823- LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
824- UpdateNdOffsetDistribution, GpuBarrierDistribution>(
825- patterns.getContext ());
1016+ patterns
1017+ .add <CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
1018+ DpasDistribution, PrefetchNdDistribution, UpdateNdOffsetDistribution,
1019+ GpuBarrierDistribution, LoadDistribution, StoreDistribution>(
1020+ patterns.getContext ());
8261021}
8271022
8281023void XeGPUSubgroupDistributePass::runOnOperation () {
0 commit comments