@@ -60,6 +60,41 @@ def filter_matrix_item(
6060 return True
6161
6262
63+ def create_distributed_config (item : Dict [str , Any ]) -> Dict [str , Any ]:
64+ """Create distributed test configuration from a regular config.
65+
66+ Takes a standard test config and modifies it for distributed testing:
67+ - Changes runner to multi-GPU instance
68+ - Adds num_gpus field
69+ - Adds config marker
70+ """
71+ import sys
72+
73+ # Create a copy to avoid modifying the original
74+ dist_item = item .copy ()
75+
76+ # Debug: Show original config
77+ print (f"[DEBUG] Creating distributed config from:" , file = sys .stderr )
78+ print (f"[DEBUG] Python: { item .get ('python_version' )} " , file = sys .stderr )
79+ print (f"[DEBUG] CUDA: { item .get ('desired_cuda' )} " , file = sys .stderr )
80+ print (
81+ f"[DEBUG] Original runner: { item .get ('validation_runner' )} " , file = sys .stderr
82+ )
83+
84+ # Override runner to use multi-GPU instance
85+ dist_item ["validation_runner" ] = "linux.g4dn.12xlarge.nvidia.gpu"
86+
87+ # Add distributed-specific fields
88+ dist_item ["num_gpus" ] = 2
89+ dist_item ["config" ] = "distributed"
90+
91+ # Debug: Show modified config
92+ print (f"[DEBUG] New runner: { dist_item ['validation_runner' ]} " , file = sys .stderr )
93+ print (f"[DEBUG] GPUs: { dist_item ['num_gpus' ]} " , file = sys .stderr )
94+
95+ return dist_item
96+
97+
6398def main (args : list [str ]) -> None :
6499 parser = argparse .ArgumentParser ()
65100 parser .add_argument (
@@ -99,16 +134,69 @@ def main(args: list[str]) -> None:
99134
100135 includes = matrix_dict ["include" ]
101136 filtered_includes = []
137+ distributed_includes = [] # NEW: separate list for distributed configs
138+
139+ print (f"[DEBUG] Processing { len (includes )} input configs" , file = sys .stderr )
102140
103141 for item in includes :
142+ py_ver = item .get ("python_version" , "unknown" )
143+ cuda_ver = item .get ("desired_cuda" , "unknown" )
144+
145+ print (f"[DEBUG] Checking config: py={ py_ver } , cuda={ cuda_ver } " , file = sys .stderr )
146+
104147 if filter_matrix_item (
105148 item ,
106149 options .jetpack == "true" ,
107150 options .limit_pr_builds == "true" ,
108151 ):
152+ print (f"[DEBUG] passed filter - adding to build matrix" , file = sys .stderr )
109153 filtered_includes .append (item )
110154
111- filtered_matrix_dict = {"include" : filtered_includes }
155+ # NEW: Create distributed variant for specific configs
156+ # Only Python 3.10 + CUDA 13.0 for now
157+ if item ["python_version" ] == "3.10" and item ["desired_cuda" ] == "cu130" :
158+ print (
159+ f"[DEBUG] Creating distributed config for py3.10+cu130" ,
160+ file = sys .stderr ,
161+ )
162+ distributed_includes .append (create_distributed_config (item ))
163+ else :
164+ print (f"[DEBUG] FILTERED OUT" , file = sys .stderr )
165+
166+ # Debug: Show summary
167+ print (f"[DEBUG] Final counts:" , file = sys .stderr )
168+ print (f"[DEBUG] Regular configs: { len (filtered_includes )} " , file = sys .stderr )
169+ print (
170+ f"[DEBUG] Distributed configs: { len (distributed_includes )} " , file = sys .stderr
171+ )
172+
173+ # Debug: Show which configs will be built
174+ print (
175+ f"[DEBUG] Configs that will be BUILT (in filtered_includes):" , file = sys .stderr
176+ )
177+ for item in filtered_includes :
178+ print (
179+ f"[DEBUG] - py={ item .get ('python_version' )} , cuda={ item .get ('desired_cuda' )} " ,
180+ file = sys .stderr ,
181+ )
182+
183+ print (
184+ f"[DEBUG] Configs for DISTRIBUTED TESTS (in distributed_includes):" ,
185+ file = sys .stderr ,
186+ )
187+ for item in distributed_includes :
188+ print (
189+ f"[DEBUG] - py={ item .get ('python_version' )} , cuda={ item .get ('desired_cuda' )} , gpus={ item .get ('num_gpus' )} " ,
190+ file = sys .stderr ,
191+ )
192+
193+ # NEW: Output both regular and distributed configs
194+ filtered_matrix_dict = {
195+ "include" : filtered_includes ,
196+ "distributed_include" : distributed_includes , # NEW field
197+ }
198+
199+ # Output to stdout (consumed by GitHub Actions)
112200 print (json .dumps (filtered_matrix_dict ))
113201
114202
0 commit comments