Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 43 additions & 27 deletions src/Plugins/RpcServer/RpcServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ public partial class RpcServer : IDisposable
private const string HttpMethodGet = "GET";
private const string HttpMethodPost = "POST";

private readonly Dictionary<string, Delegate> _methods = new();
internal record struct RpcParameter(string Name, Type Type, bool Required, object? DefaultValue);

private record struct RpcMethod(Delegate Delegate, RpcParameter[] Parameters);

private readonly Dictionary<string, RpcMethod> _methods = new();

private IWebHost? host;
private RpcServersSettings settings;
Expand Down Expand Up @@ -324,9 +328,9 @@ public async Task ProcessAsync(HttpContext context)
{
(CheckAuth(context) && !settings.DisabledMethods.Contains(method)).True_Or(RpcError.AccessDenied);

if (_methods.TryGetValue(method, out var func))
if (_methods.TryGetValue(method, out var rpcMethod))
{
response["result"] = ProcessParamsMethod(jsonParameters, func) switch
response["result"] = ProcessParamsMethod(jsonParameters, rpcMethod) switch
{
JToken result => result,
Task<JToken> task => await task,
Expand Down Expand Up @@ -366,25 +370,24 @@ public async Task ProcessAsync(HttpContext context)
}
}

private object? ProcessParamsMethod(JArray arguments, Delegate func)
private object? ProcessParamsMethod(JArray arguments, RpcMethod rpcMethod)
{
var parameterInfos = func.Method.GetParameters();
var args = new object?[parameterInfos.Length];
var args = new object?[rpcMethod.Parameters.Length];

// If the method has only one parameter of type JArray, invoke the method directly with the arguments
if (parameterInfos.Length == 1 && parameterInfos[0].ParameterType == typeof(JArray))
if (rpcMethod.Parameters.Length == 1 && rpcMethod.Parameters[0].Type == typeof(JArray))
{
return func.DynamicInvoke(arguments);
return rpcMethod.Delegate.DynamicInvoke(arguments);
}

for (var i = 0; i < parameterInfos.Length; i++)
for (var i = 0; i < rpcMethod.Parameters.Length; i++)
{
var param = parameterInfos[i];
var param = rpcMethod.Parameters[i];
if (arguments.Count > i && arguments[i] is not null) // Donot parse null values
{
try
{
args[i] = ParameterConverter.AsParameter(arguments[i]!, param.ParameterType);
args[i] = ParameterConverter.AsParameter(arguments[i]!, param.Type);
}
catch (Exception e) when (e is not RpcException)
{
Expand All @@ -393,22 +396,13 @@ public async Task ProcessAsync(HttpContext context)
}
else
{
if (param.IsOptional)
{
args[i] = param.DefaultValue;
}
else if (param.ParameterType.IsValueType && Nullable.GetUnderlyingType(param.ParameterType) == null)
{
if (param.Required)
throw new ArgumentException($"Required parameter '{param.Name}' is missing");
}
else
{
args[i] = null;
}
args[i] = param.DefaultValue;
}
}

return func.DynamicInvoke(args);
return rpcMethod.Delegate.DynamicInvoke(args);
}

public void RegisterMethods(object handler)
Expand All @@ -420,11 +414,33 @@ public void RegisterMethods(object handler)
if (rpcMethod is null) continue;

var name = string.IsNullOrEmpty(rpcMethod.Name) ? method.Name.ToLowerInvariant() : rpcMethod.Name;
var parameters = method.GetParameters().Select(p => p.ParameterType).ToArray();
var delegateType = Expression.GetDelegateType(parameters.Concat([method.ReturnType]).ToArray());

_methods[name] = Delegate.CreateDelegate(delegateType, handler, method);
var delegateParams = method.GetParameters()
.Select(p => p.ParameterType)
.Concat([method.ReturnType])
.ToArray();
var delegateType = Expression.GetDelegateType(delegateParams);

_methods[name] = new RpcMethod(
Delegate.CreateDelegate(delegateType, handler, method),
method.GetParameters().Select(AsRpcParameter).ToArray()
);
}
}

