Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ jobs:
- name: Install gotestsum
run: go install gotest.tools/gotestsum@latest

- name: Install mockery
run: go install github.com/vektra/mockery/v3@latest

- name: Generate mocks
run: go generate ./...
working-directory: ./dbos

- name: Run tests
run: go vet ./... && gotestsum --format github-action -- -race -v -count=1 ./...
working-directory: ./dbos
Expand Down
4 changes: 0 additions & 4 deletions cmd/dbos/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@ var initCmd = &cobra.Command{
RunE: runInit,
}

var (
configOnly bool
)

type templateData struct {
ProjectName string
}
Expand Down
2 changes: 2 additions & 0 deletions dbos/dbos.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ func processConfig(inputConfig *Config) (*Config, error) {
return dbosConfig, nil
}

//go:generate mockery --config=mocks-tests-config.yaml

// DBOSContext represents a DBOS execution context that provides workflow orchestration capabilities.
// It extends the standard Go context.Context and adds methods for running workflows and steps,
// inter-workflow communication, and state management.
Expand Down
18 changes: 18 additions & 0 deletions dbos/mocks-tests-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
all: false
dir: './mocks'
filename: '{{.InterfaceName}}_mock.go'
force-file-write: true
formatter: goimports
include-auto-generated: false
log-level: info
structname: 'Mock{{.InterfaceName}}'
pkgname: 'mocks'
recursive: false
require-template-schema-exists: true
template: testify
template-schema: '{{.Template}}.schema.json'
packages:
github.com/dbos-inc/dbos-transact-golang/dbos:
interfaces:
DBOSContext:
WorkflowHandle:
202 changes: 202 additions & 0 deletions dbos/mocks_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
package dbos_test

import (
"context"
"fmt"
"testing"
"time"

"github.com/dbos-inc/dbos-transact-golang/dbos"
"github.com/dbos-inc/dbos-transact-golang/dbos/mocks"
"github.com/stretchr/testify/mock"
)

func step(ctx context.Context) (int, error) {
return 1, nil
}

func childWorkflow(ctx dbos.DBOSContext, i int) (int, error) {
return i + 1, nil
}

func workflow(ctx dbos.DBOSContext, i int) (int, error) {
// Test RunAsStep
a, err := dbos.RunAsStep(ctx, step)
if err != nil {
return 0, err
}

// Child wf
ch, err := dbos.RunWorkflow(ctx, childWorkflow, i)
if err != nil {
return 0, err
}
b, err := ch.GetResult()
if err != nil {
return 0, err
}

// Test messaging operations
c, err := dbos.Recv[int](ctx, "chan1", 1*time.Second)
if err != nil {
return 0, err
}
d, err := dbos.GetEvent[int](ctx, "tgw", "event1", 1*time.Second)
if err != nil {
return 0, err
}
err = dbos.Send(ctx, "dst", 1, "topic")
if err != nil {
return 0, err
}

// Test SetEvent
err = dbos.SetEvent(ctx, "test_key", "test_value")
if err != nil {
return 0, err
}

// Test Sleep
_, err = dbos.Sleep(ctx, 100*time.Millisecond)
if err != nil {
return 0, err
}

// Test ID retrieval methods
workflowID, err := ctx.GetWorkflowID()
if err != nil {
return 0, err
}
stepID, err := ctx.GetStepID()
if err != nil {
return 0, err
}

// Test workflow management
_, err = dbos.RetrieveWorkflow[int](ctx, workflowID)
if err != nil {
return 0, err
}

_, err = dbos.Enqueue[int, int](ctx, "test_queue", "test_workflow", 42)
if err != nil {
return 0, err
}

err = dbos.CancelWorkflow(ctx, workflowID)
if err != nil {
return 0, err
}

_, err = dbos.ResumeWorkflow[int](ctx, workflowID)
if err != nil {
return 0, err
}

forkInput := dbos.ForkWorkflowInput{
OriginalWorkflowID: workflowID,
StartStep: uint(stepID),
}
_, err = dbos.ForkWorkflow[int](ctx, forkInput)
if err != nil {
return 0, err
}

_, err = dbos.ListWorkflows(ctx)
if err != nil {
return 0, err
}

_, err = dbos.GetWorkflowSteps(ctx, workflowID)
if err != nil {
return 0, err
}

// Test accessor methods
appVersion := ctx.GetApplicationVersion()
executorID := ctx.GetExecutorID()
appID := ctx.GetApplicationID()

// Use some values to avoid compiler warnings
_ = appVersion
_ = executorID
_ = appID

return a + b + c + d, nil
}

func aRealProgramFunction(dbosCtx dbos.DBOSContext) error {

dbos.RegisterWorkflow(dbosCtx, workflow)

err := dbosCtx.Launch()
if err != nil {
return err
}
defer dbosCtx.Shutdown(1 * time.Second)

res, err := workflow(dbosCtx, 2)
if err != nil {
return err
}
if res != 4 {
return fmt.Errorf("unexpected result: %v", res)
}
return nil
}

