Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions dotnet/Directory.Build.targets
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
<Project>

<!-- Direct all packages under 'dotnet' to get versions from Directory.Packages.props -->
<!-- using Central Package Management feature -->
<!-- https://learn.microsoft.com/en-us/nuget/consume-packages/Central-Package-Management -->
<Sdk Name="Microsoft.Build.CentralPackageVersions" Version="2.1.3" />

<!-- Only run 'dotnet format' on dev machines, Release builds. Skip on GitHub Actions -->
<!-- as this runs in its own Actions job. -->
<Target Name="DotnetFormatOnBuild" BeforeTargets="Build"
Condition=" '$(Configuration)' == 'Release' AND '$(GITHUB_ACTIONS)' == '' ">
<Message Text="Running dotnet format" Importance="high" />
<Exec Command="dotnet format --no-restore -v diag $(ProjectFileName)" />
</Target>

<PropertyGroup>
<!-- IsAotCompatible implies IsTrimmable, but only with the .NET 8+ SDK. -->
<!-- Once we're only building with the .NET 8+, this can be removed. -->
<IsTrimmable Condition="'$(IsAotCompatible)' == 'true'">true</IsTrimmable>
</PropertyGroup>

</Project>
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<!-- THIS PROPERTY GROUP MUST COME FIRST -->
<AssemblyName>Microsoft.SemanticKernel.Connectors.AI.HuggingFace</AssemblyName>
<RootNamespace>$(AssemblyName)</RootNamespace>
<TargetFramework>netstandard2.0</TargetFramework>
<IsAotCompatible>true</IsAotCompatible>
</PropertyGroup>

<!-- IMPORT NUGET PACKAGE SHARED PROPERTIES -->
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using System.Text.Json.Serialization;
using Microsoft.SemanticKernel.Connectors.AI.HuggingFace.TextCompletion;
using Microsoft.SemanticKernel.Connectors.AI.HuggingFace.TextEmbedding;

namespace Microsoft.SemanticKernel.Text;

