From 79c66c3727c2557cf3a25c4ad350dd37c47c6027 Mon Sep 17 00:00:00 2001 From: Carolyn Nguyen Date: Wed, 22 Sep 2021 17:21:50 -0700 Subject: [PATCH 1/5] Add retry to Retriers when passed to constructor and add catch to Catchers when passed to constructor --- src/stepfunctions/steps/states.py | 42 +++++++++++++++----- tests/unit/test_steps.py | 64 ++++++++++++++++++++++++++++++- 2 files changed, 96 insertions(+), 10 deletions(-) diff --git a/src/stepfunctions/steps/states.py b/src/stepfunctions/steps/states.py index 9669a73..822819b 100644 --- a/src/stepfunctions/steps/states.py +++ b/src/stepfunctions/steps/states.py @@ -254,25 +254,25 @@ def accept(self, visitor): def add_retry(self, retry): """ - Add a Retry block to the tail end of the list of retriers for the state. + Add a Retry block or a list of Retry blocks to the tail end of the list of retriers for the state. Args: - retry (Retry): Retry block to add. + retry (Retry or list(Retry)): Retry block(s) to add. """ if Field.Retry in self.allowed_fields(): - self.retries.append(retry) + self.retries.extend(retry) if isinstance(retry, list) else self.retries.append(retry) else: raise ValueError("{state_type} state does not support retry field. ".format(state_type=type(self).__name__)) def add_catch(self, catch): """ - Add a Catch block to the tail end of the list of catchers for the state. + Add a Catch block or a list of Catch blocks to the tail end of the list of catchers for the state. Args: - catch (Catch): Catch block to add. + catch (Catch or list(Catch): Catch block(s) to add. """ if Field.Catch in self.allowed_fields(): - self.catches.append(catch) + self.catches.extend(catch) if isinstance(catch, list) else self.catches.append(catch) else: raise ValueError("{state_type} state does not support catch field. ".format(state_type=type(self).__name__)) @@ -487,10 +487,12 @@ class Parallel(State): A Parallel state causes the interpreter to execute each branch as concurrently as possible, and wait until each branch terminates (reaches a terminal state) before processing the next state in the Chain. """ - def __init__(self, state_id, **kwargs): + def __init__(self, state_id, retry=None, catch=None, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. + retry (Retry or list(Retry), optional): Retry block(s) to add to list of Retriers that define a retry policy in case the state encounters runtime errors + catch (Catch or list(Catch), optional): Catch block(s) to add to list of Catchers that define a fallback state that is executed if the state encounters runtime errors and its retry policy is exhausted or isn't defined comment (str, optional): Human-readable comment or description. (default: None) input_path (str, optional): Path applied to the state’s raw input to select some or all of it; that selection is used by the state. (default: '$') parameters (dict, optional): The value of this field becomes the effective input for the state. @@ -500,6 +502,12 @@ def __init__(self, state_id, **kwargs): super(Parallel, self).__init__(state_id, 'Parallel', **kwargs) self.branches = [] + if retry: + self.add_retry(retry) + + if catch: + self.add_catch(catch) + def allowed_fields(self): return [ Field.Comment, @@ -536,11 +544,13 @@ class Map(State): A Map state can accept an input with a list of items, execute a state or chain for each item in the list, and return a list, with all corresponding results of each execution, as its output. """ - def __init__(self, state_id, **kwargs): + def __init__(self, state_id, retry=None, catch=None, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. iterator (State or Chain): State or chain to execute for each of the items in `items_path`. + retry (Retry or list(Retry), optional): Retry block(s) to add to list of Retriers that define a retry policy in case the state encounters runtime errors + catch (Catch or list(Catch), optional): Catch block(s) to add to list of Catchers that define a fallback state that is executed if the state encounters runtime errors and its retry policy is exhausted or isn't defined items_path (str, optional): Path in the input for items to iterate over. (default: '$') max_concurrency (int, optional): Maximum number of iterations to have running at any given point in time. (default: 0) comment (str, optional): Human-readable comment or description. (default: None) @@ -551,6 +561,12 @@ def __init__(self, state_id, **kwargs): """ super(Map, self).__init__(state_id, 'Map', **kwargs) + if retry: + self.add_retry(retry) + + if catch: + self.add_catch(catch) + def attach_iterator(self, iterator): """ Attach `State` or `Chain` as iterator to the Map state, that will execute for each of the items in `items_path`. If an iterator was attached previously with the Map state, it will be replaced. @@ -586,10 +602,12 @@ class Task(State): Task State causes the interpreter to execute the work identified by the state’s `resource` field. """ - def __init__(self, state_id, **kwargs): + def __init__(self, state_id, retry=None, catch=None, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. + retry (Retry or list(Retry), optional): Retry block(s) to add to list of Retriers that define a retry policy in case the state encounters runtime errors + catch (Catch or list(Catch), optional): Catch block(s) to add to list of Catchers that define a fallback state that is executed if the state encounters runtime errors and its retry policy is exhausted or isn't defined resource (str): A URI that uniquely identifies the specific task to execute. The States language does not constrain the URI scheme nor any other part of the URI. timeout_seconds (int, optional): Positive integer specifying timeout for the state in seconds. If the state runs longer than the specified timeout, then the interpreter fails the state with a `States.Timeout` Error Name. (default: 60) timeout_seconds_path (str, optional): Path specifying the state's timeout value in seconds from the state input. When resolved, the path must select a field whose value is a positive integer. @@ -608,6 +626,12 @@ def __init__(self, state_id, **kwargs): if self.heartbeat_seconds is not None and self.heartbeat_seconds_path is not None: raise ValueError("Only one of 'heartbeat_seconds' or 'heartbeat_seconds_path' can be provided.") + if retry: + self.add_retry(retry) + + if catch: + self.add_catch(catch) + def allowed_fields(self): return [ Field.Comment, diff --git a/tests/unit/test_steps.py b/tests/unit/test_steps.py index 5c86279..833bf8f 100644 --- a/tests/unit/test_steps.py +++ b/tests/unit/test_steps.py @@ -469,4 +469,66 @@ def test_default_paths_not_converted_to_null(): assert '"OutputPath": null' not in task_state.to_json() - +RETRY = Retry(error_equals=['ErrorA', 'ErrorB'], interval_seconds=1, max_attempts=2, backoff_rate=2) +RETRIES = [RETRY, Retry(error_equals=['ErrorC'], interval_seconds=5)] +EXPECTED_RETRY = [{'ErrorEquals': ['ErrorA', 'ErrorB'], 'IntervalSeconds': 1, 'BackoffRate': 2, 'MaxAttempts': 2}] +EXPECTED_RETRIES = EXPECTED_RETRY + [{'ErrorEquals': ['ErrorC'], 'IntervalSeconds': 5}] + +CATCH = Catch(error_equals=['States.ALL'], next_step=Pass('End State')) +CATCHES = [CATCH, Catch(error_equals=['States.TaskFailed'], next_step=Pass('Next State'))] +EXPECTED_CATCH = [{'ErrorEquals': ['States.ALL'], 'Next': 'End State'}] +EXPECTED_CATCHES = EXPECTED_CATCH + [{'ErrorEquals': ['States.TaskFailed'], 'Next': 'Next State'}] + + +@pytest.mark.parametrize("state, state_id, extra_args, retry, expected_retry", [ + (Parallel, 'Parallel', {}, RETRY, EXPECTED_RETRY), + (Parallel, 'Parallel', {}, RETRIES, EXPECTED_RETRIES), + (Map, 'Map', {'iterator': Pass('Iterator')}, RETRY, EXPECTED_RETRY), + (Map, 'Map', {'iterator': Pass('Iterator')}, RETRIES, EXPECTED_RETRIES), + (Task, 'Task', {}, RETRY, EXPECTED_RETRY), + (Task, 'Task', {}, RETRIES, EXPECTED_RETRIES) +]) +def test_state_creation_with_retry(state, state_id, extra_args, retry, expected_retry): + step = state(state_id, retry=retry, **extra_args) + assert step.to_dict()['Retry'] == expected_retry + + +@pytest.mark.parametrize("state, state_id, extra_args, catch, expected_catch", [ + (Parallel, 'Parallel', {}, CATCH, EXPECTED_CATCH), + (Parallel, 'Parallel', {}, CATCHES, EXPECTED_CATCHES), + (Map, 'Map', {'iterator': Pass('Iterator')}, CATCH, EXPECTED_CATCH), + (Map, 'Map', {'iterator': Pass('Iterator')}, CATCHES, EXPECTED_CATCHES), + (Task, 'Task', {}, CATCH, EXPECTED_CATCH), + (Task, 'Task', {}, CATCHES, EXPECTED_CATCHES) +]) +def test_state_creation_with_catch(state, state_id, extra_args, catch, expected_catch): + step = state(state_id, catch=catch, **extra_args) + assert step.to_dict()['Catch'] == expected_catch + + +@pytest.mark.parametrize("state, state_id, extra_args, retry, expected_retry", [ + (Parallel, 'Parallel', {}, RETRY, EXPECTED_RETRY), + (Parallel, 'Parallel', {}, RETRIES, EXPECTED_RETRIES), + (Map, 'Map', {'iterator': Pass('Iterator')}, RETRIES, EXPECTED_RETRIES), + (Map, 'Map', {'iterator': Pass('Iterator')}, RETRY, EXPECTED_RETRY), + (Task, 'Task', {}, RETRY, EXPECTED_RETRY), + (Task, 'Task', {}, RETRIES, EXPECTED_RETRIES) +]) +def test_state_with_added_retry(state, state_id, extra_args, retry, expected_retry): + step = state(state_id, **extra_args) + step.add_retry(retry) + assert step.to_dict()['Retry'] == expected_retry + + +@pytest.mark.parametrize("state, state_id, extra_args, catch, expected_catch", [ + (Parallel, 'Parallel', {}, CATCH, EXPECTED_CATCH), + (Parallel, 'Parallel', {}, CATCHES, EXPECTED_CATCHES), + (Map, 'Map', {'iterator': Pass('Iterator')}, CATCH, EXPECTED_CATCH), + (Map, 'Map', {'iterator': Pass('Iterator')}, CATCHES, EXPECTED_CATCHES), + (Task, 'Task', {}, CATCHES, EXPECTED_CATCHES), + (Task, 'Task', {}, CATCHES, EXPECTED_CATCHES) +]) +def test_state_with_added_catch(state, state_id, extra_args, catch, expected_catch): + step = state(state_id, **extra_args) + step.add_catch(catch) + assert step.to_dict()['Catch'] == expected_catch From 7f0ceff62f0ba88ef4e06af748510e069bcb6d5c Mon Sep 17 00:00:00 2001 From: Carolyn Nguyen Date: Wed, 22 Sep 2021 17:53:00 -0700 Subject: [PATCH 2/5] Added integ tests --- tests/integ/test_state_machine_definition.py | 107 ++++++++++++++++++- 1 file changed, 106 insertions(+), 1 deletion(-) diff --git a/tests/integ/test_state_machine_definition.py b/tests/integ/test_state_machine_definition.py index d21e59b..01a0f32 100644 --- a/tests/integ/test_state_machine_definition.py +++ b/tests/integ/test_state_machine_definition.py @@ -479,6 +479,59 @@ def test_catch_state_machine_creation(sfn_client, sfn_role_arn, training_job_par workflow_test_suite(sfn_client, workflow, asl_state_machine_definition, catch_state_result) +def test_state_machine_creation_with_catch_in_constructor(sfn_client, sfn_role_arn, training_job_parameters): + catch_state_name = "TaskWithCatchState" + all_fail_error = "States.ALL" + all_error_state_name = "Catch All End" + catch_state_result = "Catch Result" + task_resource = f"arn:{get_aws_partition()}:states:::sagemaker:createTrainingJob.sync" + + # change the parameters to cause task state to fail + training_job_parameters["AlgorithmSpecification"]["TrainingImage"] = "not_an_image" + + asl_state_machine_definition = { + "StartAt": catch_state_name, + "States": { + catch_state_name: { + "Resource": task_resource, + "Parameters": training_job_parameters, + "Type": "Task", + "End": True, + "Catch": [ + { + "ErrorEquals": [ + all_fail_error + ], + "Next": all_error_state_name + } + ] + }, + all_error_state_name: { + "Type": "Pass", + "Result": catch_state_result, + "End": True + } + } + } + task = steps.Task( + catch_state_name, + parameters=training_job_parameters, + resource=task_resource, + catch=steps.Catch( + error_equals=[all_fail_error], + next_step=steps.Pass(all_error_state_name, result=catch_state_result) + ) + ) + + workflow = Workflow( + unique_name_from_base('Test_Catch_In_Constructor_Workflow'), + definition=task, + role=sfn_role_arn + ) + + workflow_test_suite(sfn_client, workflow, asl_state_machine_definition, catch_state_result) + + def test_retry_state_machine_creation(sfn_client, sfn_role_arn, training_job_parameters): retry_state_name = "RetryStateName" all_fail_error = "Starts.ALL" @@ -531,4 +584,56 @@ def test_retry_state_machine_creation(sfn_client, sfn_role_arn, training_job_par role=sfn_role_arn ) - workflow_test_suite(sfn_client, workflow, asl_state_machine_definition, None) \ No newline at end of file + workflow_test_suite(sfn_client, workflow, asl_state_machine_definition, None) + + +def test_state_machine_creation_with_retry_in_constructor(sfn_client, sfn_role_arn, training_job_parameters): + retry_state_name = "RetryStateName" + all_fail_error = "Starts.ALL" + interval_seconds = 1 + max_attempts = 2 + backoff_rate = 2 + task_resource = f"arn:{get_aws_partition()}:states:::sagemaker:createTrainingJob.sync" + + # change the parameters to cause task state to fail + training_job_parameters["AlgorithmSpecification"]["TrainingImage"] = "not_an_image" + + asl_state_machine_definition = { + "StartAt": retry_state_name, + "States": { + retry_state_name: { + "Resource": task_resource, + "Parameters": training_job_parameters, + "Type": "Task", + "End": True, + "Retry": [ + { + "ErrorEquals": [all_fail_error], + "IntervalSeconds": interval_seconds, + "MaxAttempts": max_attempts, + "BackoffRate": backoff_rate + } + ] + } + } + } + + task = steps.Task( + retry_state_name, + parameters=training_job_parameters, + resource=task_resource, + retry=steps.Retry( + error_equals=[all_fail_error], + interval_seconds=interval_seconds, + max_attempts=max_attempts, + backoff_rate=backoff_rate + ) + ) + + workflow = Workflow( + unique_name_from_base('Test_Retry_In_Constructor_Workflow'), + definition=task, + role=sfn_role_arn + ) + + workflow_test_suite(sfn_client, workflow, asl_state_machine_definition, None) From 30e8da934c7b386129944b1e949b749b5eaa44a8 Mon Sep 17 00:00:00 2001 From: Carolyn Nguyen Date: Wed, 22 Sep 2021 18:12:19 -0700 Subject: [PATCH 3/5] Use fstring in raised exception --- src/stepfunctions/steps/states.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/stepfunctions/steps/states.py b/src/stepfunctions/steps/states.py index 822819b..40c71da 100644 --- a/src/stepfunctions/steps/states.py +++ b/src/stepfunctions/steps/states.py @@ -262,7 +262,7 @@ def add_retry(self, retry): if Field.Retry in self.allowed_fields(): self.retries.extend(retry) if isinstance(retry, list) else self.retries.append(retry) else: - raise ValueError("{state_type} state does not support retry field. ".format(state_type=type(self).__name__)) + raise ValueError(f"{type(self).__name__} state does not support retry field. ") def add_catch(self, catch): """ @@ -274,7 +274,7 @@ def add_catch(self, catch): if Field.Catch in self.allowed_fields(): self.catches.extend(catch) if isinstance(catch, list) else self.catches.append(catch) else: - raise ValueError("{state_type} state does not support catch field. ".format(state_type=type(self).__name__)) + raise ValueError(f"{type(self).__name__} state does not support catch field. ") def to_dict(self): result = super(State, self).to_dict() From abf2656702afe4f3b7c0003287d2522021e967c3 Mon Sep 17 00:00:00 2001 From: Carolyn Nguyen Date: Wed, 6 Oct 2021 23:07:24 -0700 Subject: [PATCH 4/5] Write unit test for each state and update docstrings to use ASL retrier and catcher terms --- src/stepfunctions/steps/states.py | 22 ++--- tests/unit/test_steps.py | 138 +++++++++++++++++++++--------- 2 files changed, 111 insertions(+), 49 deletions(-) diff --git a/src/stepfunctions/steps/states.py b/src/stepfunctions/steps/states.py index 40c71da..8396e69 100644 --- a/src/stepfunctions/steps/states.py +++ b/src/stepfunctions/steps/states.py @@ -254,10 +254,11 @@ def accept(self, visitor): def add_retry(self, retry): """ - Add a Retry block or a list of Retry blocks to the tail end of the list of retriers for the state. + Add a retrier or a list of retriers to the tail end of the list of retriers for the state. + See `Error handling in Step Functions `_ for more details. Args: - retry (Retry or list(Retry)): Retry block(s) to add. + retry (Retry or list(Retry)): A retrier or list of retriers to add. """ if Field.Retry in self.allowed_fields(): self.retries.extend(retry) if isinstance(retry, list) else self.retries.append(retry) @@ -266,10 +267,11 @@ def add_retry(self, retry): def add_catch(self, catch): """ - Add a Catch block or a list of Catch blocks to the tail end of the list of catchers for the state. + Add a catcher or a list of catchers to the tail end of the list of catchers for the state. + See `Error handling in Step Functions `_ for more details. Args: - catch (Catch or list(Catch): Catch block(s) to add. + catch (Catch or list(Catch): catcher or list of catchers to add. """ if Field.Catch in self.allowed_fields(): self.catches.extend(catch) if isinstance(catch, list) else self.catches.append(catch) @@ -491,8 +493,8 @@ def __init__(self, state_id, retry=None, catch=None, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. - retry (Retry or list(Retry), optional): Retry block(s) to add to list of Retriers that define a retry policy in case the state encounters runtime errors - catch (Catch or list(Catch), optional): Catch block(s) to add to list of Catchers that define a fallback state that is executed if the state encounters runtime errors and its retry policy is exhausted or isn't defined + retry (Retry or list(Retry), optional): A retrier or list of retriers that define the state's retry policy. See `Error handling in Step Functions `_ for more details. + catch (Catch or list(Catch), optional): A catcher or list of catchers that define a fallback state. See `Error handling in Step Functions `_ for more details. comment (str, optional): Human-readable comment or description. (default: None) input_path (str, optional): Path applied to the state’s raw input to select some or all of it; that selection is used by the state. (default: '$') parameters (dict, optional): The value of this field becomes the effective input for the state. @@ -549,8 +551,8 @@ def __init__(self, state_id, retry=None, catch=None, **kwargs): Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. iterator (State or Chain): State or chain to execute for each of the items in `items_path`. - retry (Retry or list(Retry), optional): Retry block(s) to add to list of Retriers that define a retry policy in case the state encounters runtime errors - catch (Catch or list(Catch), optional): Catch block(s) to add to list of Catchers that define a fallback state that is executed if the state encounters runtime errors and its retry policy is exhausted or isn't defined + retry (Retry or list(Retry), optional): A retrier or list of retriers that define the state's retry policy. See `Error handling in Step Functions `_ for more details. + catch (Catch or list(Catch), optional): A catcher or list of catchers that define a fallback state. See `Error handling in Step Functions `_ for more details. items_path (str, optional): Path in the input for items to iterate over. (default: '$') max_concurrency (int, optional): Maximum number of iterations to have running at any given point in time. (default: 0) comment (str, optional): Human-readable comment or description. (default: None) @@ -606,8 +608,8 @@ def __init__(self, state_id, retry=None, catch=None, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. - retry (Retry or list(Retry), optional): Retry block(s) to add to list of Retriers that define a retry policy in case the state encounters runtime errors - catch (Catch or list(Catch), optional): Catch block(s) to add to list of Catchers that define a fallback state that is executed if the state encounters runtime errors and its retry policy is exhausted or isn't defined + retry (Retry or list(Retry), optional): A retrier or list of retriers that define the state's retry policy. See `Error handling in Step Functions `_ for more details. + catch (Catch or list(Catch), optional): A catcher or list of catchers that define a fallback state. See `Error handling in Step Functions `_ for more details. resource (str): A URI that uniquely identifies the specific task to execute. The States language does not constrain the URI scheme nor any other part of the URI. timeout_seconds (int, optional): Positive integer specifying timeout for the state in seconds. If the state runs longer than the specified timeout, then the interpreter fails the state with a `States.Timeout` Error Name. (default: 60) timeout_seconds_path (str, optional): Path specifying the state's timeout value in seconds from the state input. When resolved, the path must select a field whose value is a positive integer. diff --git a/tests/unit/test_steps.py b/tests/unit/test_steps.py index 833bf8f..3d34ee8 100644 --- a/tests/unit/test_steps.py +++ b/tests/unit/test_steps.py @@ -474,61 +474,121 @@ def test_default_paths_not_converted_to_null(): EXPECTED_RETRY = [{'ErrorEquals': ['ErrorA', 'ErrorB'], 'IntervalSeconds': 1, 'BackoffRate': 2, 'MaxAttempts': 2}] EXPECTED_RETRIES = EXPECTED_RETRY + [{'ErrorEquals': ['ErrorC'], 'IntervalSeconds': 5}] + +@pytest.mark.parametrize("retry, expected_retry", [ + (RETRY, EXPECTED_RETRY), + (RETRIES, EXPECTED_RETRIES), +]) +def test_parallel_state_constructor_with_retry_adds_retrier_to_retriers(retry, expected_retry): + step = Parallel('Parallel', retry=retry) + assert step.to_dict()['Retry'] == expected_retry + + +@pytest.mark.parametrize("retry, expected_retry", [ + (RETRY, EXPECTED_RETRY), + (RETRIES, EXPECTED_RETRIES), +]) +def test_parallel_state_add_retry_adds_retrier_to_retriers(retry, expected_retry): + step = Parallel('Parallel') + step.add_retry(retry) + assert step.to_dict()['Retry'] == expected_retry + + +@pytest.mark.parametrize("retry, expected_retry", [ + (RETRY, EXPECTED_RETRY), + (RETRIES, EXPECTED_RETRIES), +]) +def test_map_state_constructor_with_retry_adds_retrier_to_retriers(retry, expected_retry): + step = Map('Map', retry=retry, iterator=Pass('Iterator')) + assert step.to_dict()['Retry'] == expected_retry + + +@pytest.mark.parametrize("retry, expected_retry", [ + (RETRIES, EXPECTED_RETRIES), + (RETRY, EXPECTED_RETRY), +]) +def test_map_state_add_retry_adds_retrier_to_retriers(retry, expected_retry): + step = Map('Map', iterator=Pass('Iterator')) + step.add_retry(retry) + assert step.to_dict()['Retry'] == expected_retry + + +@pytest.mark.parametrize("retry, expected_retry", [ + (RETRY, EXPECTED_RETRY), + (RETRIES, EXPECTED_RETRIES) +]) +def test_task_state_constructor_with_retry_adds_retrier_to_retriers(retry, expected_retry): + step = Task('Task', retry=retry) + assert step.to_dict()['Retry'] == expected_retry + + +@pytest.mark.parametrize("retry, expected_retry", [ + (RETRY, EXPECTED_RETRY), + (RETRIES, EXPECTED_RETRIES) +]) +def test_task_state_add_retry_adds_retrier_to_retriers(retry, expected_retry): + step = Task('Task') + step.add_retry(retry) + assert step.to_dict()['Retry'] == expected_retry + + CATCH = Catch(error_equals=['States.ALL'], next_step=Pass('End State')) CATCHES = [CATCH, Catch(error_equals=['States.TaskFailed'], next_step=Pass('Next State'))] EXPECTED_CATCH = [{'ErrorEquals': ['States.ALL'], 'Next': 'End State'}] EXPECTED_CATCHES = EXPECTED_CATCH + [{'ErrorEquals': ['States.TaskFailed'], 'Next': 'Next State'}] -@pytest.mark.parametrize("state, state_id, extra_args, retry, expected_retry", [ - (Parallel, 'Parallel', {}, RETRY, EXPECTED_RETRY), - (Parallel, 'Parallel', {}, RETRIES, EXPECTED_RETRIES), - (Map, 'Map', {'iterator': Pass('Iterator')}, RETRY, EXPECTED_RETRY), - (Map, 'Map', {'iterator': Pass('Iterator')}, RETRIES, EXPECTED_RETRIES), - (Task, 'Task', {}, RETRY, EXPECTED_RETRY), - (Task, 'Task', {}, RETRIES, EXPECTED_RETRIES) +@pytest.mark.parametrize("catch, expected_catch", [ + (CATCH, EXPECTED_CATCH), + (CATCHES, EXPECTED_CATCHES) ]) -def test_state_creation_with_retry(state, state_id, extra_args, retry, expected_retry): - step = state(state_id, retry=retry, **extra_args) - assert step.to_dict()['Retry'] == expected_retry +def test_parallel_state_constructor_with_catch_adds_catcher_to_catchers(catch, expected_catch): + step = Parallel('Parallel', catch=catch) + assert step.to_dict()['Catch'] == expected_catch + +@pytest.mark.parametrize("catch, expected_catch", [ + (CATCH, EXPECTED_CATCH), + (CATCHES, EXPECTED_CATCHES) +]) +def test_parallel_state_add_catch_adds_catcher_to_catchers(catch, expected_catch): + step = Parallel('Parallel') + step.add_catch(catch) + assert step.to_dict()['Catch'] == expected_catch -@pytest.mark.parametrize("state, state_id, extra_args, catch, expected_catch", [ - (Parallel, 'Parallel', {}, CATCH, EXPECTED_CATCH), - (Parallel, 'Parallel', {}, CATCHES, EXPECTED_CATCHES), - (Map, 'Map', {'iterator': Pass('Iterator')}, CATCH, EXPECTED_CATCH), - (Map, 'Map', {'iterator': Pass('Iterator')}, CATCHES, EXPECTED_CATCHES), - (Task, 'Task', {}, CATCH, EXPECTED_CATCH), - (Task, 'Task', {}, CATCHES, EXPECTED_CATCHES) +@pytest.mark.parametrize("catch, expected_catch", [ + (CATCH, EXPECTED_CATCH), + (CATCHES, EXPECTED_CATCHES) ]) -def test_state_creation_with_catch(state, state_id, extra_args, catch, expected_catch): - step = state(state_id, catch=catch, **extra_args) +def test_map_state_constructor_with_catch_adds_catcher_to_catchers(catch, expected_catch): + step = Map('Map', catch=catch, iterator=Pass('Iterator')) assert step.to_dict()['Catch'] == expected_catch -@pytest.mark.parametrize("state, state_id, extra_args, retry, expected_retry", [ - (Parallel, 'Parallel', {}, RETRY, EXPECTED_RETRY), - (Parallel, 'Parallel', {}, RETRIES, EXPECTED_RETRIES), - (Map, 'Map', {'iterator': Pass('Iterator')}, RETRIES, EXPECTED_RETRIES), - (Map, 'Map', {'iterator': Pass('Iterator')}, RETRY, EXPECTED_RETRY), - (Task, 'Task', {}, RETRY, EXPECTED_RETRY), - (Task, 'Task', {}, RETRIES, EXPECTED_RETRIES) +@pytest.mark.parametrize("catch, expected_catch", [ + (CATCH, EXPECTED_CATCH), + (CATCHES, EXPECTED_CATCHES) ]) -def test_state_with_added_retry(state, state_id, extra_args, retry, expected_retry): - step = state(state_id, **extra_args) - step.add_retry(retry) - assert step.to_dict()['Retry'] == expected_retry +def test_map_state_add_catch_adds_catcher_to_catchers(catch, expected_catch): + step = Map('Map', iterator=Pass('Iterator')) + step.add_catch(catch) + assert step.to_dict()['Catch'] == expected_catch + + +@pytest.mark.parametrize("catch, expected_catch", [ + (CATCH, EXPECTED_CATCH), + (CATCHES, EXPECTED_CATCHES) +]) +def test_task_state_constructor_with_catch_adds_catcher_to_catchers(catch, expected_catch): + step = Task('Task', catch=catch) + assert step.to_dict()['Catch'] == expected_catch -@pytest.mark.parametrize("state, state_id, extra_args, catch, expected_catch", [ - (Parallel, 'Parallel', {}, CATCH, EXPECTED_CATCH), - (Parallel, 'Parallel', {}, CATCHES, EXPECTED_CATCHES), - (Map, 'Map', {'iterator': Pass('Iterator')}, CATCH, EXPECTED_CATCH), - (Map, 'Map', {'iterator': Pass('Iterator')}, CATCHES, EXPECTED_CATCHES), - (Task, 'Task', {}, CATCHES, EXPECTED_CATCHES), - (Task, 'Task', {}, CATCHES, EXPECTED_CATCHES) +@pytest.mark.parametrize("catch, expected_catch", [ + (CATCH, EXPECTED_CATCH), + (CATCHES, EXPECTED_CATCHES) ]) -def test_state_with_added_catch(state, state_id, extra_args, catch, expected_catch): - step = state(state_id, **extra_args) +def test_task_state_add_catch_adds_catcher_to_catchers(catch, expected_catch): + step = Task('Task') step.add_catch(catch) assert step.to_dict()['Catch'] == expected_catch From cd1fc9e30bdf891df563536e9ad9ee0af0854742 Mon Sep 17 00:00:00 2001 From: Carolyn Nguyen Date: Thu, 7 Oct 2021 12:36:21 -0700 Subject: [PATCH 5/5] Removed integ test for retry and catch in constructor and updated existing tests --- tests/integ/test_state_machine_definition.py | 166 +++++-------------- 1 file changed, 44 insertions(+), 122 deletions(-) diff --git a/tests/integ/test_state_machine_definition.py b/tests/integ/test_state_machine_definition.py index 01a0f32..4881b75 100644 --- a/tests/integ/test_state_machine_definition.py +++ b/tests/integ/test_state_machine_definition.py @@ -422,51 +422,29 @@ def test_task_state_machine_creation(sfn_client, sfn_role_arn, training_job_para def test_catch_state_machine_creation(sfn_client, sfn_role_arn, training_job_parameters): catch_state_name = "TaskWithCatchState" - custom_error = "CustomError" task_failed_error = "States.TaskFailed" - all_fail_error = "States.ALL" - custom_error_state_name = "Custom Error End" - task_failed_state_name = "Task Failed End" - all_error_state_name = "Catch All End" + timeout_error = "States.Timeout" + task_failed_state_name = "Catch Task Failed End" + timeout_state_name = "Catch Timeout End" catch_state_result = "Catch Result" task_resource = f"arn:{get_aws_partition()}:states:::sagemaker:createTrainingJob.sync" - # change the parameters to cause task state to fail + # Provide invalid TrainingImage to cause States.TaskFailed error training_job_parameters["AlgorithmSpecification"]["TrainingImage"] = "not_an_image" - asl_state_machine_definition = { - "StartAt": catch_state_name, - "States": { - catch_state_name: { - "Resource": task_resource, - "Parameters": training_job_parameters, - "Type": "Task", - "End": True, - "Catch": [ - { - "ErrorEquals": [ - all_fail_error - ], - "Next": all_error_state_name - } - ] - }, - all_error_state_name: { - "Type": "Pass", - "Result": catch_state_result, - "End": True - } - } - } task = steps.Task( catch_state_name, parameters=training_job_parameters, - resource=task_resource + resource=task_resource, + catch=steps.Catch( + error_equals=[timeout_error], + next_step=steps.Pass(timeout_state_name, result=catch_state_result) + ) ) task.add_catch( steps.Catch( - error_equals=[all_fail_error], - next_step=steps.Pass(all_error_state_name, result=catch_state_result) + error_equals=[task_failed_error], + next_step=steps.Pass(task_failed_state_name, result=catch_state_result) ) ) @@ -476,19 +454,6 @@ def test_catch_state_machine_creation(sfn_client, sfn_role_arn, training_job_par role=sfn_role_arn ) - workflow_test_suite(sfn_client, workflow, asl_state_machine_definition, catch_state_result) - - -def test_state_machine_creation_with_catch_in_constructor(sfn_client, sfn_role_arn, training_job_parameters): - catch_state_name = "TaskWithCatchState" - all_fail_error = "States.ALL" - all_error_state_name = "Catch All End" - catch_state_result = "Catch Result" - task_resource = f"arn:{get_aws_partition()}:states:::sagemaker:createTrainingJob.sync" - - # change the parameters to cause task state to fail - training_job_parameters["AlgorithmSpecification"]["TrainingImage"] = "not_an_image" - asl_state_machine_definition = { "StartAt": catch_state_name, "States": { @@ -500,78 +465,61 @@ def test_state_machine_creation_with_catch_in_constructor(sfn_client, sfn_role_a "Catch": [ { "ErrorEquals": [ - all_fail_error + timeout_error + ], + "Next": timeout_state_name + }, + { + "ErrorEquals": [ + task_failed_error ], - "Next": all_error_state_name + "Next": task_failed_state_name } ] }, - all_error_state_name: { + task_failed_state_name: { "Type": "Pass", "Result": catch_state_result, "End": True - } + }, + timeout_state_name: { + "Type": "Pass", + "Result": catch_state_result, + "End": True + }, } } - task = steps.Task( - catch_state_name, - parameters=training_job_parameters, - resource=task_resource, - catch=steps.Catch( - error_equals=[all_fail_error], - next_step=steps.Pass(all_error_state_name, result=catch_state_result) - ) - ) - - workflow = Workflow( - unique_name_from_base('Test_Catch_In_Constructor_Workflow'), - definition=task, - role=sfn_role_arn - ) workflow_test_suite(sfn_client, workflow, asl_state_machine_definition, catch_state_result) def test_retry_state_machine_creation(sfn_client, sfn_role_arn, training_job_parameters): retry_state_name = "RetryStateName" - all_fail_error = "Starts.ALL" + task_failed_error = "States.TaskFailed" + timeout_error = "States.Timeout" interval_seconds = 1 max_attempts = 2 backoff_rate = 2 task_resource = f"arn:{get_aws_partition()}:states:::sagemaker:createTrainingJob.sync" - # change the parameters to cause task state to fail + # Provide invalid TrainingImage to cause States.TaskFailed error training_job_parameters["AlgorithmSpecification"]["TrainingImage"] = "not_an_image" - asl_state_machine_definition = { - "StartAt": retry_state_name, - "States": { - retry_state_name: { - "Resource": task_resource, - "Parameters": training_job_parameters, - "Type": "Task", - "End": True, - "Retry": [ - { - "ErrorEquals": [all_fail_error], - "IntervalSeconds": interval_seconds, - "MaxAttempts": max_attempts, - "BackoffRate": backoff_rate - } - ] - } - } - } - task = steps.Task( retry_state_name, parameters=training_job_parameters, - resource=task_resource + resource=task_resource, + retry=steps.Retry( + error_equals=[timeout_error], + interval_seconds=interval_seconds, + max_attempts=max_attempts, + backoff_rate=backoff_rate + ) ) task.add_retry( steps.Retry( - error_equals=[all_fail_error], + error_equals=[task_failed_error], interval_seconds=interval_seconds, max_attempts=max_attempts, backoff_rate=backoff_rate @@ -584,20 +532,6 @@ def test_retry_state_machine_creation(sfn_client, sfn_role_arn, training_job_par role=sfn_role_arn ) - workflow_test_suite(sfn_client, workflow, asl_state_machine_definition, None) - - -def test_state_machine_creation_with_retry_in_constructor(sfn_client, sfn_role_arn, training_job_parameters): - retry_state_name = "RetryStateName" - all_fail_error = "Starts.ALL" - interval_seconds = 1 - max_attempts = 2 - backoff_rate = 2 - task_resource = f"arn:{get_aws_partition()}:states:::sagemaker:createTrainingJob.sync" - - # change the parameters to cause task state to fail - training_job_parameters["AlgorithmSpecification"]["TrainingImage"] = "not_an_image" - asl_state_machine_definition = { "StartAt": retry_state_name, "States": { @@ -608,7 +542,13 @@ def test_state_machine_creation_with_retry_in_constructor(sfn_client, sfn_role_a "End": True, "Retry": [ { - "ErrorEquals": [all_fail_error], + "ErrorEquals": [timeout_error], + "IntervalSeconds": interval_seconds, + "MaxAttempts": max_attempts, + "BackoffRate": backoff_rate + }, + { + "ErrorEquals": [task_failed_error], "IntervalSeconds": interval_seconds, "MaxAttempts": max_attempts, "BackoffRate": backoff_rate @@ -618,22 +558,4 @@ def test_state_machine_creation_with_retry_in_constructor(sfn_client, sfn_role_a } } - task = steps.Task( - retry_state_name, - parameters=training_job_parameters, - resource=task_resource, - retry=steps.Retry( - error_equals=[all_fail_error], - interval_seconds=interval_seconds, - max_attempts=max_attempts, - backoff_rate=backoff_rate - ) - ) - - workflow = Workflow( - unique_name_from_base('Test_Retry_In_Constructor_Workflow'), - definition=task, - role=sfn_role_arn - ) - workflow_test_suite(sfn_client, workflow, asl_state_machine_definition, None)