Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 25 additions & 16 deletions dbos/system_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -1440,7 +1440,8 @@ func (s *sysDB) getWorkflowSteps(ctx context.Context, workflowID string) ([]Step

type sleepInput struct {
duration time.Duration // Duration to sleep
skipSleep bool // If true, the function will not actually sleep (useful for testing)
skipSleep bool // If true, the function will not actually sleep and just return the remaining sleep duration
stepID *int // Optional step ID to use instead of generating a new one (for internal use)
}

// Sleep is a special type of step that sleeps for a specified duration
Expand All @@ -1461,7 +1462,13 @@ func (s *sysDB) sleep(ctx context.Context, input sleepInput) (time.Duration, err
return 0, newStepExecutionError(wfState.workflowID, functionName, "cannot call Sleep within a step")
}

stepID := wfState.NextStepID()
// Determine step ID
var stepID int
if input.stepID != nil && *input.stepID >= 0 {
stepID = *input.stepID
} else {
stepID = wfState.NextStepID()
}

// Check if operation was already executed
checkInput := checkOperationExecutionDBInput{
Expand Down Expand Up @@ -1700,6 +1707,7 @@ func (s *sysDB) recv(ctx context.Context, input recvInput) (any, error) {
}

stepID := wfState.NextStepID()
sleepStepID := wfState.NextStepID() // We will use a sleep step to implement the timeout
destinationID := wfState.workflowID

// Set default topic if not provided
Expand All @@ -1719,10 +1727,7 @@ func (s *sysDB) recv(ctx context.Context, input recvInput) (any, error) {
return nil, err
}
if recordedResult != nil {
if recordedResult.output != nil {
return recordedResult.output, nil
}
return nil, fmt.Errorf("no output recorded in the last recv")
return recordedResult.output, nil
}

// First check if there's already a receiver for this workflow/topic to avoid unnecessary database load
Expand Down Expand Up @@ -1762,6 +1767,7 @@ func (s *sysDB) recv(ctx context.Context, input recvInput) (any, error) {
timeout, err := s.sleep(ctx, sleepInput{
duration: input.Timeout,
skipSleep: true,
stepID: &sleepStepID,
})
if err != nil {
return nil, fmt.Errorf("failed to sleep before recv timeout: %w", err)
Expand All @@ -1772,6 +1778,9 @@ func (s *sysDB) recv(ctx context.Context, input recvInput) (any, error) {
s.logger.Debug("Received notification on condition variable", "payload", payload)
case <-time.After(timeout):
s.logger.Warn("Recv() timeout reached", "payload", payload, "timeout", input.Timeout)
case <-ctx.Done():
s.logger.Warn("Recv() context cancelled", "payload", payload, "cause", context.Cause(ctx))
return nil, ctx.Err()
}
}

Expand Down Expand Up @@ -1808,7 +1817,7 @@ func (s *sysDB) recv(ctx context.Context, input recvInput) (any, error) {

// Deserialize the message
var message any
if messageString != nil { // nil message should never happen because they'd cause an error on the send() path
if messageString != nil { // nil message can happen on the timeout path only
message, err = deserialize(messageString)
if err != nil {
return nil, fmt.Errorf("failed to deserialize message: %w", err)
Expand Down Expand Up @@ -1923,6 +1932,7 @@ func (s *sysDB) getEvent(ctx context.Context, input getEventInput) (any, error)
// Get workflow state from context (optional for GetEvent as we can get an event from outside a workflow)
wfState, ok := ctx.Value(workflowStateKey).(*workflowState)
var stepID int
var sleepStepID int
var isInWorkflow bool

if ok && wfState != nil {
Expand All @@ -1931,6 +1941,7 @@ func (s *sysDB) getEvent(ctx context.Context, input getEventInput) (any, error)
return nil, newStepExecutionError(wfState.workflowID, functionName, "cannot call GetEvent within a step")
}
stepID = wfState.NextStepID()
sleepStepID = wfState.NextStepID() // We will use a sleep step to implement the timeout

// Check if operation was already executed (only if in workflow)
checkInput := checkOperationExecutionDBInput{
Expand Down Expand Up @@ -1967,7 +1978,6 @@ func (s *sysDB) getEvent(ctx context.Context, input getEventInput) (any, error)

// Check if the event already exists in the database
query := `SELECT value FROM dbos.workflow_events WHERE workflow_uuid = $1 AND key = $2`
var value any
var valueString *string

row := s.pool.QueryRow(ctx, query, input.TargetWorkflowID, input.Key)
Expand All @@ -1976,7 +1986,7 @@ func (s *sysDB) getEvent(ctx context.Context, input getEventInput) (any, error)
return nil, fmt.Errorf("failed to query workflow event: %w", err)
}

if err == pgx.ErrNoRows || valueString == nil { // valueString should never be `nil`
if err == pgx.ErrNoRows {
// Wait for notification with timeout using condition variable
done := make(chan struct{})
go func() {
Expand All @@ -1991,6 +2001,7 @@ func (s *sysDB) getEvent(ctx context.Context, input getEventInput) (any, error)
timeout, err = s.sleep(ctx, sleepInput{
duration: input.Timeout,
skipSleep: true,
stepID: &sleepStepID,
})
if err != nil {
return nil, fmt.Errorf("failed to sleep before getEvent timeout: %w", err)
Expand All @@ -2003,22 +2014,20 @@ func (s *sysDB) getEvent(ctx context.Context, input getEventInput) (any, error)
case <-time.After(timeout):
s.logger.Warn("GetEvent() timeout reached", "target_workflow_id", input.TargetWorkflowID, "key", input.Key, "timeout", input.Timeout)
case <-ctx.Done():
return nil, fmt.Errorf("context cancelled while waiting for event: %w", ctx.Err())
s.logger.Warn("GetEvent() context cancelled", "target_workflow_id", input.TargetWorkflowID, "key", input.Key, "cause", context.Cause(ctx))
return nil, ctx.Err()
}

// Query the database again after waiting
row = s.pool.QueryRow(ctx, query, input.TargetWorkflowID, input.Key)
err = row.Scan(&valueString)
if err != nil {
if err == pgx.ErrNoRows {
value = nil // Event still doesn't exist
} else {
return nil, fmt.Errorf("failed to query workflow event after wait: %w", err)
}
Comment on lines -2012 to -2017
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simplify: the zero value of an any value is already nil

if err != nil && err != pgx.ErrNoRows {
return nil, fmt.Errorf("failed to query workflow event after wait: %w", err)
}
}

// Deserialize the value if it exists
var value any
if valueString != nil {
value, err = deserialize(valueString)
if err != nil {
Expand Down
Loading
Loading