diff --git a/src/stepfunctions/steps/states.py b/src/stepfunctions/steps/states.py index 9669a73..8396e69 100644 --- a/src/stepfunctions/steps/states.py +++ b/src/stepfunctions/steps/states.py @@ -254,27 +254,29 @@ def accept(self, visitor): def add_retry(self, retry): """ - Add a Retry block to the tail end of the list of retriers for the state. + Add a retrier or a list of retriers to the tail end of the list of retriers for the state. + See `Error handling in Step Functions `_ for more details. Args: - retry (Retry): Retry block to add. + retry (Retry or list(Retry)): A retrier or list of retriers to add. """ if Field.Retry in self.allowed_fields(): - self.retries.append(retry) + self.retries.extend(retry) if isinstance(retry, list) else self.retries.append(retry) else: - raise ValueError("{state_type} state does not support retry field. ".format(state_type=type(self).__name__)) + raise ValueError(f"{type(self).__name__} state does not support retry field. ") def add_catch(self, catch): """ - Add a Catch block to the tail end of the list of catchers for the state. + Add a catcher or a list of catchers to the tail end of the list of catchers for the state. + See `Error handling in Step Functions `_ for more details. Args: - catch (Catch): Catch block to add. + catch (Catch or list(Catch): catcher or list of catchers to add. """ if Field.Catch in self.allowed_fields(): - self.catches.append(catch) + self.catches.extend(catch) if isinstance(catch, list) else self.catches.append(catch) else: - raise ValueError("{state_type} state does not support catch field. ".format(state_type=type(self).__name__)) + raise ValueError(f"{type(self).__name__} state does not support catch field. ") def to_dict(self): result = super(State, self).to_dict() @@ -487,10 +489,12 @@ class Parallel(State): A Parallel state causes the interpreter to execute each branch as concurrently as possible, and wait until each branch terminates (reaches a terminal state) before processing the next state in the Chain. """ - def __init__(self, state_id, **kwargs): + def __init__(self, state_id, retry=None, catch=None, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. + retry (Retry or list(Retry), optional): A retrier or list of retriers that define the state's retry policy. See `Error handling in Step Functions `_ for more details. + catch (Catch or list(Catch), optional): A catcher or list of catchers that define a fallback state. See `Error handling in Step Functions `_ for more details. comment (str, optional): Human-readable comment or description. (default: None) input_path (str, optional): Path applied to the state’s raw input to select some or all of it; that selection is used by the state. (default: '$') parameters (dict, optional): The value of this field becomes the effective input for the state. @@ -500,6 +504,12 @@ def __init__(self, state_id, **kwargs): super(Parallel, self).__init__(state_id, 'Parallel', **kwargs) self.branches = [] + if retry: + self.add_retry(retry) + + if catch: + self.add_catch(catch) + def allowed_fields(self): return [ Field.Comment, @@ -536,11 +546,13 @@ class Map(State): A Map state can accept an input with a list of items, execute a state or chain for each item in the list, and return a list, with all corresponding results of each execution, as its output. """ - def __init__(self, state_id, **kwargs): + def __init__(self, state_id, retry=None, catch=None, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. iterator (State or Chain): State or chain to execute for each of the items in `items_path`. + retry (Retry or list(Retry), optional): A retrier or list of retriers that define the state's retry policy. See `Error handling in Step Functions `_ for more details. + catch (Catch or list(Catch), optional): A catcher or list of catchers that define a fallback state. See `Error handling in Step Functions `_ for more details. items_path (str, optional): Path in the input for items to iterate over. (default: '$') max_concurrency (int, optional): Maximum number of iterations to have running at any given point in time. (default: 0) comment (str, optional): Human-readable comment or description. (default: None) @@ -551,6 +563,12 @@ def __init__(self, state_id, **kwargs): """ super(Map, self).__init__(state_id, 'Map', **kwargs) + if retry: + self.add_retry(retry) + + if catch: + self.add_catch(catch) + def attach_iterator(self, iterator): """ Attach `State` or `Chain` as iterator to the Map state, that will execute for each of the items in `items_path`. If an iterator was attached previously with the Map state, it will be replaced. @@ -586,10 +604,12 @@ class Task(State): Task State causes the interpreter to execute the work identified by the state’s `resource` field. """ - def __init__(self, state_id, **kwargs): + def __init__(self, state_id, retry=None, catch=None, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. + retry (Retry or list(Retry), optional): A retrier or list of retriers that define the state's retry policy. See `Error handling in Step Functions `_ for more details. + catch (Catch or list(Catch), optional): A catcher or list of catchers that define a fallback state. See `Error handling in Step Functions `_ for more details. resource (str): A URI that uniquely identifies the specific task to execute. The States language does not constrain the URI scheme nor any other part of the URI. timeout_seconds (int, optional): Positive integer specifying timeout for the state in seconds. If the state runs longer than the specified timeout, then the interpreter fails the state with a `States.Timeout` Error Name. (default: 60) timeout_seconds_path (str, optional): Path specifying the state's timeout value in seconds from the state input. When resolved, the path must select a field whose value is a positive integer. @@ -608,6 +628,12 @@ def __init__(self, state_id, **kwargs): if self.heartbeat_seconds is not None and self.heartbeat_seconds_path is not None: raise ValueError("Only one of 'heartbeat_seconds' or 'heartbeat_seconds_path' can be provided.") + if retry: + self.add_retry(retry) + + if catch: + self.add_catch(catch) + def allowed_fields(self): return [ Field.Comment, diff --git a/tests/integ/test_state_machine_definition.py b/tests/integ/test_state_machine_definition.py index d21e59b..4881b75 100644 --- a/tests/integ/test_state_machine_definition.py +++ b/tests/integ/test_state_machine_definition.py @@ -422,18 +422,38 @@ def test_task_state_machine_creation(sfn_client, sfn_role_arn, training_job_para def test_catch_state_machine_creation(sfn_client, sfn_role_arn, training_job_parameters): catch_state_name = "TaskWithCatchState" - custom_error = "CustomError" task_failed_error = "States.TaskFailed" - all_fail_error = "States.ALL" - custom_error_state_name = "Custom Error End" - task_failed_state_name = "Task Failed End" - all_error_state_name = "Catch All End" + timeout_error = "States.Timeout" + task_failed_state_name = "Catch Task Failed End" + timeout_state_name = "Catch Timeout End" catch_state_result = "Catch Result" task_resource = f"arn:{get_aws_partition()}:states:::sagemaker:createTrainingJob.sync" - # change the parameters to cause task state to fail + # Provide invalid TrainingImage to cause States.TaskFailed error training_job_parameters["AlgorithmSpecification"]["TrainingImage"] = "not_an_image" + task = steps.Task( + catch_state_name, + parameters=training_job_parameters, + resource=task_resource, + catch=steps.Catch( + error_equals=[timeout_error], + next_step=steps.Pass(timeout_state_name, result=catch_state_result) + ) + ) + task.add_catch( + steps.Catch( + error_equals=[task_failed_error], + next_step=steps.Pass(task_failed_state_name, result=catch_state_result) + ) + ) + + workflow = Workflow( + unique_name_from_base('Test_Catch_Workflow'), + definition=task, + role=sfn_role_arn + ) + asl_state_machine_definition = { "StartAt": catch_state_name, "States": { @@ -445,80 +465,61 @@ def test_catch_state_machine_creation(sfn_client, sfn_role_arn, training_job_par "Catch": [ { "ErrorEquals": [ - all_fail_error + timeout_error ], - "Next": all_error_state_name + "Next": timeout_state_name + }, + { + "ErrorEquals": [ + task_failed_error + ], + "Next": task_failed_state_name } ] }, - all_error_state_name: { + task_failed_state_name: { "Type": "Pass", "Result": catch_state_result, "End": True - } + }, + timeout_state_name: { + "Type": "Pass", + "Result": catch_state_result, + "End": True + }, } } - task = steps.Task( - catch_state_name, - parameters=training_job_parameters, - resource=task_resource - ) - task.add_catch( - steps.Catch( - error_equals=[all_fail_error], - next_step=steps.Pass(all_error_state_name, result=catch_state_result) - ) - ) - - workflow = Workflow( - unique_name_from_base('Test_Catch_Workflow'), - definition=task, - role=sfn_role_arn - ) workflow_test_suite(sfn_client, workflow, asl_state_machine_definition, catch_state_result) def test_retry_state_machine_creation(sfn_client, sfn_role_arn, training_job_parameters): retry_state_name = "RetryStateName" - all_fail_error = "Starts.ALL" + task_failed_error = "States.TaskFailed" + timeout_error = "States.Timeout" interval_seconds = 1 max_attempts = 2 backoff_rate = 2 task_resource = f"arn:{get_aws_partition()}:states:::sagemaker:createTrainingJob.sync" - # change the parameters to cause task state to fail + # Provide invalid TrainingImage to cause States.TaskFailed error training_job_parameters["AlgorithmSpecification"]["TrainingImage"] = "not_an_image" - asl_state_machine_definition = { - "StartAt": retry_state_name, - "States": { - retry_state_name: { - "Resource": task_resource, - "Parameters": training_job_parameters, - "Type": "Task", - "End": True, - "Retry": [ - { - "ErrorEquals": [all_fail_error], - "IntervalSeconds": interval_seconds, - "MaxAttempts": max_attempts, - "BackoffRate": backoff_rate - } - ] - } - } - } - task = steps.Task( retry_state_name, parameters=training_job_parameters, - resource=task_resource + resource=task_resource, + retry=steps.Retry( + error_equals=[timeout_error], + interval_seconds=interval_seconds, + max_attempts=max_attempts, + backoff_rate=backoff_rate + ) ) task.add_retry( steps.Retry( - error_equals=[all_fail_error], + error_equals=[task_failed_error], interval_seconds=interval_seconds, max_attempts=max_attempts, backoff_rate=backoff_rate @@ -531,4 +532,30 @@ def test_retry_state_machine_creation(sfn_client, sfn_role_arn, training_job_par role=sfn_role_arn ) - workflow_test_suite(sfn_client, workflow, asl_state_machine_definition, None) \ No newline at end of file + asl_state_machine_definition = { + "StartAt": retry_state_name, + "States": { + retry_state_name: { + "Resource": task_resource, + "Parameters": training_job_parameters, + "Type": "Task", + "End": True, + "Retry": [ + { + "ErrorEquals": [timeout_error], + "IntervalSeconds": interval_seconds, + "MaxAttempts": max_attempts, + "BackoffRate": backoff_rate + }, + { + "ErrorEquals": [task_failed_error], + "IntervalSeconds": interval_seconds, + "MaxAttempts": max_attempts, + "BackoffRate": backoff_rate + } + ] + } + } + } + + workflow_test_suite(sfn_client, workflow, asl_state_machine_definition, None) diff --git a/tests/unit/test_steps.py b/tests/unit/test_steps.py index 5c86279..3d34ee8 100644 --- a/tests/unit/test_steps.py +++ b/tests/unit/test_steps.py @@ -469,4 +469,126 @@ def test_default_paths_not_converted_to_null(): assert '"OutputPath": null' not in task_state.to_json() - +RETRY = Retry(error_equals=['ErrorA', 'ErrorB'], interval_seconds=1, max_attempts=2, backoff_rate=2) +RETRIES = [RETRY, Retry(error_equals=['ErrorC'], interval_seconds=5)] +EXPECTED_RETRY = [{'ErrorEquals': ['ErrorA', 'ErrorB'], 'IntervalSeconds': 1, 'BackoffRate': 2, 'MaxAttempts': 2}] +EXPECTED_RETRIES = EXPECTED_RETRY + [{'ErrorEquals': ['ErrorC'], 'IntervalSeconds': 5}] + + +@pytest.mark.parametrize("retry, expected_retry", [ + (RETRY, EXPECTED_RETRY), + (RETRIES, EXPECTED_RETRIES), +]) +def test_parallel_state_constructor_with_retry_adds_retrier_to_retriers(retry, expected_retry): + step = Parallel('Parallel', retry=retry) + assert step.to_dict()['Retry'] == expected_retry + + +@pytest.mark.parametrize("retry, expected_retry", [ + (RETRY, EXPECTED_RETRY), + (RETRIES, EXPECTED_RETRIES), +]) +def test_parallel_state_add_retry_adds_retrier_to_retriers(retry, expected_retry): + step = Parallel('Parallel') + step.add_retry(retry) + assert step.to_dict()['Retry'] == expected_retry + + +@pytest.mark.parametrize("retry, expected_retry", [ + (RETRY, EXPECTED_RETRY), + (RETRIES, EXPECTED_RETRIES), +]) +def test_map_state_constructor_with_retry_adds_retrier_to_retriers(retry, expected_retry): + step = Map('Map', retry=retry, iterator=Pass('Iterator')) + assert step.to_dict()['Retry'] == expected_retry + + +@pytest.mark.parametrize("retry, expected_retry", [ + (RETRIES, EXPECTED_RETRIES), + (RETRY, EXPECTED_RETRY), +]) +def test_map_state_add_retry_adds_retrier_to_retriers(retry, expected_retry): + step = Map('Map', iterator=Pass('Iterator')) + step.add_retry(retry) + assert step.to_dict()['Retry'] == expected_retry + + +@pytest.mark.parametrize("retry, expected_retry", [ + (RETRY, EXPECTED_RETRY), + (RETRIES, EXPECTED_RETRIES) +]) +def test_task_state_constructor_with_retry_adds_retrier_to_retriers(retry, expected_retry): + step = Task('Task', retry=retry) + assert step.to_dict()['Retry'] == expected_retry + + +@pytest.mark.parametrize("retry, expected_retry", [ + (RETRY, EXPECTED_RETRY), + (RETRIES, EXPECTED_RETRIES) +]) +def test_task_state_add_retry_adds_retrier_to_retriers(retry, expected_retry): + step = Task('Task') + step.add_retry(retry) + assert step.to_dict()['Retry'] == expected_retry + + +CATCH = Catch(error_equals=['States.ALL'], next_step=Pass('End State')) +CATCHES = [CATCH, Catch(error_equals=['States.TaskFailed'], next_step=Pass('Next State'))] +EXPECTED_CATCH = [{'ErrorEquals': ['States.ALL'], 'Next': 'End State'}] +EXPECTED_CATCHES = EXPECTED_CATCH + [{'ErrorEquals': ['States.TaskFailed'], 'Next': 'Next State'}] + + +@pytest.mark.parametrize("catch, expected_catch", [ + (CATCH, EXPECTED_CATCH), + (CATCHES, EXPECTED_CATCHES) +]) +def test_parallel_state_constructor_with_catch_adds_catcher_to_catchers(catch, expected_catch): + step = Parallel('Parallel', catch=catch) + assert step.to_dict()['Catch'] == expected_catch + +@pytest.mark.parametrize("catch, expected_catch", [ + (CATCH, EXPECTED_CATCH), + (CATCHES, EXPECTED_CATCHES) +]) +def test_parallel_state_add_catch_adds_catcher_to_catchers(catch, expected_catch): + step = Parallel('Parallel') + step.add_catch(catch) + assert step.to_dict()['Catch'] == expected_catch + + +@pytest.mark.parametrize("catch, expected_catch", [ + (CATCH, EXPECTED_CATCH), + (CATCHES, EXPECTED_CATCHES) +]) +def test_map_state_constructor_with_catch_adds_catcher_to_catchers(catch, expected_catch): + step = Map('Map', catch=catch, iterator=Pass('Iterator')) + assert step.to_dict()['Catch'] == expected_catch + + +@pytest.mark.parametrize("catch, expected_catch", [ + (CATCH, EXPECTED_CATCH), + (CATCHES, EXPECTED_CATCHES) +]) +def test_map_state_add_catch_adds_catcher_to_catchers(catch, expected_catch): + step = Map('Map', iterator=Pass('Iterator')) + step.add_catch(catch) + assert step.to_dict()['Catch'] == expected_catch + + +@pytest.mark.parametrize("catch, expected_catch", [ + (CATCH, EXPECTED_CATCH), + (CATCHES, EXPECTED_CATCHES) +]) +def test_task_state_constructor_with_catch_adds_catcher_to_catchers(catch, expected_catch): + step = Task('Task', catch=catch) + assert step.to_dict()['Catch'] == expected_catch + + +@pytest.mark.parametrize("catch, expected_catch", [ + (CATCH, EXPECTED_CATCH), + (CATCHES, EXPECTED_CATCHES) +]) +def test_task_state_add_catch_adds_catcher_to_catchers(catch, expected_catch): + step = Task('Task') + step.add_catch(catch) + assert step.to_dict()['Catch'] == expected_catch