func TestMocks(t *testing.T) {
mockCtx := mocks.NewMockDBOSContext(t)

// Context lifecycle
mockCtx.On("Launch").Return(nil)
mockCtx.On("Shutdown", mock.Anything).Return()

// Basic workflow operations (existing)
mockCtx.On("RunAsStep", mockCtx, mock.Anything, mock.Anything).Return(1, nil)

// Child workflow
mockChildHandle := mocks.NewMockWorkflowHandle[any](t)
mockCtx.On("RunWorkflow", mockCtx, mock.Anything, 2, mock.Anything).Return(mockChildHandle, nil).Once()
mockChildHandle.On("GetResult").Return(1, nil)

// Messaging
mockCtx.On("Recv", mockCtx, "chan1", 1*time.Second).Return(1, nil)
mockCtx.On("GetEvent", mockCtx, "tgw", "event1", 1*time.Second).Return(1, nil)
mockCtx.On("Send", mockCtx, "dst", 1, "topic").Return(nil)
mockCtx.On("SetEvent", mockCtx, "test_key", "test_value").Return(nil)

mockCtx.On("Sleep", mockCtx, 100*time.Millisecond).Return(100*time.Millisecond, nil)

// ID retrieval methods
mockCtx.On("GetWorkflowID").Return("test-workflow-id", nil)
mockCtx.On("GetStepID").Return(1, nil)

// Workflow management
mockGenericHandle := mocks.NewMockWorkflowHandle[any](t)
mockGenericHandle.On("GetWorkflowID").Return("generic-workflow-id").Maybe()
mockGenericHandle.On("GetResult").Return(42, nil).Maybe()
mockGenericHandle.On("GetStatus").Return(dbos.WorkflowStatus{}, nil).Maybe()

mockCtx.On("RetrieveWorkflow", mockCtx, "test-workflow-id").Return(mockGenericHandle, nil)
mockCtx.On("Enqueue", mockCtx, "test_queue", "test_workflow", 42).Return(mockGenericHandle, nil)
mockCtx.On("CancelWorkflow", mockCtx, "test-workflow-id").Return(nil)
mockCtx.On("ResumeWorkflow", mockCtx, "test-workflow-id").Return(mockGenericHandle, nil)
mockCtx.On("ForkWorkflow", mockCtx, mock.Anything).Return(mockGenericHandle, nil)
mockCtx.On("ListWorkflows", mockCtx).Return([]dbos.WorkflowStatus{}, nil)
mockCtx.On("GetWorkflowSteps", mockCtx, "test-workflow-id").Return([]dbos.StepInfo{}, nil)

// Accessor methods
mockCtx.On("GetApplicationVersion").Return("test-version")
mockCtx.On("GetExecutorID").Return("test-executor")
mockCtx.On("GetApplicationID").Return("test-app-id")

err := aRealProgramFunction(mockCtx)
if err != nil {
t.Fatal(err)
}

mockCtx.AssertExpectations(t)
mockChildHandle.AssertExpectations(t)
// mockGenericHandle.AssertExpectations(t)
}
33 changes: 31 additions & 2 deletions dbos/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,35 @@ func (h *workflowPollingHandle[R]) GetResult() (R, error) {
return *new(R), err
}

// Wrapper handle -- useful for handling mocks in RunWorkflow
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this do? Why do we need it? Do users need to worry about it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The situation we have with RunWorkflow, is that only the interface methods are mockable. If you want to mock RunWorkflow and have it return a mock handle, only the "inner" call with get a mock. Then, in the package level method, we use information from that handle to return a new typed handle. This new handle, not being the mock, makes testing impossible.

What this does -- users don't need to worry -- is to fall back on a "proxy" handler if the interface RunWorkflow does not return the "normal" handlers (direct or polling). The proxy just calls in whatever was returned by RunWorkflow. This allow a mock handle to be passed through.

Copy link
Member

@kraftp kraftp Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we just return a polling handle (passing in the mocked context) instead of this new thing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returning a new polling handle would be returning a new object that is not the mock returned by DBOSContext.RunWorkflow. We need a way to pass through the mock while respecting the function signature (return a generic WorkflowHandle[R] interface). Unfortunately we cannot just "cast" the mock (SomeMockType[any]) to WorkflowHandle[R]: any and R are different type. This proxy type:

  • Let us pass through the mock
  • Let us returned a typed mock

type workflowHandleProxy[R any] struct {
wrappedHandle WorkflowHandle[any]
}

func (h *workflowHandleProxy[R]) GetResult() (R, error) {
result, err := h.wrappedHandle.GetResult()
if err != nil {
var zero R
return zero, err
}

// Convert from any to R
if typed, ok := result.(R); ok {
return typed, nil
}

var zero R
return zero, fmt.Errorf("cannot convert result of type %T to %T", result, zero)
}

func (h *workflowHandleProxy[R]) GetStatus() (WorkflowStatus, error) {
return h.wrappedHandle.GetStatus()
}

func (h *workflowHandleProxy[R]) GetWorkflowID() string {
return h.wrappedHandle.GetWorkflowID()
}

/**********************************/
/******* WORKFLOW REGISTRY *******/
/**********************************/
Expand Down Expand Up @@ -583,8 +612,8 @@ func RunWorkflow[P any, R any](ctx DBOSContext, fn Workflow[P, R], input P, opts
return typedHandle, nil
}

// Should never happen
return nil, fmt.Errorf("unexpected workflow handle type: %T", handle)
// Usually on a mocked path
return &workflowHandleProxy[R]{wrappedHandle: handle}, nil
}

func (c *dbosContext) RunWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opts ...WorkflowOption) (WorkflowHandle[any], error) {
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ require (
github.com/spf13/afero v1.12.0 // indirect
github.com/spf13/cast v1.7.1 // indirect
github.com/spf13/pflag v1.0.6 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An
github.com/spf13/viper v1.20.1 h1:ZMi+z/lvLyPSCoNtFCpqjy0S4kPbirhpTMwl8BkW9X4=
github.com/spf13/viper v1.20.1/go.mod h1:P9Mdzt1zoHIG8m2eZQinpiBjo6kCmZSKBClNNqjJvu4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
Expand Down
Loading