Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.WebJobs.Host;
using Microsoft.Azure.WebJobs.Host.Executors;
using Microsoft.Extensions.Logging;
using Moq;
using RabbitMQ.Client;
using RabbitMQ.Client.Events;
using Xunit;
Expand All @@ -23,6 +30,7 @@ public void Verify_BindingDataContract_Types()
["RoutingKey"] = typeof(string),
["BasicProperties"] = typeof(IBasicProperties),
["Body"] = typeof(ReadOnlyMemory<byte>),
["MessageActions"] = typeof(RabbitMQMessageActions),
};

IReadOnlyDictionary<string, Type> actualContract = RabbitMQTriggerBinding.CreateBindingDataContract();
Expand All @@ -44,6 +52,7 @@ public void Verify_BindingDataContract_Values()

ReadOnlyMemory<byte> body = buffer;
var eventArgs = new BasicDeliverEventArgs("ConsumerName", deliveryTag, false, "n/a", "QueueName", null, body);
var messageActions = new RabbitMQMessageActions(Mock.Of<IModel>());

var data = new Dictionary<string, object>(StringComparer.OrdinalIgnoreCase)
{
Expand All @@ -54,13 +63,109 @@ public void Verify_BindingDataContract_Values()
["Body"] = body,
["Exchange"] = eventArgs.Exchange,
["BasicProperties"] = eventArgs.BasicProperties,
["MessageActions"] = messageActions,
};

IReadOnlyDictionary<string, object> actualContract = RabbitMQTriggerBinding.CreateBindingData(eventArgs);
IReadOnlyDictionary<string, object> actualContract = RabbitMQTriggerBinding.CreateBindingData(eventArgs, messageActions);

foreach (KeyValuePair<string, object> item in actualContract)
{
Assert.Equal(data[item.Key], item.Value);
}
}

[Fact]
public void RabbitMQTriggerAttribute_ManualAck_DefaultsToFalse()
{
var attribute = new RabbitMQTriggerAttribute("test-queue");
bool manualAck = attribute.ManualAck;
Assert.False(manualAck, "ManualAck should default to false.");
}

[Fact]
public void RabbitMQTriggerBinding_CreateBindingData_IncludesRabbitMQMessageActions()
{
ulong deliveryTag = 1;

var rand = new Random();
byte[] buffer = new byte[10];
rand.NextBytes(buffer);

ReadOnlyMemory<byte> body = buffer;
var eventArgs = new BasicDeliverEventArgs("ConsumerName", deliveryTag, false, "n/a", "QueueName", null, body);
var messageActions = new RabbitMQMessageActions(Mock.Of<IModel>());

IReadOnlyDictionary<string, object> bindingData = RabbitMQTriggerBinding.CreateBindingData(eventArgs, messageActions);

// Assert
Assert.True(bindingData.ContainsKey("MessageActions"), "Binding data should include MessageActions.");
Assert.Equal(messageActions, bindingData["MessageActions"]);
}

