Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -22,31 +22,31 @@ public void Opens_Connection()
var options = new OptionsWrapper<RabbitMQOptions>(new RabbitMQOptions());
var mockServiceFactory = new Mock<IRabbitMQServiceFactory>();
var config = new RabbitMQExtensionConfigProvider(options, new Mock<INameResolver>().Object, mockServiceFactory.Object, new LoggerFactory(), EmptyConfig, DrainModeManager);
mockServiceFactory.Setup(m => m.CreateService(It.IsAny<string>(), false)).Returns(new Mock<IRabbitMQService>().Object);
mockServiceFactory.Setup(m => m.CreateService(It.IsAny<string>(), false, It.IsAny<ILogger>())).Returns(new Mock<IRabbitMQService>().Object);
RabbitMQAttribute attr = GetTestAttribute();

var clientBuilder = new RabbitMQClientBuilder(config, options);
IModel model = clientBuilder.Convert(attr);
IRabbitMQService service = clientBuilder.Convert(attr);

mockServiceFactory.Verify(m => m.CreateService(It.IsAny<string>(), false), Times.Exactly(1));
mockServiceFactory.Verify(m => m.CreateService(It.IsAny<string>(), false, It.IsAny<ILogger>()), Times.Exactly(1));
}

[Fact]
public void TestWhetherConnectionIsPooled()
{
var options = new OptionsWrapper<RabbitMQOptions>(new RabbitMQOptions());
var mockServiceFactory = new Mock<IRabbitMQServiceFactory>();
mockServiceFactory.SetupSequence(m => m.CreateService(It.IsAny<string>(), false))
mockServiceFactory.SetupSequence(m => m.CreateService(It.IsAny<string>(), false, It.IsAny<ILogger>()))
.Returns(GetRabbitMQService());
var config = new RabbitMQExtensionConfigProvider(options, new Mock<INameResolver>().Object, mockServiceFactory.Object, new LoggerFactory(), EmptyConfig, DrainModeManager);
RabbitMQAttribute attr = GetTestAttribute();

var clientBuilder = new RabbitMQClientBuilder(config, options);

IModel model = clientBuilder.Convert(attr);
IModel model2 = clientBuilder.Convert(attr);
IRabbitMQService service = clientBuilder.Convert(attr);
IRabbitMQService service2 = clientBuilder.Convert(attr);

Assert.Equal(model, model2);
Assert.Equal(service, service2);
}

private static RabbitMQAttribute GetTestAttribute()
Expand All @@ -60,7 +60,6 @@ private static RabbitMQAttribute GetTestAttribute()
private static IRabbitMQService GetRabbitMQService()
{
var mockService = new Mock<IRabbitMQService>();
mockService.Setup(a => a.Model).Returns(new Mock<IModel>().Object);
return mockService.Object;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using Microsoft.Azure.WebJobs.Host;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
using Moq;
Expand All @@ -18,11 +19,11 @@ public void TestConnectionPooling()
var rabbitmqServiceFactory = new Mock<IRabbitMQServiceFactory>();

rabbitmqServiceFactory
.SetupSequence(a => a.CreateService(It.IsAny<string>(), It.IsAny<string>(), false))
.SetupSequence(a => a.CreateService(It.IsAny<string>(), It.IsAny<string>(), false, It.IsAny<ILogger>()))
.Returns(new Mock<IRabbitMQService>().Object);

rabbitmqServiceFactory
.SetupSequence(a => a.CreateService(It.IsAny<string>(), false))
.SetupSequence(a => a.CreateService(It.IsAny<string>(), false, It.IsAny<ILogger>()))
.Returns(new Mock<IRabbitMQService>().Object);

var extensionConfigProvider = new RabbitMQExtensionConfigProvider(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.WebJobs.Host;
using Microsoft.Azure.WebJobs.Host.Executors;
using Microsoft.Extensions.Logging;
using Moq;
using RabbitMQ.Client;
using RabbitMQ.Client.Events;
using Xunit;
Expand All @@ -23,6 +30,7 @@ public void Verify_BindingDataContract_Types()
["RoutingKey"] = typeof(string),
["BasicProperties"] = typeof(IBasicProperties),
["Body"] = typeof(ReadOnlyMemory<byte>),
["MessageActions"] = typeof(RabbitMQMessageActions),
};

IReadOnlyDictionary<string, Type> actualContract = RabbitMQTriggerBinding.CreateBindingDataContract();
Expand All @@ -44,6 +52,7 @@ public void Verify_BindingDataContract_Values()

ReadOnlyMemory<byte> body = buffer;
var eventArgs = new BasicDeliverEventArgs("ConsumerName", deliveryTag, false, "n/a", "QueueName", null, body);
var messageActions = new RabbitMQMessageActions(Mock.Of<IRabbitMQService>(), eventArgs);

var data = new Dictionary<string, object>(StringComparer.OrdinalIgnoreCase)
{
Expand All @@ -54,13 +63,84 @@ public void Verify_BindingDataContract_Values()
["Body"] = body,
["Exchange"] = eventArgs.Exchange,
["BasicProperties"] = eventArgs.BasicProperties,
["MessageActions"] = messageActions,
};

IReadOnlyDictionary<string, object> actualContract = RabbitMQTriggerBinding.CreateBindingData(eventArgs);
IReadOnlyDictionary<string, object> actualContract = RabbitMQTriggerBinding.CreateBindingData(eventArgs, messageActions);

foreach (KeyValuePair<string, object> item in actualContract)
{
Assert.Equal(data[item.Key], item.Value);
}
}