[JsonSerializable(typeof(List<TextCompletionResponse>))]
[JsonSerializable(typeof(TextCompletionRequest))]
[JsonSerializable(typeof(TextEmbeddingRequest))]
[JsonSerializable(typeof(TextEmbeddingResponse))]
internal sealed partial class SourceGenerationContext : JsonSerializerContext
{
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using Microsoft.SemanticKernel.AI;
using Microsoft.SemanticKernel.AI.TextCompletion;
using Microsoft.SemanticKernel.Diagnostics;
using Microsoft.SemanticKernel.Text;

namespace Microsoft.SemanticKernel.Connectors.AI.HuggingFace.TextCompletion;

Expand Down Expand Up @@ -161,7 +162,7 @@ private async Task<IReadOnlyList<ITextStreamingResult>> ExecuteGetCompletionsAsy
{
Method = HttpMethod.Post,
RequestUri = this.GetRequestUri(),
Content = new StringContent(JsonSerializer.Serialize(completionRequest)),
Content = new StringContent(JsonSerializer.Serialize(completionRequest, SourceGenerationContext.Default.TextCompletionRequest)),
};

httpRequestMessage.Headers.Add("User-Agent", HttpUserAgent);
Expand All @@ -175,7 +176,7 @@ private async Task<IReadOnlyList<ITextStreamingResult>> ExecuteGetCompletionsAsy

var body = await response.Content.ReadAsStringAsync().ConfigureAwait(false);

List<TextCompletionResponse>? completionResponse = JsonSerializer.Deserialize<List<TextCompletionResponse>>(body);
List<TextCompletionResponse>? completionResponse = JsonSerializer.Deserialize(body, SourceGenerationContext.Default.ListTextCompletionResponse);

if (completionResponse is null)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using Microsoft.SemanticKernel.AI;
using Microsoft.SemanticKernel.AI.Embeddings;
using Microsoft.SemanticKernel.Diagnostics;
using Microsoft.SemanticKernel.Text;

namespace Microsoft.SemanticKernel.Connectors.AI.HuggingFace.TextEmbedding;

Expand Down Expand Up @@ -141,15 +142,15 @@ private async Task<IList<Embedding<float>>> ExecuteEmbeddingRequestAsync(IList<s
{
Method = HttpMethod.Post,
RequestUri = this.GetRequestUri(),
Content = new StringContent(JsonSerializer.Serialize(embeddingRequest)),
Content = new StringContent(JsonSerializer.Serialize(embeddingRequest, SourceGenerationContext.Default.TextEmbeddingRequest)),
};

httpRequestMessage.Headers.Add("User-Agent", HttpUserAgent);

var response = await this._httpClient.SendAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false);
var body = await response.Content.ReadAsStringAsync().ConfigureAwait(false);

var embeddingResponse = JsonSerializer.Deserialize<TextEmbeddingResponse>(body);
var embeddingResponse = JsonSerializer.Deserialize<TextEmbeddingResponse>(body, SourceGenerationContext.Default.TextEmbeddingResponse);

return embeddingResponse?.Embeddings?.Select(l => new Embedding<float>(l.Embedding!, transferOwnership: true)).ToList()!;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
<TargetFramework>netstandard2.0</TargetFramework>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<NoWarn>$(NoWarn);NU5104</NoWarn>
<IsAotCompatible>true</IsAotCompatible>
</PropertyGroup>

<!-- IMPORT NUGET PACKAGE SHARED PROPERTIES -->
Expand All @@ -28,16 +29,8 @@
</ItemGroup>

<ItemGroup>
<None Remove="Tokenizers\Settings\encoder.json" />
<Content Include="Tokenizers\Settings\encoder.json">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<PackageCopyToOutput>true</PackageCopyToOutput>
</Content>
<None Remove="Tokenizers\Settings\vocab.bpe" />
<Content Include="Tokenizers\Settings\vocab.bpe">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<PackageCopyToOutput>true</PackageCopyToOutput>
</Content>
<EmbeddedResource Include="Tokenizers\Settings\encoder.json" LogicalName="encoder.json" />
<EmbeddedResource Include="Tokenizers\Settings\vocab.bpe" LogicalName="vocab.bpe" />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
using System.Text;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Text.Json.Serialization.Metadata;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.SemanticKernel.AI;
using Microsoft.SemanticKernel.AI.Embeddings;
using Microsoft.SemanticKernel.Connectors.AI.OpenAI.ImageGeneration;
using Microsoft.SemanticKernel.Connectors.AI.OpenAI.TextEmbedding;
using Microsoft.SemanticKernel.Diagnostics;
using Microsoft.SemanticKernel.Text;

Expand Down Expand Up @@ -53,7 +53,7 @@ private protected async Task<IList<Embedding<float>>> ExecuteTextEmbeddingReques
string requestBody,
CancellationToken cancellationToken = default)
{
var result = await this.ExecutePostRequestAsync<TextEmbeddingResponse>(url, requestBody, cancellationToken).ConfigureAwait(false);
var result = await this.ExecutePostRequestAsync(url, requestBody, SourceGenerationContext.Default.TextEmbeddingResponse, cancellationToken).ConfigureAwait(false);
if (result.Embeddings is not { Count: >= 1 })
{
throw new AIException(
Expand All @@ -79,15 +79,15 @@ private protected async Task<IList<string>> ExecuteImageGenerationRequestAsync(
Func<ImageGenerationResponse.Image, string> extractResponseFunc,
CancellationToken cancellationToken = default)
{
var result = await this.ExecutePostRequestAsync<ImageGenerationResponse>(url, requestBody, cancellationToken).ConfigureAwait(false);
var result = await this.ExecutePostRequestAsync(url, requestBody, SourceGenerationContext.Default.ImageGenerationResponse, cancellationToken).ConfigureAwait(false);
return result.Images.Select(extractResponseFunc).ToList();
}

private protected virtual string? GetErrorMessageFromResponse(string jsonResponsePayload)
{
try
{
JsonNode? root = JsonSerializer.Deserialize<JsonNode>(jsonResponsePayload);
JsonNode? root = JsonSerializer.Deserialize(jsonResponsePayload, SourceGenerationContext.Default.JsonNode);

return root?["error"]?["message"]?.GetValue<string>();
}
Expand All @@ -114,14 +114,14 @@ private protected async Task<IList<string>> ExecuteImageGenerationRequestAsync(
/// </summary>
private readonly HttpClient _httpClient;

private protected async Task<T> ExecutePostRequestAsync<T>(string url, string requestBody, CancellationToken cancellationToken = default)
private protected async Task<T> ExecutePostRequestAsync<T>(string url, string requestBody, JsonTypeInfo<T> jsonTypeInfo, CancellationToken cancellationToken = default)
{
try
{
using var content = new StringContent(requestBody, Encoding.UTF8, "application/json");
using var response = await this.ExecuteRequestAsync(url, HttpMethod.Post, content, cancellationToken).ConfigureAwait(false);
string responseJson = await response.Content.ReadAsStringAsync().ConfigureAwait(false);
T result = this.JsonDeserialize<T>(responseJson);
T result = this.JsonDeserialize<T>(responseJson, jsonTypeInfo);
return result;
}
catch (Exception e) when (e is not AIException)
Expand All @@ -132,9 +132,9 @@ private protected async Task<T> ExecutePostRequestAsync<T>(string url, string re
}
}

private protected T JsonDeserialize<T>(string responseJson)
private protected T JsonDeserialize<T>(string responseJson, JsonTypeInfo<T> jsonTypeInfo)
{
var result = Json.Deserialize<T>(responseJson);
var result = JsonSerializer.Deserialize(responseJson, jsonTypeInfo);
if (result is null)
{
throw new AIException(AIException.ErrorCodes.InvalidResponseContent, "Response JSON parse error");
Expand All @@ -161,86 +161,91 @@ private protected async Task<HttpResponseMessage> ExecuteRequestAsync(string url

this._log.LogTrace("HTTP response: {0} {1}", (int)response.StatusCode, response.StatusCode.ToString("G"));

if (response.IsSuccessStatusCode)
{
return response;
}

string responseJson = await response.Content.ReadAsStringAsync().ConfigureAwait(false);
string? errorDetail = this.GetErrorMessageFromResponse(responseJson);
switch ((HttpStatusCodeType)response.StatusCode)

if (!response.IsSuccessStatusCode)
{
case HttpStatusCodeType.BadRequest:
case HttpStatusCodeType.MethodNotAllowed:
case HttpStatusCodeType.NotFound:
case HttpStatusCodeType.NotAcceptable:
case HttpStatusCodeType.Conflict:
case HttpStatusCodeType.Gone:
case HttpStatusCodeType.LengthRequired:
case HttpStatusCodeType.PreconditionFailed:
case HttpStatusCodeType.RequestEntityTooLarge:
case HttpStatusCodeType.RequestUriTooLong:
case HttpStatusCodeType.UnsupportedMediaType:
case HttpStatusCodeType.RequestedRangeNotSatisfiable:
case HttpStatusCodeType.ExpectationFailed:
case HttpStatusCodeType.HttpVersionNotSupported:
case HttpStatusCodeType.UpgradeRequired:
case HttpStatusCodeType.MisdirectedRequest:
case HttpStatusCodeType.UnprocessableEntity:
case HttpStatusCodeType.Locked:
case HttpStatusCodeType.FailedDependency:
case HttpStatusCodeType.PreconditionRequired:
case HttpStatusCodeType.RequestHeaderFieldsTooLarge:
throw new AIException(
AIException.ErrorCodes.InvalidRequest,
$"The request is not valid, HTTP status: {response.StatusCode:G}",
errorDetail);

case HttpStatusCodeType.Unauthorized:
case HttpStatusCodeType.Forbidden:
case HttpStatusCodeType.ProxyAuthenticationRequired:
case HttpStatusCodeType.UnavailableForLegalReasons:
case HttpStatusCodeType.NetworkAuthenticationRequired:
throw new AIException(
AIException.ErrorCodes.AccessDenied,
$"The request is not authorized, HTTP status: {response.StatusCode:G}",
errorDetail);

case HttpStatusCodeType.RequestTimeout:
throw new AIException(
AIException.ErrorCodes.RequestTimeout,
$"The request timed out, HTTP status: {response.StatusCode:G}");

case HttpStatusCodeType.TooManyRequests:
throw new AIException(
AIException.ErrorCodes.Throttling,
$"Too many requests, HTTP status: {response.StatusCode:G}",
errorDetail);

case HttpStatusCodeType.InternalServerError:
case HttpStatusCodeType.NotImplemented:
case HttpStatusCodeType.BadGateway:
case HttpStatusCodeType.ServiceUnavailable:
case HttpStatusCodeType.GatewayTimeout:
case HttpStatusCodeType.InsufficientStorage:
throw new AIException(
AIException.ErrorCodes.ServiceError,
$"The service failed to process the request, HTTP status: {response.StatusCode:G}",
errorDetail);

default:
throw new AIException(
AIException.ErrorCodes.UnknownError,
$"Unexpected HTTP response, status: {response.StatusCode:G}",
errorDetail);
switch ((HttpStatusCodeType)response.StatusCode)
{
case HttpStatusCodeType.BadRequest:
case HttpStatusCodeType.MethodNotAllowed:
case HttpStatusCodeType.NotFound:
case HttpStatusCodeType.NotAcceptable:
case HttpStatusCodeType.Conflict:
case HttpStatusCodeType.Gone:
case HttpStatusCodeType.LengthRequired:
case HttpStatusCodeType.PreconditionFailed:
case HttpStatusCodeType.RequestEntityTooLarge:
case HttpStatusCodeType.RequestUriTooLong:
case HttpStatusCodeType.UnsupportedMediaType:
case HttpStatusCodeType.RequestedRangeNotSatisfiable:
case HttpStatusCodeType.ExpectationFailed:
case HttpStatusCodeType.HttpVersionNotSupported:
case HttpStatusCodeType.UpgradeRequired:
case HttpStatusCodeType.MisdirectedRequest:
case HttpStatusCodeType.UnprocessableEntity:
case HttpStatusCodeType.Locked:
case HttpStatusCodeType.FailedDependency:
case HttpStatusCodeType.PreconditionRequired:
case HttpStatusCodeType.RequestHeaderFieldsTooLarge:
throw new AIException(
AIException.ErrorCodes.InvalidRequest,
$"The request is not valid, HTTP status: {response.StatusCode:G}",
errorDetail);

case HttpStatusCodeType.Unauthorized:
case HttpStatusCodeType.Forbidden:
case HttpStatusCodeType.ProxyAuthenticationRequired:
case HttpStatusCodeType.UnavailableForLegalReasons:
case HttpStatusCodeType.NetworkAuthenticationRequired:
throw new AIException(
AIException.ErrorCodes.AccessDenied,
$"The request is not authorized, HTTP status: {response.StatusCode:G}",
errorDetail);

case HttpStatusCodeType.RequestTimeout:
throw new AIException(
AIException.ErrorCodes.RequestTimeout,
$"The request timed out, HTTP status: {response.StatusCode:G}");

case HttpStatusCodeType.TooManyRequests:
throw new AIException(
AIException.ErrorCodes.Throttling,
$"Too many requests, HTTP status: {response.StatusCode:G}",
errorDetail);

case HttpStatusCodeType.InternalServerError:
case HttpStatusCodeType.NotImplemented:
case HttpStatusCodeType.BadGateway:
case HttpStatusCodeType.ServiceUnavailable:
case HttpStatusCodeType.GatewayTimeout:
case HttpStatusCodeType.InsufficientStorage:
throw new AIException(
AIException.ErrorCodes.ServiceError,
$"The service failed to process the request, HTTP status: {response.StatusCode:G}",
errorDetail);

default:
throw new AIException(
AIException.ErrorCodes.UnknownError,
$"Unexpected HTTP response, status: {response.StatusCode:G}",
errorDetail);
}
}

return response;
}
catch (Exception e) when (e is not AIException)
{
throw new AIException(
AIException.ErrorCodes.UnknownError,
$"Something went wrong: {e.Message}", e);
}
finally
{
response?.Dispose();
}
}

#endregion
Expand Down
Loading