[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task RabbitMQTrigger_ManualAck_BasicAckBehavior(bool manualAck)
{
// Arrange
var mockChannel = new Mock<IModel>();
var mockExecutor = new Mock<ITriggeredFunctionExecutor>();
var mockLogger = new Mock<ILogger>();
var mockDrainModeManager = new Mock<IDrainModeManager>();
var mockBasicProperties = new Mock<IBasicProperties>();

// Simulate successful function execution
mockExecutor
.Setup(executor => executor.TryExecuteAsync(It.IsAny<TriggeredFunctionData>(), It.IsAny<CancellationToken>()))
.ReturnsAsync(new FunctionResult(true));

var listener = new RabbitMQListener(
mockChannel.Object,
mockExecutor.Object,
mockLogger.Object,
functionId: "test-function",
queueName: "test-queue",
manualAck: manualAck,
prefetchCount: 10,
drainModeManager: mockDrainModeManager.Object);

var eventArgs = new BasicDeliverEventArgs
{
DeliveryTag = 1,
Body = new ReadOnlyMemory<byte>([0x01, 0x02, 0x03]),
BasicProperties = mockBasicProperties.Object,
};

// Act
await listener.StartAsync(CancellationToken.None);

// Find the IBasicConsumer passed to BasicConsume
IInvocation basicConsumeInvocation = mockChannel.Invocations
.FirstOrDefault(invocation => invocation.Method.Name == "BasicConsume");

Assert.NotNull(basicConsumeInvocation);

var consumer = basicConsumeInvocation.Arguments[6] as AsyncEventingBasicConsumer;
Assert.NotNull(consumer);

// Simulate message delivery
await consumer.HandleBasicDeliver(
consumerTag: "ctag",
deliveryTag: eventArgs.DeliveryTag,
redelivered: false,
exchange: string.Empty,
routingKey: string.Empty,
properties: eventArgs.BasicProperties,
body: eventArgs.Body.ToArray());

// Assert
if (manualAck)
{
mockChannel.Verify(channel => channel.BasicAck(It.IsAny<ulong>(), It.IsAny<bool>()), Times.Never, "BasicAck should not be called when ManualAck is true.");
}
else
{
mockChannel.Verify(channel => channel.BasicAck(It.IsAny<ulong>(), It.IsAny<bool>()), Times.AtLeastOnce, "BasicAck should be called when ManualAck is false.");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ private static RabbitMQListener GetScaleMonitor(string functionId, string queueN
Mock.Of<ILogger>(),
functionId,
queueName,
false,
7357,
DrainModeManager);
}
Expand All @@ -224,6 +225,7 @@ private static (IScaleMonitor<RabbitMQTriggerMetrics> Monitor, List<string> LogM
mockLogger.Object,
"testFunctionId",
"testQueueName",
false,
7357,
DrainModeManager);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ internal sealed class RabbitMQListener : IListener, IScaleMonitor<RabbitMQTrigge
private readonly ILogger logger;
private readonly string queueName;
private readonly ushort prefetchCount;
private readonly bool manualAck;
private readonly string logDetails;
private readonly IDrainModeManager drainModeManager;

Expand All @@ -46,13 +47,15 @@ public RabbitMQListener(
ILogger logger,
string functionId,
string queueName,
bool manualAck,
ushort prefetchCount,
IDrainModeManager drainModeManager)
{
this.channel = channel ?? throw new ArgumentNullException(nameof(channel));
this.executor = executor ?? throw new ArgumentNullException(nameof(executor));
this.logger = logger ?? throw new ArgumentNullException(nameof(logger));
this.queueName = !string.IsNullOrWhiteSpace(queueName) ? queueName : throw new ArgumentNullException(nameof(queueName));
this.manualAck = manualAck;
this.prefetchCount = prefetchCount;
this.drainModeManager = drainModeManager;
this.listenerCancellationTokenSource = new CancellationTokenSource();
Expand All @@ -62,6 +65,12 @@ public RabbitMQListener(
// Do not convert the scale-monitor ID to lower-case string since RabbitMQ queue names are case-sensitive.
this.Descriptor = new ScaleMonitorDescriptor($"{functionId}-RabbitMQTrigger-{queueName}", functionId);
this.logDetails = $"function: '{functionId}', queue: '{queueName}'";

// Add a handler to log any errors that occur on the channel.
channel.ModelShutdown += (sender, args) =>
{
logger.LogError($"[!] Channel closed due to error: {args.Exception?.Message}");
};
}

public ScaleMonitorDescriptor Descriptor { get; }
Expand Down Expand Up @@ -133,10 +142,20 @@ async Task ReceivedHandler(object model, BasicDeliverEventArgs args)
// We cannot call BasicReject() on the message with requeue = true since that would not enable a fixed
// number of retry attempts. See: https://stackoverflow.com/q/23158310.
this.channel.BasicPublish(exchange: string.Empty, routingKey: this.queueName, args.BasicProperties, args.Body);
}

// Acknowledge the existing message only after the message (in case of failure) is re-published.
this.channel.BasicAck(args.DeliveryTag, multiple: false);
// Acknowledge the existing message after the message is re-published.
this.channel.BasicAck(args.DeliveryTag, multiple: false);
}
else if (!this.manualAck)
{
// Acknowledge the existing message if manualAck is not set and function execution was successful.
this.channel.BasicAck(args.DeliveryTag, multiple: false);
}
else
{
// Do not acknowledge the message if manualAck is set and function execution was successful.
this.logger.LogDebug($"Not acknowledging message for {this.logDetails} since manualAck is set.");
}
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the MIT License. See License.txt in the project root for license information.

using System.Collections.Generic;
using System.Threading;
using Microsoft.Azure.WebJobs.Host.Scale;
using RabbitMQ.Client;

namespace Microsoft.Azure.WebJobs.Extensions.RabbitMQ;

public class RabbitMQMessageActions
{
private readonly IModel channel;

internal RabbitMQMessageActions(IModel channel)
{
this.channel = channel;
}

public void BasicReject(ulong deliveryTag, bool requeue = false)
{
this.channel.BasicReject(deliveryTag, requeue: requeue);
}

public void BasicAck(ulong deliveryTag, bool multiple = false)
{
this.channel.BasicAck(deliveryTag, multiple: multiple);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,9 @@ public sealed class RabbitMQTriggerAttribute(string queueName) : Attribute
/// production. Does not apply when SSL is disabled.
/// </summary>
public bool DisableCertificateValidation { get; set; }

/// <summary>
/// Gets or sets a value indicating whether message acknowledgements would be done manually.
/// </summary>
public bool ManualAck { get; set; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@ public Task<ITriggerBinding> TryCreateAsync(TriggerBindingProviderContext contex
string connectionString = Utility.ResolveConnectionString(attribute.ConnectionStringSetting, this.options.Value.ConnectionString, this.configuration);
string queueName = this.Resolve(attribute.QueueName) ?? throw new InvalidOperationException("RabbitMQ queue name is missing");
bool disableCertificateValidation = attribute.DisableCertificateValidation || this.options.Value.DisableCertificateValidation;
bool manualAck = attribute.ManualAck;

IRabbitMQService service = this.provider.GetService(connectionString, queueName, disableCertificateValidation);

return Task.FromResult<ITriggerBinding>(new RabbitMQTriggerBinding(service, queueName, this.logger, parameter.ParameterType, this.options.Value.PrefetchCount, this.drainModeManager));
return Task.FromResult<ITriggerBinding>(new RabbitMQTriggerBinding(service, queueName, manualAck, this.logger, parameter.ParameterType, this.options.Value.PrefetchCount, this.drainModeManager));
}

private string Resolve(string name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@

namespace Microsoft.Azure.WebJobs.Extensions.RabbitMQ;

internal class RabbitMQTriggerBinding(IRabbitMQService service, string queueName, ILogger logger, Type parameterType, ushort prefetchCount, IDrainModeManager drainModeManager) : ITriggerBinding
internal class RabbitMQTriggerBinding(IRabbitMQService service, string queueName, bool manualAck, ILogger logger, Type parameterType, ushort prefetchCount, IDrainModeManager drainModeManager) : ITriggerBinding
{
private readonly IRabbitMQService service = service;
private readonly ILogger logger = logger;
private readonly Type parameterType = parameterType;
private readonly string queueName = queueName;
private readonly bool manualAck = manualAck;
private readonly ushort prefetchCount = prefetchCount;
private readonly IDrainModeManager drainModeManager = drainModeManager;
private readonly RabbitMQMessageActions messageActions = new(service.Model);

public Type TriggerValueType => typeof(BasicDeliverEventArgs);

Expand All @@ -31,7 +33,7 @@ internal class RabbitMQTriggerBinding(IRabbitMQService service, string queueName
public Task<ITriggerData> BindAsync(object value, ValueBindingContext context)
{
var message = (BasicDeliverEventArgs)value;
IReadOnlyDictionary<string, object> bindingData = CreateBindingData(message);
IReadOnlyDictionary<string, object> bindingData = CreateBindingData(message, this.messageActions);

return Task.FromResult<ITriggerData>(new TriggerData(new BasicDeliverEventArgsValueProvider(message, this.parameterType), bindingData));
}
Expand All @@ -46,6 +48,7 @@ public Task<IListener> CreateListenerAsync(ListenerFactoryContext context)
this.logger,
context.Descriptor.Id,
this.queueName,
this.manualAck,
this.prefetchCount,
this.drainModeManager));
}
Expand All @@ -69,12 +72,13 @@ internal static IReadOnlyDictionary<string, Type> CreateBindingDataContract()
["RoutingKey"] = typeof(string),
["BasicProperties"] = typeof(IBasicProperties),
["Body"] = typeof(ReadOnlyMemory<byte>),
["MessageActions"] = typeof(RabbitMQMessageActions),
};

return contract;
}

internal static IReadOnlyDictionary<string, object> CreateBindingData(BasicDeliverEventArgs value)
internal static IReadOnlyDictionary<string, object> CreateBindingData(BasicDeliverEventArgs value, RabbitMQMessageActions messageActions)
{
var bindingData = new Dictionary<string, object>(StringComparer.OrdinalIgnoreCase);

Expand All @@ -85,6 +89,7 @@ internal static IReadOnlyDictionary<string, object> CreateBindingData(BasicDeliv
SafeAddValue(() => bindingData.Add(nameof(value.RoutingKey), value.RoutingKey));
SafeAddValue(() => bindingData.Add(nameof(value.BasicProperties), value.BasicProperties));
SafeAddValue(() => bindingData.Add(nameof(value.Body), value.Body));
SafeAddValue(() => bindingData.Add(nameof(messageActions), messageActions));

return bindingData;
}
Expand Down