[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task RabbitMQTrigger_ManualAck_BasicAckBehavior(bool disableAck)
{
// Arrange
var mockservice = new Mock<IRabbitMQService>();
var mockModel = new Mock<IModel>();
mockservice.Setup(a => a.CreateConsumer()).Returns(new AsyncEventingBasicConsumer(mockModel.Object));

var mockExecutor = new Mock<ITriggeredFunctionExecutor>();
var mockLogger = new Mock<ILogger>();
var mockDrainModeManager = new Mock<IDrainModeManager>();
var mockBasicProperties = new Mock<IBasicProperties>();

// Simulate successful function execution
mockExecutor
.Setup(executor => executor.TryExecuteAsync(It.IsAny<TriggeredFunctionData>(), It.IsAny<CancellationToken>()))
.ReturnsAsync(new FunctionResult(true));

var listener = new RabbitMQListener(
mockservice.Object,
mockExecutor.Object,
mockLogger.Object,
functionId: "test-function",
queueName: "test-queue",
disableAck: disableAck,
prefetchCount: 10,
drainModeManager: mockDrainModeManager.Object);

var eventArgs = new BasicDeliverEventArgs
{
DeliveryTag = 1,
Body = new ReadOnlyMemory<byte>([0x01, 0x02, 0x03]),
BasicProperties = mockBasicProperties.Object,
};

// Act
await listener.StartAsync(CancellationToken.None);

// Find the Consumer instance passed to RabbitMQService.Consume method
IInvocation consumeInvocation = mockservice.Invocations
.FirstOrDefault(invocation => invocation.Method.Name == "Consume");

Assert.NotNull(consumeInvocation);

var consumer = consumeInvocation.Arguments[2] as AsyncEventingBasicConsumer;
Assert.NotNull(consumer);

// Simulate message delivery
await consumer.HandleBasicDeliver(
consumerTag: "ctag",
deliveryTag: eventArgs.DeliveryTag,
redelivered: false,
exchange: string.Empty,
routingKey: string.Empty,
properties: eventArgs.BasicProperties,
body: eventArgs.Body.ToArray());

// Assert
if (disableAck)
{
mockservice.Verify(channel => channel.Acknowledge(It.IsAny<ulong>(), It.IsAny<bool>(), It.IsAny<string>()), Times.Never, "BasicAck should not be called when DisableAck is true.");
}
else
{
mockservice.Verify(channel => channel.Acknowledge(It.IsAny<ulong>(), It.IsAny<bool>(), It.IsAny<string>()), Times.Once, "BasicAck should be called when DisableAck is false.");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,12 @@ public void ScaleMonitorGetScaleStatus_CountNotIncreasingOrDecreasing_ReturnsNon
private static RabbitMQListener GetScaleMonitor(string functionId, string queueName)
{
return new RabbitMQListener(
Mock.Of<IModel>(),
Mock.Of<IRabbitMQService>(),
Mock.Of<ITriggeredFunctionExecutor>(),
Mock.Of<ILogger>(),
functionId,
queueName,
false,
7357,
DrainModeManager);
}
Expand All @@ -219,11 +220,12 @@ private static (IScaleMonitor<RabbitMQTriggerMetrics> Monitor, List<string> LogM
(Mock<ILogger> mockLogger, List<string> logMessages) = CreateMockLogger();

IScaleMonitor<RabbitMQTriggerMetrics> monitor = new RabbitMQListener(
Mock.Of<IModel>(),
Mock.Of<IRabbitMQService>(),
Mock.Of<ITriggeredFunctionExecutor>(),
mockLogger.Object,
"testFunctionId",
"testQueueName",
false,
7357,
DrainModeManager);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@

namespace Microsoft.Azure.WebJobs.Extensions.RabbitMQ;

internal class RabbitMQClientBuilder(RabbitMQExtensionConfigProvider configProvider, IOptions<RabbitMQOptions> options) : IConverter<RabbitMQAttribute, IModel>
internal class RabbitMQClientBuilder(RabbitMQExtensionConfigProvider configProvider, IOptions<RabbitMQOptions> options) : IConverter<RabbitMQAttribute, IRabbitMQService>
{
private readonly RabbitMQExtensionConfigProvider configProvider = configProvider;
private readonly IOptions<RabbitMQOptions> options = options;

public IModel Convert(RabbitMQAttribute attribute)
public IRabbitMQService Convert(RabbitMQAttribute attribute)
{
return this.CreateModelFromAttribute(attribute);
}

private IModel CreateModelFromAttribute(RabbitMQAttribute attribute)
private IRabbitMQService CreateModelFromAttribute(RabbitMQAttribute attribute)
{
if (attribute == null)
{
Expand All @@ -27,8 +27,6 @@ private IModel CreateModelFromAttribute(RabbitMQAttribute attribute)
string resolvedConnectionString = Utility.FirstOrDefault(attribute.ConnectionStringSetting, this.options.Value.ConnectionString);
bool resolvedDisableCertificateValidation = Utility.FirstOrDefault(attribute.DisableCertificateValidation, this.options.Value.DisableCertificateValidation);

IRabbitMQService service = this.configProvider.GetService(resolvedConnectionString, resolvedDisableCertificateValidation);

return service.Model;
return this.configProvider.GetService(resolvedConnectionString, resolvedDisableCertificateValidation);
}
}
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the MIT License. See License.txt in the project root for license information.

using Microsoft.Extensions.Logging;

namespace Microsoft.Azure.WebJobs.Extensions.RabbitMQ;

internal class DefaultRabbitMQServiceFactory : IRabbitMQServiceFactory
{
public IRabbitMQService CreateService(string connectionString, string queueName, bool disableCertificateValidation)
public IRabbitMQService CreateService(string connectionString, string queueName, bool disableCertificateValidation, ILogger logger)
{
return new RabbitMQService(connectionString, queueName, disableCertificateValidation);
return new RabbitMQService(connectionString, queueName, disableCertificateValidation, logger);
}

public IRabbitMQService CreateService(string connectionString, bool disableCertificateValidation)
public IRabbitMQService CreateService(string connectionString, bool disableCertificateValidation, ILogger logger)
{
return new RabbitMQService(connectionString, disableCertificateValidation);
return new RabbitMQService(connectionString, disableCertificateValidation, logger);
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the MIT License. See License.txt in the project root for license information.

using Microsoft.Extensions.Logging;

namespace Microsoft.Azure.WebJobs.Extensions.RabbitMQ;

public interface IRabbitMQServiceFactory
{
IRabbitMQService CreateService(string connectionString, string queueName, bool disableCertificateValidation);
IRabbitMQService CreateService(string connectionString, string queueName, bool disableCertificateValidation, ILogger logger);

IRabbitMQService CreateService(string connectionString, bool disableCertificateValidation);
IRabbitMQService CreateService(string connectionString, bool disableCertificateValidation, ILogger logger);
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the MIT License. See License.txt in the project root for license information.

using System;
Expand Down Expand Up @@ -84,7 +84,7 @@ internal IRabbitMQService GetService(string connectionString, string queueName,
string[] keyArray =
[connectionString, queueName, disableCertificateValidation.ToString()];
string key = string.Join(",", keyArray);
return this.connectionParametersToService.GetOrAdd(key, _ => this.rabbitMQServiceFactory.CreateService(connectionString, queueName, disableCertificateValidation));
return this.connectionParametersToService.GetOrAdd(key, _ => this.rabbitMQServiceFactory.CreateService(connectionString, queueName, disableCertificateValidation, this.logger));
}

// Overloaded method used only for getting the RabbitMQ client.
Expand All @@ -93,6 +93,6 @@ internal IRabbitMQService GetService(string connectionString, bool disableCertif
string[] keyArray =
[connectionString, disableCertificateValidation.ToString()];
string key = string.Join(",", keyArray);
return this.connectionParametersToService.GetOrAdd(key, _ => this.rabbitMQServiceFactory.CreateService(connectionString, disableCertificateValidation));
return this.connectionParametersToService.GetOrAdd(key, _ => this.rabbitMQServiceFactory.CreateService(connectionString, disableCertificateValidation, this.logger));
}
}
Original file line number Diff line number Diff line change
@@ -1,17 +1,37 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the MIT License. See License.txt in the project root for license information.

using System;
using RabbitMQ.Client;
using RabbitMQ.Client.Events;

namespace Microsoft.Azure.WebJobs.Extensions.RabbitMQ;

public interface IRabbitMQService
{
IModel Model { get; }

IBasicPublishBatch BasicPublishBatch { get; }

object PublishBatchLock { get; }

void ResetPublishBatch();

void ConfigureQos(uint prefetchSize, ushort prefetchCount, bool global);

QueueDeclareOk GetQueueInfo(string queueName);

AsyncEventingBasicConsumer CreateConsumer();

string Consume(string queue, bool autoAck, AsyncEventingBasicConsumer consumer);

void OnMessageConsumed(string consumerTag, ulong deliveryTag);

void Acknowledge(ulong deliveryTag, bool multiple, string logDetails);

void Reject(ulong deliveryTag, bool requeue, string logDetails);

void Publish(string exchange, string routingKey, IBasicProperties basicProperties, ReadOnlyMemory<byte> body);

void Cancel(string consumerTag);

void Close();
}
Loading
Loading