Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,50 +18,40 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.cloudwatchevents.CloudWatchEventsClient;
import software.amazon.awssdk.services.ecs.EcsClient;
import software.amazon.awssdk.services.emr.EmrClient;
import software.amazon.awssdk.services.emr.model.ListClustersResponse;
import software.amazon.awssdk.services.emrserverless.EmrServerlessClient;

import sleeper.clients.deploy.PauseSystem;
import sleeper.clients.util.EmrUtils;
import sleeper.core.properties.SleeperProperties;
import sleeper.core.properties.SleeperProperty;
import sleeper.core.properties.instance.InstanceProperties;
import sleeper.core.util.StaticRateLimit;
import sleeper.core.util.ThreadSleep;

import java.util.List;
import java.util.function.Consumer;

import static sleeper.core.properties.instance.CdkDefinedInstanceProperty.BULK_EXPORT_CLUSTER;
import static sleeper.core.properties.instance.CdkDefinedInstanceProperty.COMPACTION_CLUSTER;
import static sleeper.core.properties.instance.CdkDefinedInstanceProperty.INGEST_CLUSTER;
import static sleeper.core.properties.instance.CommonProperty.ID;
import static sleeper.core.util.RateLimitUtils.sleepForSustainedRatePerSecond;

public class ShutdownSystemProcesses {

private static final Logger LOGGER = LoggerFactory.getLogger(ShutdownSystemProcesses.class);

private final CloudWatchEventsClient cloudWatch;
private final EcsClient ecs;
private final EmrClient emrClient;
private final EmrServerlessClient emrServerlessClient;
private final StaticRateLimit<ListClustersResponse> listActiveClustersLimit;
private final ThreadSleep threadSleep;

public ShutdownSystemProcesses(TearDownClients clients) {
this(clients.getCloudWatch(), clients.getEcs(), clients.getEmr(), clients.getEmrServerless(), EmrUtils.LIST_ACTIVE_CLUSTERS_LIMIT, Thread::sleep);
this(clients.getCloudWatch(), clients.getEmr(), clients.getEmrServerless(), EmrUtils.LIST_ACTIVE_CLUSTERS_LIMIT, Thread::sleep);
}

public ShutdownSystemProcesses(
CloudWatchEventsClient cloudWatch, EcsClient ecs,
CloudWatchEventsClient cloudWatch,
EmrClient emrClient, EmrServerlessClient emrServerlessClient,
StaticRateLimit<ListClustersResponse> listActiveClustersLimit,
ThreadSleep threadSleep) {
this.cloudWatch = cloudWatch;
this.ecs = ecs;
this.emrClient = emrClient;
this.emrServerlessClient = emrServerlessClient;
this.listActiveClustersLimit = listActiveClustersLimit;
Expand All @@ -72,18 +62,10 @@ public void shutdown(InstanceProperties instanceProperties, List<String> extraEC
LOGGER.info("Shutting down system processes for instance {}", instanceProperties.get(ID));
LOGGER.info("Pausing the system");
PauseSystem.pause(cloudWatch, instanceProperties);
stopECSTasks(instanceProperties, extraECSClusters);
stopEMRClusters(instanceProperties);
stopEMRServerlessApplication(instanceProperties);
}

private void stopECSTasks(InstanceProperties instanceProperties, List<String> extraClusters) {
stopTasks(ecs, instanceProperties, INGEST_CLUSTER);
stopTasks(ecs, instanceProperties, COMPACTION_CLUSTER);
stopTasks(ecs, instanceProperties, BULK_EXPORT_CLUSTER);
extraClusters.forEach(clusterName -> stopTasks(ecs, clusterName));
}

private void stopEMRClusters(InstanceProperties properties) throws InterruptedException {
new TerminateEMRClusters(emrClient, properties.get(ID), listActiveClustersLimit, threadSleep).run();
}
Expand All @@ -92,29 +74,4 @@ private void stopEMRServerlessApplication(InstanceProperties properties) throws
new TerminateEMRServerlessApplications(emrServerlessClient, properties).run();
}

public static <T extends SleeperProperty> void stopTasks(EcsClient ecs, SleeperProperties<T> properties, T property) {
if (!properties.isSet(property)) {
return;
}
stopTasks(ecs, properties.get(property));
}

private static void stopTasks(EcsClient ecs, String clusterName) {
LOGGER.info("Stopping tasks for ECS cluster {}", clusterName);
forEachTaskArn(ecs, clusterName, taskArn -> {
// Rate limit for ECS StopTask is 100 burst, 40 sustained:
// https://docs.aws.amazon.com/AmazonECS/latest/APIReference/request-throttling.html
sleepForSustainedRatePerSecond(30);
ecs.stopTask(builder -> builder.cluster(clusterName).task(taskArn)
.reason("Cleaning up before cdk destroy"));
});
}

private static void forEachTaskArn(EcsClient ecs, String clusterName, Consumer<String> consumer) {
ecs.listTasksPaginator(builder -> builder.cluster(clusterName))
.stream()
.peek(response -> LOGGER.info("Found {} tasks", response.taskArns().size()))
.flatMap(response -> response.taskArns().stream())
.forEach(consumer);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import software.amazon.awssdk.services.cloudformation.CloudFormationClient;
import software.amazon.awssdk.services.cloudwatchevents.CloudWatchEventsClient;
import software.amazon.awssdk.services.ecr.EcrClient;
import software.amazon.awssdk.services.ecs.EcsClient;
import software.amazon.awssdk.services.emr.EmrClient;
import software.amazon.awssdk.services.emrserverless.EmrServerlessClient;
import software.amazon.awssdk.services.s3.S3Client;
Expand All @@ -31,7 +30,6 @@ public class TearDownClients {

private final S3Client s3;
private final CloudWatchEventsClient cloudWatch;
private final EcsClient ecs;
private final EcrClient ecr;
private final EmrClient emr;
private final EmrServerlessClient emrServerless;
Expand All @@ -40,7 +38,6 @@ public class TearDownClients {
private TearDownClients(Builder builder) {
s3 = Objects.requireNonNull(builder.s3, "s3v2 must not be null");
cloudWatch = Objects.requireNonNull(builder.cloudWatch, "cloudWatch must not be null");
ecs = Objects.requireNonNull(builder.ecs, "ecs must not be null");
ecr = Objects.requireNonNull(builder.ecr, "ecr must not be null");
emr = Objects.requireNonNull(builder.emr, "emr must not be null");
emrServerless = Objects.requireNonNull(builder.emrServerless, "emrServerless must not be null");
Expand All @@ -51,14 +48,12 @@ public static void withDefaults(TearDownOperation operation) throws IOException,
try (S3Client s3Client = S3Client.create();
CloudWatchEventsClient cloudWatchClient = CloudWatchEventsClient.create();
EcrClient ecrClient = EcrClient.create();
EcsClient ecsClient = EcsClient.create();
EmrClient emrClient = EmrClient.create();
EmrServerlessClient emrServerless = EmrServerlessClient.create();
CloudFormationClient cloudFormationClient = CloudFormationClient.create()) {
TearDownClients clients = builder()
.s3(s3Client)
.cloudWatch(cloudWatchClient)
.ecs(ecsClient)
.ecr(ecrClient)
.emr(emrClient)
.emrServerless(emrServerless)
Expand All @@ -80,10 +75,6 @@ public CloudWatchEventsClient getCloudWatch() {
return cloudWatch;
}

public EcsClient getEcs() {
return ecs;
}

public EcrClient getEcr() {
return ecr;
}
Expand All @@ -103,7 +94,6 @@ public CloudFormationClient getCloudFormation() {
public static final class Builder {
private S3Client s3;
private CloudWatchEventsClient cloudWatch;
private EcsClient ecs;
private EcrClient ecr;
private EmrClient emr;
private EmrServerlessClient emrServerless;
Expand All @@ -122,11 +112,6 @@ public Builder cloudWatch(CloudWatchEventsClient cloudWatch) {
return this;
}

public Builder ecs(EcsClient ecs) {
this.ecs = ecs;
return this;
}

public Builder ecr(EcrClient ecr) {
this.ecr = ecr;
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,23 +33,15 @@
import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
import static com.github.tomakehurst.wiremock.client.WireMock.anyRequestedFor;
import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl;
import static com.github.tomakehurst.wiremock.client.WireMock.post;
import static com.github.tomakehurst.wiremock.client.WireMock.stubFor;
import static com.github.tomakehurst.wiremock.client.WireMock.verify;
import static com.github.tomakehurst.wiremock.stubbing.Scenario.STARTED;
import static sleeper.clients.testutil.ClientWiremockTestHelper.OPERATION_HEADER;
import static sleeper.clients.testutil.ClientWiremockTestHelper.wiremockCloudWatchClient;
import static sleeper.clients.testutil.ClientWiremockTestHelper.wiremockEcsClient;
import static sleeper.clients.testutil.ClientWiremockTestHelper.wiremockEmrClient;
import static sleeper.clients.testutil.ClientWiremockTestHelper.wiremockEmrServerlessClient;
import static sleeper.clients.testutil.WiremockCloudWatchTestHelper.anyRequestedForCloudWatchEvents;
import static sleeper.clients.testutil.WiremockCloudWatchTestHelper.disableRuleRequest;
import static sleeper.clients.testutil.WiremockCloudWatchTestHelper.disableRuleRequestedFor;
import static sleeper.clients.testutil.WiremockEcsTestHelper.MATCHING_LIST_TASKS_OPERATION;
import static sleeper.clients.testutil.WiremockEcsTestHelper.MATCHING_STOP_TASK_OPERATION;
import static sleeper.clients.testutil.WiremockEcsTestHelper.anyRequestedForEcs;
import static sleeper.clients.testutil.WiremockEcsTestHelper.listTasksRequestedFor;
import static sleeper.clients.testutil.WiremockEcsTestHelper.stopTaskRequestedFor;
import static sleeper.clients.testutil.WiremockEmrServerlessTestHelper.aResponseWithApplicationWithNameAndState;
import static sleeper.clients.testutil.WiremockEmrServerlessTestHelper.aResponseWithApplicationWithState;
import static sleeper.clients.testutil.WiremockEmrServerlessTestHelper.aResponseWithJobRunWithState;
Expand All @@ -75,12 +67,10 @@
import static sleeper.clients.testutil.WiremockEmrTestHelper.terminateJobFlowsRequestWithJobIdCount;
import static sleeper.clients.testutil.WiremockEmrTestHelper.terminateJobFlowsRequestedFor;
import static sleeper.clients.testutil.WiremockEmrTestHelper.terminateJobFlowsRequestedWithJobIdsCount;
import static sleeper.core.properties.instance.CdkDefinedInstanceProperty.COMPACTION_CLUSTER;
import static sleeper.core.properties.instance.CdkDefinedInstanceProperty.COMPACTION_JOB_CREATION_CLOUDWATCH_RULE;
import static sleeper.core.properties.instance.CdkDefinedInstanceProperty.COMPACTION_TASK_CREATION_CLOUDWATCH_RULE;
import static sleeper.core.properties.instance.CdkDefinedInstanceProperty.GARBAGE_COLLECTOR_CLOUDWATCH_RULE;
import static sleeper.core.properties.instance.CdkDefinedInstanceProperty.INGEST_CLOUDWATCH_RULE;
import static sleeper.core.properties.instance.CdkDefinedInstanceProperty.INGEST_CLUSTER;
import static sleeper.core.properties.instance.CdkDefinedInstanceProperty.PARTITION_SPLITTING_CLOUDWATCH_RULE;
import static sleeper.core.properties.instance.CdkDefinedInstanceProperty.TABLE_METRICS_RULE;
import static sleeper.core.properties.instance.CommonProperty.ID;
Expand All @@ -95,7 +85,7 @@ class ShutdownSystemProcessesIT {

@BeforeEach
void setUp(WireMockRuntimeInfo runtimeInfo) {
shutdown = new ShutdownSystemProcesses(wiremockCloudWatchClient(runtimeInfo), wiremockEcsClient(runtimeInfo),
shutdown = new ShutdownSystemProcesses(wiremockCloudWatchClient(runtimeInfo),
wiremockEmrClient(runtimeInfo), wiremockEmrServerlessClient(runtimeInfo), StaticRateLimit.none(), noWaits());
}

Expand Down Expand Up @@ -163,59 +153,6 @@ void shouldShutdownCloudWatchRulesWhenSet() throws Exception {
}
}

@Nested
@DisplayName("Terminate running ECS tasks")
class TerminateECSTasks {

@BeforeEach
void setup() {
properties.set(INGEST_CLUSTER, "test-ingest-cluster");
stubFor(listActiveEmrClustersRequest()
.willReturn(aResponseWithNoClusters()));
stubFor(listActiveEmrApplicationsRequest()
.willReturn(aResponseWithNoApplications()));
}

@Test
void shouldLookForECSTasksWhenClustersSet() throws Exception {
// Given
properties.set(COMPACTION_CLUSTER, "test-compaction-cluster");
List<String> extraECSClusters = List.of("test-system-test-cluster");

stubFor(post("/")
.withHeader(OPERATION_HEADER, MATCHING_LIST_TASKS_OPERATION)
.willReturn(aResponse().withStatus(200).withBody("{\"nextToken\":null,\"taskArns\":[]}")));

// When
shutdownWithExtraEcsClusters(extraECSClusters);

// Then
verify(3, anyRequestedForEcs());
verify(1, listTasksRequestedFor("test-ingest-cluster"));
verify(1, listTasksRequestedFor("test-compaction-cluster"));
verify(1, listTasksRequestedFor("test-system-test-cluster"));
}

@Test
void shouldStopECSTaskWhenOneIsFound() throws Exception {
// Given
stubFor(post("/")
.withHeader(OPERATION_HEADER, MATCHING_LIST_TASKS_OPERATION)
.willReturn(aResponse().withStatus(200).withBody("{\"nextToken\":null,\"taskArns\":[\"test-task\"]}")));
stubFor(post("/")
.withHeader(OPERATION_HEADER, MATCHING_STOP_TASK_OPERATION)
.willReturn(aResponse().withStatus(200)));

// When
shutdown();

// Then
verify(2, anyRequestedForEcs());
verify(1, listTasksRequestedFor("test-ingest-cluster"));
verify(1, stopTaskRequestedFor("test-ingest-cluster", "test-task"));
}
}

@Nested
@DisplayName("Terminate running EMR clusters")
class TerminateEMRClusters {
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ public static void tearDown(
instance.waitForStackToDelete();
}
for (TearDownSystemTestDeployment deployment : tearDownSystemTestDeployments) {
deployment.shutdownSystemProcesses();
deployment.deleteStack();
}
for (TearDownInstance instance : tearDownStandaloneInstances) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import sleeper.clients.teardown.RemoveECRRepositories;
import sleeper.clients.teardown.RemoveJarsBucket;
import sleeper.clients.teardown.ShutdownSystemProcesses;
import sleeper.clients.teardown.TearDownClients;
import sleeper.clients.teardown.WaitForStackToDelete;
import sleeper.core.deploy.PopulateInstanceProperties;
Expand All @@ -31,7 +30,6 @@
import java.io.IOException;
import java.util.List;

import static sleeper.systemtest.configuration.SystemTestProperty.SYSTEM_TEST_CLUSTER_NAME;
import static sleeper.systemtest.configuration.SystemTestProperty.SYSTEM_TEST_ID;
import static sleeper.systemtest.configuration.SystemTestProperty.SYSTEM_TEST_JARS_BUCKET;
import static sleeper.systemtest.configuration.SystemTestProperty.SYSTEM_TEST_REPO;
Expand Down Expand Up @@ -65,10 +63,6 @@ public void waitForStackToDelete() throws InterruptedException {
WaitForStackToDelete.from(clients.getCloudFormation(), properties.get(SYSTEM_TEST_ID)).pollUntilFinished();
}

public void shutdownSystemProcesses() throws InterruptedException {
ShutdownSystemProcesses.stopTasks(clients.getEcs(), properties, SYSTEM_TEST_CLUSTER_NAME);
}

public void cleanupAfterAllInstancesAndStackDeleted() throws InterruptedException, IOException {
LOGGER.info("Removing the Jars bucket and docker containers");
RemoveJarsBucket.remove(clients.getS3(), properties.get(SYSTEM_TEST_JARS_BUCKET));
Expand Down