Skip to content

Commit d670b8a

Browse files
authored
[AUTOREVERT] add notification to shared issue (#7200)
* Adds a notification to a shared issue for autorevert * addresses comments on #7198
1 parent c0cada7 commit d670b8a

File tree

5 files changed

+225
-52
lines changed

5 files changed

+225
-52
lines changed

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/__main__.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@
1717

1818

1919
DEFAULT_WORKFLOWS = ["Lint", "trunk", "pull", "inductor"]
20+
DEFAULT_REPO_FULL_NAME = "pytorch/pytorch"
21+
DEFAULT_HOURS = 16
22+
DEFAULT_COMMENT_ISSUE_NUMBER = (
23+
163650 # https://github.com/pytorch/pytorch/issues/163650
24+
)
2025
# Special constant to indicate --hud-html was passed as a flag (without a value)
2126
HUD_HTML_NO_VALUE_FLAG = object()
2227

@@ -106,11 +111,14 @@ def get_opts() -> argparse.Namespace:
106111
+ ' list (e.g., "pull" or "pull,trunk,inductor")',
107112
)
108113
workflow_parser.add_argument(
109-
"--hours", type=int, default=48, help="Lookback window in hours (default: 48)"
114+
"--hours",
115+
type=int,
116+
default=DEFAULT_HOURS,
117+
help=f"Lookback window in hours (default: {DEFAULT_HOURS})",
110118
)
111119
workflow_parser.add_argument(
112120
"--repo-full-name",
113-
default=os.environ.get("REPO_FULL_NAME", "pytorch/pytorch"),
121+
default=os.environ.get("REPO_FULL_NAME", DEFAULT_REPO_FULL_NAME),
114122
help="Full repo name to filter by (owner/repo).",
115123
)
116124
workflow_parser.add_argument(
@@ -140,6 +148,14 @@ def get_opts() -> argparse.Namespace:
140148
"If set, write the run state to HUD HTML; omit a value to use the run timestamp as the filename."
141149
),
142150
)
151+
workflow_parser.add_argument(
152+
"--notify-issue-number",
153+
type=int,
154+
default=int(
155+
os.environ.get("NOTIFY_ISSUE_NUMBER", DEFAULT_COMMENT_ISSUE_NUMBER)
156+
),
157+
help=f"Issue number to notify (default: {DEFAULT_COMMENT_ISSUE_NUMBER})",
158+
)
143159

