Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;

import java.util.Map;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -111,6 +109,18 @@ void shouldOverrideGrpcPortIfExists() {
verify(builder).withPropertyOverride(Properties.GRPC_PORT, String.valueOf(grpcPort));
}

@Test
@DisplayName("Should override API token if it exists")
void shouldOverrideApiTokenIfExists() {
String apiToken = "token";

when(connectionDetails.getApiToken()).thenReturn(apiToken);

configuration.daprClientBuilder(connectionDetails);

verify(builder).withPropertyOverride(Properties.API_TOKEN, apiToken);
}

@Test
@DisplayName("Should override HTTP endpoint in properties if it exists")
void shouldOverrideHttpEndpointInPropertiesIfExists() {
Expand Down Expand Up @@ -159,6 +169,18 @@ void shouldOverrideGrpcPortPropertiesIfExists() {
assertThat(result.getValue(Properties.GRPC_PORT)).isEqualTo(grpcPort);
}

@Test
@DisplayName("Should override API token in properties if it exists")
void shouldOverrideApiTokenPropertiesIfExists() {
String apiToken = "token";

when(connectionDetails.getApiToken()).thenReturn(apiToken);

Properties result = configuration.createPropertiesFromConnectionDetails(connectionDetails);

assertThat(result.getValue(Properties.API_TOKEN)).isEqualTo(apiToken);
}

private static class TestDaprClientAutoConfiguration extends DaprClientAutoConfiguration {

private final DaprClientBuilder daprClientBuilder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@

package io.dapr.workflows;

import org.slf4j.Logger;

public interface WorkflowActivityContext {

Logger getLogger();

String getName();

String getTaskExecutionId();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@

import io.dapr.durabletask.TaskActivityContext;
import io.dapr.workflows.WorkflowActivityContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Wrapper for Durable Task Framework {@link TaskActivityContext}.
*/
class DefaultWorkflowActivityContext implements WorkflowActivityContext {
private final TaskActivityContext innerContext;
private final Logger logger;

/**
* Constructor for WorkflowActivityContext.
Expand All @@ -29,10 +32,36 @@ class DefaultWorkflowActivityContext implements WorkflowActivityContext {
* @throws IllegalArgumentException if context is null
*/
public DefaultWorkflowActivityContext(TaskActivityContext context) throws IllegalArgumentException {
this(context, LoggerFactory.getLogger(WorkflowActivityContext.class));
}

/**
* Constructor for WorkflowActivityContext.
*
* @param context TaskActivityContext
* @throws IllegalArgumentException if context is null
*/
public DefaultWorkflowActivityContext(TaskActivityContext context, Logger logger) throws IllegalArgumentException {
if (context == null) {
throw new IllegalArgumentException("Context cannot be null");
}

if (logger == null) {
throw new IllegalArgumentException("Logger cannot be null");
}

this.innerContext = context;
this.logger = logger;
}

/**
* Gets the logger for the current activity.
*
* @return the logger for the current activity
*/
@Override
public Logger getLogger() {
return this.logger;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ public DefaultWorkflowContext(TaskOrchestrationContext context, Logger logger)
if (context == null) {
throw new IllegalArgumentException("Context cannot be null");
}

if (logger == null) {
throw new IllegalArgumentException("Logger cannot be null");
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package io.dapr.workflows.runtime;

import io.dapr.durabletask.TaskActivityContext;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

class DefaultWorkflowActivityContextTest {

@Test
@DisplayName("Should successfully create context and return correct values for all methods")
void shouldSuccessfullyCreateContextAndReturnCorrectValuesForAllMethods() {
TaskActivityContext mockInnerContext = mock(TaskActivityContext.class);
DefaultWorkflowActivityContext context = new DefaultWorkflowActivityContext(mockInnerContext);

when(mockInnerContext.getName()).thenReturn("TestActivity");
when(mockInnerContext.getInput(any())).thenReturn("TestInput");
when(mockInnerContext.getTaskExecutionId()).thenReturn("TestExecutionId");

assertNotNull(context.getLogger());
assertEquals("TestActivity", context.getName());

String input = context.getInput(String.class);

assertEquals("TestInput", input);
assertEquals("TestExecutionId", context.getTaskExecutionId());
}

@Test
@DisplayName("Should throw IllegalArgumentException when context parameter is null")
void shouldThrowIllegalArgumentExceptionWhenContextParameterIsNull() {
IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> {
new DefaultWorkflowActivityContext(null);
});
assertEquals("Context cannot be null", exception.getMessage());
}

@Test
@DisplayName("Should throw IllegalArgumentException when logger parameter is null")
void shouldThrowIllegalArgumentExceptionWhenLoggerParameterIsNull() {
TaskActivityContext mockInnerContext = mock(TaskActivityContext.class);

IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> {
new DefaultWorkflowActivityContext(mockInnerContext, null);
});
assertEquals("Logger cannot be null", exception.getMessage());
}
}