Skip to content

Commit 1567ec5

Browse files
wprazuchAWarno
authored andcommitted
(feat) Configure request method for progress tracking requests (#213)
Signed-off-by: Wojciech Prazuch <[email protected]> Signed-off-by: Anna Warno <[email protected]>
1 parent 6c2b912 commit 1567ec5

File tree

3 files changed

+94
-15
lines changed

3 files changed

+94
-15
lines changed

packages/nemo-evaluator/src/nemo_evaluator/adapters/adapter_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,7 @@ def from_legacy_config(
537537
"progress_tracking_interval": legacy_config[
538538
"progress_tracking_interval"
539539
],
540+
"request_method": "POST", # Legacy method uses POST
540541
"output_dir": cls._get_default_output_dir(legacy_config, run_config),
541542
}
542543
if legacy_config["progress_tracking_url"] is not None:

packages/nemo-evaluator/src/nemo_evaluator/adapters/interceptors/progress_tracking_interceptor.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""Progress tracking interceptor that tracks number of samples processed via webhook."""
1717

18+
import os
1819
import pathlib
1920
import threading
2021
from typing import Optional, final
@@ -45,12 +46,16 @@ class Params(BaseLoggingParams):
4546

4647
progress_tracking_url: Optional[str] = Field(
4748
default="http://localhost:8000",
48-
description="URL to post the number of processed samples to.",
49+
description="URL to post the number of processed samples to. Supports expansion of shell variables if present.",
4950
)
5051
progress_tracking_interval: int = Field(
5152
default=1,
5253
description="How often (every how many samples) to send a progress information.",
5354
)
55+
request_method: str = Field(
56+
default="PATCH",
57+
description="Request method to use for updating the evaluation progress.",
58+
)
5459
output_dir: Optional[str] = Field(
5560
default=None,
5661
description="Evaluation output directory. If provided, the progress tracking will be saved to a file in this directory.",
@@ -59,6 +64,7 @@ class Params(BaseLoggingParams):
5964
progress_tracking_url: Optional[str]
6065
progress_tracking_interval: int
6166
progress_filepath: Optional[pathlib.Path]
67+
request_method: str
6268

6369
def __init__(self, params: Params):
6470
"""
@@ -67,8 +73,9 @@ def __init__(self, params: Params):
6773
Args:
6874
params: Configuration parameters
6975
"""
70-
self.progress_tracking_url = params.progress_tracking_url
76+
self.progress_tracking_url = os.path.expandvars(params.progress_tracking_url)
7177
self.progress_tracking_interval = params.progress_tracking_interval
78+
self.request_method = params.request_method
7279
if params.output_dir is not None:
7380
output_dir = pathlib.Path(params.output_dir)
7481
output_dir.mkdir(parents=True, exist_ok=True)
@@ -111,20 +118,32 @@ def _write_progress(self, num_samples: int):
111118
samples_processed=num_samples,
112119
)
113120

114-
def _send_progress(self, num_samples: int):
121+
def _send_progress(self, num_samples: int) -> requests.Response:
115122
self.logger.debug(
116123
"Sending progress to tracking server",
117124
url=self.progress_tracking_url,
125+
method=self.request_method,
118126
samples_processed=num_samples,
119127
)
128+
body = {"samples_processed": num_samples}
120129
try:
121-
requests.post(
130+
resp = requests.request(
131+
self.request_method,
122132
self.progress_tracking_url,
123-
json={"samples_processed": num_samples},
124-
)
125-
self.logger.debug(
126-
"Progress sent successfully", samples_processed=num_samples
133+
json=body,
127134
)
135+
if resp.status_code >= 200 and resp.status_code < 300:
136+
self.logger.debug(
137+
"Progress sent successfully", samples_processed=num_samples
138+
)
139+
else:
140+
self.logger.warning(
141+
"Failed to update job progress",
142+
body=body,
143+
status_code=resp.status_code,
144+
response_text=resp.text,
145+
)
146+
return resp
128147
except requests.exceptions.RequestException as e:
129148
self.logger.error(
130149
"Failed to communicate with progress tracking server",

packages/nemo-evaluator/tests/unit_tests/adapters/interceptors/test_progress_tracking_interceptor.py

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import os
1617
import threading
1718
import time
1819
from typing import List
@@ -34,13 +35,13 @@
3435
class FakeProgressTrackingServer:
3536
"""Test server to receive progress tracking webhooks."""
3637

37-
def __init__(self, port: int = 8000):
38+
def __init__(self, port: int = 8000, request_method="PATCH"):
3839
self.port = port
3940
self.app = Flask(__name__)
4041
self.received_updates: List[dict] = []
4142
self.lock = threading.Lock()
4243

43-
@self.app.route("/", methods=["POST"])
44+
@self.app.route("/", methods=[request_method])
4445
def progress_webhook():
4546
"""Receive progress updates."""
4647
data = request.get_json()
@@ -93,6 +94,18 @@ def test_init_with_custom_params(self):
9394
assert interceptor.progress_tracking_url == "http://test-server:9000"
9495
assert interceptor.progress_tracking_interval == 5
9596

97+
@patch.dict(os.environ, {"NEMO_JOB_ID": "job-1234"})
98+
def test_init_url_with_env_expansion(self):
99+
"""Test initialization with URL with env variable is expanded."""
100+
params = ProgressTrackingInterceptor.Params(
101+
progress_tracking_url="http://test-server:8000/jobs/${NEMO_JOB_ID}/status-details"
102+
)
103+
interceptor = ProgressTrackingInterceptor(params)
104+
assert (
105+
interceptor.progress_tracking_url
106+
== "http://test-server:8000/jobs/job-1234/status-details"
107+
)
108+
96109
def test_intercept_response_sends_progress(self):
97110
"""Test that intercept_response sends progress updates."""
98111
# Start test server
@@ -213,11 +226,11 @@ def process_samples():
213226
finally:
214227
server.stop()
215228

216-
@patch("requests.post")
217-
def test_network_error_handling(self, mock_post):
229+
@patch("requests.request")
230+
def test_network_error_handling(self, mock_request):
218231
"""Test that network errors are handled gracefully."""
219-
# Mock requests.post to raise an exception
220-
mock_post.side_effect = requests.exceptions.RequestException(
232+
# Mock requests.patch to raise an exception
233+
mock_request.side_effect = requests.exceptions.RequestException(
221234
"Connection failed"
222235
)
223236

@@ -240,7 +253,7 @@ def test_network_error_handling(self, mock_post):
240253
assert result == mock_response
241254

242255
# Verify that the request was attempted
243-
mock_post.assert_called_once()
256+
mock_request.assert_called_once()
244257

245258
def test_interval_configuration(self):
246259
"""Test different interval configurations."""
@@ -308,6 +321,52 @@ def test_json_payload_format(self):
308321
finally:
309322
server.stop()
310323

324+
def test_configured_method(self):
325+
"""Test that the JSON payload has the correct format."""
326+
# Start test server
327+
server = FakeProgressTrackingServer(port=8006, request_method="POST")
328+
server.start()
329+
330+
try:
331+
# Create interceptor
332+
params = ProgressTrackingInterceptor.Params(
333+
progress_tracking_url="http://localhost:8006",
334+
progress_tracking_interval=1,
335+
request_method="POST",
336+
)
337+
interceptor = ProgressTrackingInterceptor(params)
338+
339+
mock_response = AdapterResponse(
340+
r=requests.Response(),
341+
rctx=AdapterRequestContext(),
342+
)
343+
context = AdapterGlobalContext(output_dir="/tmp", url="http://test")
344+
345+
# Process one sample
346+
interceptor.intercept_response(mock_response, context)
347+
348+
# Check the payload format
349+
updates = server.get_updates()
350+
assert len(updates) == 1
351+
assert "samples_processed" in updates[0]
352+
assert updates[0]["samples_processed"] == 1
353+
assert isinstance(updates[0]["samples_processed"], int)
354+
355+
# Verify PATCH does not update the server
356+
params = ProgressTrackingInterceptor.Params(
357+
progress_tracking_url="http://localhost:8006",
358+
progress_tracking_interval=1,
359+
request_method="PATCH",
360+
)
361+
interceptor = ProgressTrackingInterceptor(params)
362+
interceptor.intercept_response(mock_response, context)
363+
assert updates == server.get_updates(), (
364+
"server should not update with misconfigured method"
365+
)
366+
367+
finally:
368+
server.stop()
369+
311370

312371
if __name__ == "__main__":
313372
# Simple test runner for manual testing

0 commit comments

Comments
 (0)