1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16+ import os
1617import threading
1718import time
1819from typing import List
3435class FakeProgressTrackingServer :
3536 """Test server to receive progress tracking webhooks."""
3637
37- def __init__ (self , port : int = 8000 ):
38+ def __init__ (self , port : int = 8000 , request_method = "PATCH" ):
3839 self .port = port
3940 self .app = Flask (__name__ )
4041 self .received_updates : List [dict ] = []
4142 self .lock = threading .Lock ()
4243
43- @self .app .route ("/" , methods = ["POST" ])
44+ @self .app .route ("/" , methods = [request_method ])
4445 def progress_webhook ():
4546 """Receive progress updates."""
4647 data = request .get_json ()
@@ -93,6 +94,18 @@ def test_init_with_custom_params(self):
9394 assert interceptor .progress_tracking_url == "http://test-server:9000"
9495 assert interceptor .progress_tracking_interval == 5
9596
97+ @patch .dict (os .environ , {"NEMO_JOB_ID" : "job-1234" })
98+ def test_init_url_with_env_expansion (self ):
99+ """Test initialization with URL with env variable is expanded."""
100+ params = ProgressTrackingInterceptor .Params (
101+ progress_tracking_url = "http://test-server:8000/jobs/${NEMO_JOB_ID}/status-details"
102+ )
103+ interceptor = ProgressTrackingInterceptor (params )
104+ assert (
105+ interceptor .progress_tracking_url
106+ == "http://test-server:8000/jobs/job-1234/status-details"
107+ )
108+
96109 def test_intercept_response_sends_progress (self ):
97110 """Test that intercept_response sends progress updates."""
98111 # Start test server
@@ -213,11 +226,11 @@ def process_samples():
213226 finally :
214227 server .stop ()
215228
216- @patch ("requests.post " )
217- def test_network_error_handling (self , mock_post ):
229+ @patch ("requests.request " )
230+ def test_network_error_handling (self , mock_request ):
218231 """Test that network errors are handled gracefully."""
219- # Mock requests.post to raise an exception
220- mock_post .side_effect = requests .exceptions .RequestException (
232+ # Mock requests.patch to raise an exception
233+ mock_request .side_effect = requests .exceptions .RequestException (
221234 "Connection failed"
222235 )
223236
@@ -240,7 +253,7 @@ def test_network_error_handling(self, mock_post):
240253 assert result == mock_response
241254
242255 # Verify that the request was attempted
243- mock_post .assert_called_once ()
256+ mock_request .assert_called_once ()
244257
245258 def test_interval_configuration (self ):
246259 """Test different interval configurations."""
@@ -308,6 +321,52 @@ def test_json_payload_format(self):
308321 finally :
309322 server .stop ()
310323
324+ def test_configured_method (self ):
325+ """Test that the JSON payload has the correct format."""
326+ # Start test server
327+ server = FakeProgressTrackingServer (port = 8006 , request_method = "POST" )
328+ server .start ()
329+
330+ try :
331+ # Create interceptor
332+ params = ProgressTrackingInterceptor .Params (
333+ progress_tracking_url = "http://localhost:8006" ,
334+ progress_tracking_interval = 1 ,
335+ request_method = "POST" ,
336+ )
337+ interceptor = ProgressTrackingInterceptor (params )
338+
339+ mock_response = AdapterResponse (
340+ r = requests .Response (),
341+ rctx = AdapterRequestContext (),
342+ )
343+ context = AdapterGlobalContext (output_dir = "/tmp" , url = "http://test" )
344+
345+ # Process one sample
346+ interceptor .intercept_response (mock_response , context )
347+
348+ # Check the payload format
349+ updates = server .get_updates ()
350+ assert len (updates ) == 1
351+ assert "samples_processed" in updates [0 ]
352+ assert updates [0 ]["samples_processed" ] == 1
353+ assert isinstance (updates [0 ]["samples_processed" ], int )
354+
355+ # Verify PATCH does not update the server
356+ params = ProgressTrackingInterceptor .Params (
357+ progress_tracking_url = "http://localhost:8006" ,
358+ progress_tracking_interval = 1 ,
359+ request_method = "PATCH" ,
360+ )
361+ interceptor = ProgressTrackingInterceptor (params )
362+ interceptor .intercept_response (mock_response , context )
363+ assert updates == server .get_updates (), (
364+ "server should not update with misconfigured method"
365+ )
366+
367+ finally :
368+ server .stop ()
369+
311370
312371if __name__ == "__main__" :
313372 # Simple test runner for manual testing
0 commit comments