144160
# workflow-restart-checker subcommand
145161
workflow_restart_parser = subparsers.add_parser(
@@ -219,17 +235,21 @@ def main(*args, **kwargs) -> None:
219235

220236
if opts.subcommand is None:
221237
autorevert_v2(
222-
os.environ.get("WORKFLOWS", "Lint,trunk,pull,inductor").split(","),
223-
hours=int(os.environ.get("HOURS", 16)),
224-
repo_full_name=os.environ.get("REPO_FULL_NAME", "pytorch/pytorch"),
238+
os.environ.get("WORKFLOWS", ",".join(DEFAULT_WORKFLOWS)).split(","),
239+
hours=int(os.environ.get("HOURS", DEFAULT_HOURS)),
240+
notify_issue_number=int(
241+
os.environ.get("NOTIFY_ISSUE_NUMBER", DEFAULT_COMMENT_ISSUE_NUMBER)
242+
),
243+
repo_full_name=os.environ.get("REPO_FULL_NAME", DEFAULT_REPO_FULL_NAME),
225244
restart_action=(RestartAction.LOG if opts.dry_run else RestartAction.RUN),
226-
revert_action=(RevertAction.LOG if opts.dry_run else RevertAction.RUN_LOG),
245+
revert_action=(RevertAction.LOG if opts.dry_run else RevertAction.RUN_NOTIFY),
227246
)
228247
elif opts.subcommand == "autorevert-checker":
229248
# New default behavior under the same subcommand
230-
_signals, _pairs, state_json = autorevert_v2(
249+
_, _, state_json = autorevert_v2(
231250
opts.workflows,
232251
hours=opts.hours,
252+
notify_issue_number=opts.notify_issue_number,
233253
repo_full_name=opts.repo_full_name,
234254
restart_action=(RestartAction.LOG if opts.dry_run else opts.restart_action),
235255
revert_action=(RevertAction.LOG if opts.dry_run else opts.revert_action),

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal_actions.py

Lines changed: 112 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,10 @@ def execute_revert(
226226
)
227227
return False
228228

229-
if not dry_run:
230-
self._comment_pr_notify_revert(commit_sha, sources, ctx)
229+
if ctx.revert_action == RevertAction.RUN_REVERT:
230+
self._comment_pr_do_revert(commit_sha, sources, ctx)
231+
elif ctx.revert_action == RevertAction.RUN_NOTIFY:
232+
self._comment_issue_notify(commit_sha, sources, ctx)
231233

232234
self._logger.insert_event(
233235
repo=ctx.repo_full_name,
@@ -323,6 +325,32 @@ def execute_restart(
323325
)
324326
return True
325327

328+
def _commit_message_check_pr_is_revert(
329+
self, commit_message: str, ctx: RunContext
330+
) -> Optional[int]:
331+
# Look for "Reverted #XXXXX" - indicates a revert action
332+
revert_matches = re.findall(
333+
f"Reverted https://github.com/{ctx.repo_full_name}/pull/(\\d+)",
334+
commit_message,
335+
)
336+
if revert_matches:
337+
pr_number = int(revert_matches[-1])
338+
return pr_number
339+
return None
340+
341+
def _commit_message_check_pr_is_merge(
342+
self, commit_message: str, ctx: RunContext
343+
) -> Optional[int]:
344+
# Look for "Pull Request resolved: #XXXXX" - indicates a merge action
345+
merge_matches = re.findall(
346+
f"Pull Request resolved: https://github.com/{ctx.repo_full_name}/pull/(\\d+)",
347+
commit_message,
348+
)
349+
if merge_matches:
350+
pr_number = int(merge_matches[-1])
351+
return pr_number
352+
return None
353+
326354
def _find_pr_by_sha(
327355
self, commit_sha: str, ctx: RunContext
328356
) -> Optional[Tuple[CommitPRSourceAction, github.PullRequest.PullRequest]]:
@@ -348,13 +376,8 @@ def _find_pr_by_sha(
348376
# This is the most reliable way to determine the pytorchbot action
349377
# Use findall to get all matches and pick the last one (pytorchbot appends at the end)
350378

351-
# Look for "Reverted #XXXXX" - indicates a revert action
352-
revert_matches = re.findall(
353-
r"Reverted https://github.com/pytorch/pytorch/pull/(\d+)",
354-
commit_message,
355-
)
356-
if revert_matches:
357-
pr_number = int(revert_matches[-1]) # Use the last match
379+
pr_number = self._commit_message_check_pr_is_revert(commit_message, ctx)
380+
if pr_number is not None:
358381
try:
359382
pr = repo.get_pull(pr_number)
360383
logging.info(
@@ -370,13 +393,8 @@ def _find_pr_by_sha(
370393
str(e),
371394
)
372395

373-
# Look for "Pull Request resolved: #XXXXX" - indicates a merge action
374-
pr_resolved_matches = re.findall(
375-
r"Pull Request resolved: https://github.com/pytorch/pytorch/pull/(\d+)",
376-
commit_message,
377-
)
378-
if pr_resolved_matches:
379-
pr_number = int(pr_resolved_matches[-1]) # Use the last match
396+
pr_number = self._commit_message_check_pr_is_merge(commit_message, ctx)
397+
if pr_number is not None:
380398
try:
381399
pr = repo.get_pull(pr_number)
382400
logging.info(
@@ -434,13 +452,65 @@ def _find_pr_by_sha(
434452
)
435453
return None
436454

437-
def _comment_pr_notify_revert(
455+
def _comment_issue_notify(
456+
self, commit_sha: str, sources: List[SignalMetadata], ctx: RunContext
457+
) -> bool:
458+
"""Comment on the issue to notify interested stakeholders about the detected autorevert"""
459+
460+
if ctx.revert_action != RevertAction.RUN_NOTIFY:
461+
return False
462+
463+
logging.debug(
464+
"[v2][action] notify for sha %s: finding the issue and notifying stakeholders on issue %s",
465+
commit_sha[:8],
466+
ctx.notify_issue_number,
467+
)
468+
469+
# find the PR from commit_sha on main
470+
pr_result = self._find_pr_by_sha(commit_sha, ctx)
471+
if not pr_result:
472+
logging.error(
473+
"[v2][action] revert for sha %s: no PR found!", commit_sha[:8]
474+
)
475+
return False
476+
477+
try:
478+
issue = (
479+
GHClientFactory()
480+
.client.get_repo(ctx.repo_full_name)
481+
.get_issue(number=ctx.notify_issue_number)
482+
)
483+
action_type, pr = pr_result
484+
issue.create_comment(
485+
f"Autorevert detected a possible offender: {commit_sha[:8]} from PR #{pr.number}.\n"
486+
+ (
487+
"The commit is a revert"
488+
if action_type == CommitPRSourceAction.REVERT
489+
else "The commit is a PR merge"
490+
)
491+
+ "\n"
492+
+ "This commit is breaking the following workflows:\n"
493+
+ "- {}".format("\n- ".join(source.workflow_name for source in sources))
494+
+ "\n"
495+
)
496+
except Exception:
497+
logging.exception(
498+
"[v2][action] revert for sha %s: failed to comment on issue #%d",
499+
commit_sha[:8],
500+
ctx.notify_issue_number,
501+
)
502+
return False
503+
504+
def _comment_pr_do_revert(
438505
self, commit_sha: str, sources: List[SignalMetadata], ctx: RunContext
439506
) -> bool:
440-
"""Comment on the pull request to notify the author about that their PR is breaking signals."""
507+
"""Comment on the pull request to pytorchbot to revert it."""
508+
509+
if ctx.revert_action != RevertAction.RUN_REVERT:
510+
return False
441511

442512
logging.debug(
443-
"[v2][action] revert for sha %s: finding the PR andnotifying author",
513+
"[v2][action] revert for sha %s: finding the PR and notifying author",
444514
commit_sha[:8],
445515
)
446516

@@ -461,35 +531,33 @@ def _comment_pr_notify_revert(
461531
)
462532
return False
463533

464-
# Comment on the PR to notify the author about the revert
465-
comment_body = (
466-
"This PR is breaking the following workflows:\n"
467-
+ "- {}".format("\n- ".join(source.workflow_name for source in sources))
468-
+ "\n\nPlease investigate and fix the issues."
469-
)
470-
471-
pr.create_issue_comment(comment_body)
472-
logging.warning(
473-
"[v2][action] revert for sha %s: notified author in PR #%d",
474-
commit_sha[:8],
475-
pr.number,
476-
)
534+
# TODO Add autorevert cause for pytorchbot OR decide if we need to use
535+
# other causes like weird
477536

478-
if ctx.revert_action == RevertAction.RUN_REVERT:
479-
# TODO Add autorevert cause for pytorchbot OR decide if we need to use
480-
# other causes like weird
481-
482-
# TODO check if the tag `autorevert:disable` is present and don't do the revert
483-
# comment, instead limiting to poke the author
484-
comment_body = (
485-
"XXXX revert -m \"Reverted automatically by pytorch's autorevert, "
486-
+ 'to avoid this behaviour add the tag autorevert:disable" -c autorevert'
537+
# TODO check if the tag `autorevert:disable` is present and don't do the revert
538+
# comment, instead limiting to poke the author
539+
try:
540+
pr.create_issue_comment(
541+
"@pytorchbot revert -m \"Reverted automatically by pytorch's autorevert, "
542+
+ 'to avoid this behaviour add the tag autorevert:disable" -c autorevert\n'
543+
+ "\n"
544+
+ "This PR is breaking the following workflows:\n"
545+
+ "- {}".format("\n- ".join(source.workflow_name for source in sources))
546+
+ "\n\nPlease investigate and fix the issues."
487547
)
488-
pr.create_issue_comment(comment_body)
489-
logging.warning(
490-
"[v2][action] revert for sha %s: requested pytorchbot revert in PR #%d",
548+
except Exception as e:
549+
logging.error( # noqa: G200
550+
"[v2][action] revert for sha %s: failed to comment on PR #%d: %s",
491551
commit_sha[:8],
492552
pr.number,
553+
str(e),
493554
)
555+
return False
556+
557+
logging.warning(
558+
"[v2][action] revert for sha %s: requested pytorchbot revert in PR #%d",
559+
commit_sha[:8],
560+
pr.number,
561+
)
494562

495563
return True

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal_extraction_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
@dataclass(frozen=True)
3333
class RunContext:
3434
lookback_hours: int
35+
notify_issue_number: int
3536
repo_full_name: str
3637
restart_action: RestartAction
3738
revert_action: RevertAction

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/testers/autorevert_v2.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
def autorevert_v2(
1414
workflows: Iterable[str],
1515
*,
16+
notify_issue_number: int,
1617
hours: int = 24,
1718
repo_full_name: str = "pytorch/pytorch",
1819
restart_action: RestartAction = RestartAction.RUN,
@@ -32,12 +33,13 @@ def autorevert_v2(
3233
ts = datetime.now(timezone.utc)
3334

3435
logging.info(
35-
"[v2] Start: workflows=%s hours=%s repo=%s restart_action=%s revert_action=%s",
36+
"[v2] Start: workflows=%s hours=%s repo=%s restart_action=%s revert_action=%s notify_issue_number=%s",
3637
",".join(workflows),
3738
hours,
3839
repo_full_name,
3940
restart_action,
4041
revert_action,
42+
notify_issue_number,
4143
)
4244
logging.info("[v2] Run timestamp (CH log ts) = %s", ts.isoformat())
4345

@@ -59,6 +61,7 @@ def autorevert_v2(
5961
# Build run context
6062
run_ctx = RunContext(
6163
lookback_hours=hours,
64+
notify_issue_number=notify_issue_number,
6265
repo_full_name=repo_full_name,
6366
restart_action=restart_action,
6467
revert_action=revert_action,

0 commit comments

Comments
 (0)