static internal RpcParameter AsRpcParameter(ParameterInfo param)
{
// Required if not optional and not nullable
// For reference types, if parameter has not default value and nullable is disabled, it is optional.
// For value types, if parameter has not default value, it is required.
var required = param.IsOptional ? false : NotNullParameter(param);
return new RpcParameter(param.Name ?? string.Empty, param.ParameterType, required, param.DefaultValue);
}

static private bool NotNullParameter(ParameterInfo param)
{
var context = new NullabilityInfoContext();
var nullabilityInfo = context.Create(param);
return nullabilityInfo.WriteState == NullabilityState.NotNull;
}
}
}
45 changes: 42 additions & 3 deletions tests/Neo.Plugins.RpcServer.Tests/UT_RpcServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,18 @@ public async Task TestProcessRequest_MixedBatch()

private class MockRpcMethods
{
#nullable enable
[RpcMethod]
internal JToken GetMockMethod() => "mock";
public JToken GetMockMethod(string info) => $"mock {info}";

public JToken NullContextMethod(string? info) => $"mock {info}";
#nullable restore

#nullable disable
public JToken NullableMethod(string info) => $"mock {info}";

public JToken OptionalMethod(string info = "default") => $"mock {info}";
#nullable restore
}

[TestMethod]
Expand All @@ -256,7 +266,7 @@ public async Task TestRegisterMethods()
// Request ProcessAsync with a valid request
var context = new DefaultHttpContext();
var body = """
{"jsonrpc": "2.0", "method": "getmockmethod", "params": [], "id": 1 }
{"jsonrpc": "2.0", "method": "getmockmethod", "params": ["test"], "id": 1 }
""";
context.Request.Method = "POST";
context.Request.Body = new MemoryStream(Encoding.UTF8.GetBytes(body));
Expand All @@ -276,10 +286,39 @@ public async Task TestRegisterMethods()
// Parse the JSON response and check the result
var responseJson = JToken.Parse(output);
Assert.IsNotNull(responseJson["result"]);
Assert.AreEqual("mock", responseJson["result"].AsString());
Assert.AreEqual("mock test", responseJson["result"].AsString());
Assert.AreEqual(200, context.Response.StatusCode);
}

[TestMethod]
public void TestNullableParameter()
{
var method = typeof(MockRpcMethods).GetMethod("GetMockMethod");
var parameter = RpcServer.AsRpcParameter(method.GetParameters()[0]);
Assert.IsTrue(parameter.Required);
Assert.AreEqual(typeof(string), parameter.Type);
Assert.AreEqual("info", parameter.Name);

method = typeof(MockRpcMethods).GetMethod("NullableMethod");
parameter = RpcServer.AsRpcParameter(method.GetParameters()[0]);
Assert.IsFalse(parameter.Required);
Assert.AreEqual(typeof(string), parameter.Type);
Assert.AreEqual("info", parameter.Name);

method = typeof(MockRpcMethods).GetMethod("NullContextMethod");
parameter = RpcServer.AsRpcParameter(method.GetParameters()[0]);
Assert.IsFalse(parameter.Required);
Assert.AreEqual(typeof(string), parameter.Type);
Assert.AreEqual("info", parameter.Name);

method = typeof(MockRpcMethods).GetMethod("OptionalMethod");
parameter = RpcServer.AsRpcParameter(method.GetParameters()[0]);
Assert.IsFalse(parameter.Required);
Assert.AreEqual(typeof(string), parameter.Type);
Assert.AreEqual("info", parameter.Name);
Assert.AreEqual("default", parameter.DefaultValue);
}

[TestMethod]
public void TestRpcServerSettings_Load()
{
Expand Down
Loading