@@ -807,6 +807,200 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
807807  }
808808};
809809
810+ // / Distribute a scattered store op. The offsets argument is required.
811+ // / Both offset and mask vectors must be 1D and have #subgroup_size elements.
812+ // / The layouts are fixed and implicit: one offset/mask per lane.
813+ // / The pass changes the offset/mask vector shapes to a
814+ // / single-element vector, **it is assumed that their producer will also be
815+ // / distributed**. The payload vector also has a fixed distribution:
816+ // /   no chunk size -> vector of one element.
817+ // /   chunk size    -> vector of the innermost dimension of the SG-payload.
818+ // / Example 1 (no chunk size):
819+ // /    %mask = producer_op : vector<16xi1>
820+ // /    %offset = producer_op : vector<16xindex>
821+ // /    xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
822+ // /     memref<256xf16>, vector<16xindex>, vector<16xi1>
823+ // / To
824+ // /    %mask = producer_op : vector<1xi1>
825+ // /    %offset = producer_op : vector<1xindex>
826+ // /    xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
827+ // /     memref<256xf16>, vector<1xindex>, vector<1xi1>
828+ // / Example 2 (chunk size, same mask and offsets):
829+ // /    xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
830+ // /     vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
831+ // / To
832+ // /    xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
833+ // /     vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
834+ struct  StoreDistribution  final  : public gpu::WarpDistributionPattern {
835+   using  gpu::WarpDistributionPattern::WarpDistributionPattern;
836+   LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp,
837+                                 PatternRewriter &rewriter) const  override  {
838+     Operation *lastNode = warpOp.getTerminator ()->getPrevNode ();
839+     auto  storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
840+     if  (!storeScatterOp)
841+       return  failure ();
842+     auto  offsets = storeScatterOp.getOffsets ();
843+     if  (!offsets || !isa<VectorType>(offsets.getType ()))
844+       return  rewriter.notifyMatchFailure (
845+           storeScatterOp, " Store op must have a vector of offsets argument" 
846+     VectorType offsetsTy = cast<VectorType>(offsets.getType ());
847+     VectorType maskTy = cast<VectorType>(storeScatterOp.getMask ().getType ());
848+     if  (offsetsTy.getRank () != 1  || maskTy.getRank () != 1 )
849+       return  rewriter.notifyMatchFailure (storeScatterOp,
850+                                          " Expected 1D offsets and mask vector" 
851+     VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType ());
852+     if  (storeVecTy.getRank () > 2 )
853+       return  rewriter.notifyMatchFailure (
854+           storeScatterOp, " Expected at most 2D result at SG level" 
855+ 
856+     std::string layoutPayloadName =
857+         xegpu::getLayoutName (storeScatterOp->getOpOperand (0 ));
858+     std::string layoutOffsetsName =
859+         xegpu::getLayoutName (storeScatterOp->getOpOperand (2 ));
860+     std::string layoutMaskName =
861+         xegpu::getLayoutName (storeScatterOp->getOpOperand (3 ));
862+ 
863+     xegpu::LayoutAttr layoutPayload =
864+         storeScatterOp->getAttrOfType <xegpu::LayoutAttr>(layoutPayloadName);
865+     xegpu::LayoutAttr layoutOffsets =
866+         storeScatterOp->getAttrOfType <xegpu::LayoutAttr>(layoutOffsetsName);
867+     xegpu::LayoutAttr layoutMask =
868+         storeScatterOp->getAttrOfType <xegpu::LayoutAttr>(layoutMaskName);
869+ 
870+     FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
871+         getDistVecTypeBasedOnLaneLayout (layoutPayload, storeVecTy);
872+     FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
873+         getDistVecTypeBasedOnLaneLayout (layoutOffsets, offsetsTy);
874+     FailureOr<VectorType> distMaskByWarpOpOrFailure =
875+         getDistVecTypeBasedOnLaneLayout (layoutMask, maskTy);
876+     if  (failed (distStoreVecByWarpOpOrFailure) ||
877+         failed (distOffsetsByWarpOpOrFailure) ||
878+         failed (distMaskByWarpOpOrFailure)) {
879+       return  rewriter.notifyMatchFailure (
880+           storeScatterOp,
881+           " Some vector operands have no layouts, using defaults instead." 
882+     }
883+     VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value ();
884+     VectorType expectedPayloadTy = VectorType::get (
885+         {distPayloadTy.getNumElements ()}, distPayloadTy.getElementType ());
886+ 
887+     SmallVector<size_t > newRetIndices;
888+     SmallVector<Value> operands = storeScatterOp->getOperands ();
889+     SmallVector<Type> operandTypesToYield = {
890+         expectedPayloadTy, operands[1 ].getType (),
891+         distOffsetsByWarpOpOrFailure.value (),
892+         distMaskByWarpOpOrFailure.value ()};
893+ 
894+     gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
895+         rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
896+     SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector (
897+         newRetIndices, [&](size_t  idx) { return  newWarpOp.getResult (idx); });
898+ 
899+     rewriter.setInsertionPointAfter (newWarpOp);
900+     xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create (
901+         rewriter, newWarpOp.getLoc (), TypeRange{}, newStoreScatterOpOperands,
902+         storeScatterOp->getAttrs ());
903+     xegpu::removeLayoutAttrs (newOp);
904+     rewriter.eraseOp (storeScatterOp);
905+     return  success ();
906+   }
907+ };
908+ 
909+ // / Distribute a scattered load op. The logic and requirements are the same as
910+ // / for the scattered store distribution. The warpOp's payload vector is
911+ // / expected to be distributed by the load's result consumer.
912+ // / Example 1 (no chunk size):
913+ // /    %mask = producer_op : vector<16xi1>
914+ // /    %offset = producer_op : vector<16xindex>
915+ // /    %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
916+ // /    vector<16xindex>, vector<16xi1> -> vector<16xf16>
917+ // / To
918+ // /    %mask = producer_op : vector<1xi1>
919+ // /    %offset = producer_op : vector<1xindex>
920+ // /    %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
921+ // /     vector<1xindex>, vector<1xi1> -> vector<1xf16>
922+ // / Example 2 (chunk size, same mask and offsets):
923+ // /    %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
924+ // /     memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
925+ // / To
926+ // /    %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
927+ // /     memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
928+ struct  LoadDistribution  final  : public gpu::WarpDistributionPattern {
929+   using  gpu::WarpDistributionPattern::WarpDistributionPattern;
930+   LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp,
931+                                 PatternRewriter &rewriter) const  override  {
932+     OpOperand *producedByLastLoad = getWarpResult (warpOp, [&](Operation *op) {
933+       //  Check if the yield operand that was produced by the *last* scattered
934+       //  load op to avoid sinking it before barriers (maintain memory order).
935+       return  isa<xegpu::LoadGatherOp>(op) &&
936+              warpOp.getTerminator ()->getPrevNode () == op;
937+     });
938+     if  (!producedByLastLoad)
939+       return  rewriter.notifyMatchFailure (
940+           warpOp, " The last op is not xegpu::LoadGatherOp" 
941+ 
942+     auto  loadGatherOp =
943+         producedByLastLoad->get ().getDefiningOp <xegpu::LoadGatherOp>();
944+     auto  offsets = loadGatherOp.getOffsets ();
945+     if  (!offsets || !isa<VectorType>(offsets.getType ()) ||
946+         !isa<VectorType>(loadGatherOp.getMask ().getType ()))
947+       return  rewriter.notifyMatchFailure (
948+           loadGatherOp,
949+           " Load op must have a vector arguments for offsets and mask" 
950+     VectorType offsetsTy = cast<VectorType>(offsets.getType ());
951+     VectorType maskTy = cast<VectorType>(loadGatherOp.getMask ().getType ());
952+     if  (offsetsTy.getRank () != 1  || maskTy.getRank () != 1 )
953+       return  rewriter.notifyMatchFailure (loadGatherOp,
954+                                          " Expected 1D offsets and mask vector" 
955+     //  Assume offset and mask producers will be distributed as well.
956+     std::string layoutOffsetsName =
957+         xegpu::getLayoutName (loadGatherOp->getOpOperand (1 ));
958+     std::string layoutMaskName =
959+         xegpu::getLayoutName (loadGatherOp->getOpOperand (2 ));
960+ 
961+     xegpu::LayoutAttr layoutOffsets =
962+         loadGatherOp->getAttrOfType <xegpu::LayoutAttr>(layoutOffsetsName);
963+     xegpu::LayoutAttr layoutMask =
964+         loadGatherOp->getAttrOfType <xegpu::LayoutAttr>(layoutMaskName);
965+ 
966+     FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
967+         getDistVecTypeBasedOnLaneLayout (layoutOffsets, offsetsTy);
968+     FailureOr<VectorType> distMaskByWarpOpOrFailure =
969+         getDistVecTypeBasedOnLaneLayout (layoutMask, maskTy);
970+     if  (failed (distOffsetsByWarpOpOrFailure) ||
971+         failed (distMaskByWarpOpOrFailure)) {
972+       return  rewriter.notifyMatchFailure (
973+           loadGatherOp,
974+           " Some vector operands have no layouts, using defaults instead." 
975+     }
976+ 
977+     SmallVector<size_t > newRetIndices;
978+     SmallVector<Value> operands = loadGatherOp->getOperands ();
979+     SmallVector<Type> operandTypesToYield = {
980+         operands[0 ].getType (), distOffsetsByWarpOpOrFailure.value (),
981+         distMaskByWarpOpOrFailure.value ()};
982+ 
983+     const  unsigned  operandIdx = producedByLastLoad->getOperandNumber ();
984+     VectorType loadVecTy =
985+         cast<VectorType>(warpOp.getResult (operandIdx).getType ());
986+ 
987+     gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
988+         rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
989+ 
990+     SmallVector<Value> newLoadGatherOperands = llvm::map_to_vector (
991+         newRetIndices, [&](size_t  idx) { return  newWarpOp.getResult (idx); });
992+ 
993+     rewriter.setInsertionPointAfter (newWarpOp);
994+     xegpu::LoadGatherOp newOp = rewriter.create <xegpu::LoadGatherOp>(
995+         newWarpOp.getLoc (), loadVecTy, newLoadGatherOperands,
996+         loadGatherOp->getAttrs ());
997+     xegpu::removeLayoutAttrs (newOp);
998+     Value distributedVal = newWarpOp.getResult (operandIdx);
999+     rewriter.replaceAllUsesWith (distributedVal, newOp->getResult (0 ));
1000+     return  success ();
1001+   }
1002+ };
1003+ 
8101004} //  namespace
8111005
8121006namespace  {
@@ -819,10 +1013,11 @@ struct XeGPUSubgroupDistributePass final
8191013
8201014void  xegpu::populateXeGPUSubgroupDistributePatterns (
8211015    RewritePatternSet &patterns) {
822-   patterns.add <CreateNdDescDistribution, StoreNdDistribution,
823-                LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
824-                UpdateNdOffsetDistribution, GpuBarrierDistribution>(
825-       patterns.getContext ());
1016+   patterns
1017+       .add <CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
1018+            DpasDistribution, PrefetchNdDistribution, UpdateNdOffsetDistribution,
1019+            GpuBarrierDistribution, LoadDistribution, StoreDistribution>(
1020+           patterns.getContext ());
8261021}
8271022
8281023void  XeGPUSubgroupDistributePass::runOnOperation () {
0 commit comments