Skip to content

Commit

Permalink
Update to latest M.E.AI
Browse files Browse the repository at this point in the history
  • Loading branch information
stephentoub committed Nov 22, 2024
1 parent ef8251c commit e49b7e2
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 213 deletions.
8 changes: 4 additions & 4 deletions dotnet/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@
<PackageVersion Include="System.Net.Http" Version="4.3.4" />
<PackageVersion Include="System.Numerics.Tensors" Version="8.0.0" />
<PackageVersion Include="System.Text.Json" Version="8.0.5" />
<PackageVersion Include="OllamaSharp" Version="4.0.6" />
<PackageVersion Include="OllamaSharp" Version="4.0.8" />
<!-- Tokenizers -->
<PackageVersion Include="Microsoft.ML.Tokenizers" Version="1.0.0" />
<PackageVersion Include="Microsoft.DeepDev.TokenizerLib" Version="1.3.3" />
<PackageVersion Include="SharpToken" Version="2.0.3" />
<!-- Microsoft.Extensions.* -->
<PackageVersion Include="Microsoft.Extensions.AI" Version="9.0.0-preview.9.24556.5" />
<PackageVersion Include="Microsoft.Extensions.AI.Abstractions" Version="9.0.0-preview.9.24556.5" />
<PackageVersion Include="Microsoft.Extensions.AI.AzureAIInference" Version="9.0.0-preview.9.24556.5" />
<PackageVersion Include="Microsoft.Extensions.AI" Version="9.0.1-preview.1.24570.5" />
<PackageVersion Include="Microsoft.Extensions.AI.Abstractions" Version="9.0.1-preview.1.24570.5" />
<PackageVersion Include="Microsoft.Extensions.AI.AzureAIInference" Version="9.0.1-preview.1.24570.5" />
<PackageVersion Include="Microsoft.Extensions.Configuration" Version="8.0.0" />
<PackageVersion Include="Microsoft.Extensions.Configuration.Binder" Version="8.0.2" />
<PackageVersion Include="Microsoft.Extensions.Configuration.EnvironmentVariables" Version="8.0.0" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
using Microsoft.Extensions.AI;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Http;

namespace Microsoft.SemanticKernel;

Expand Down Expand Up @@ -38,34 +38,26 @@ public static IServiceCollection AddAzureAIInferenceChatCompletion(
{
Verify.NotNull(services);

services.AddKeyedSingleton<IChatCompletionService>(serviceId, (serviceProvider, _) =>
return services.AddKeyedSingleton<IChatCompletionService>(serviceId, (serviceProvider, _) =>
{
var chatClientBuilder = new ChatClientBuilder()
.UseFunctionInvocation(config =>
config.MaximumIterationsPerRequest = MaxInflightAutoInvokes);

var logger = serviceProvider.GetService<ILoggerFactory>()?.CreateLogger<ChatCompletionsClient>();
if (logger is not null)
{
chatClientBuilder.UseLogging(logger);
}

var options = new AzureAIInferenceClientOptions();

httpClient ??= serviceProvider.GetService<HttpClient>();
if (httpClient is not null)
{
options.Transport = new HttpClientTransport(HttpClientProvider.GetHttpClient(httpClient, serviceProvider));
options.Transport = new HttpClientTransport(httpClient);
}

return
chatClientBuilder.Use(
new Microsoft.Extensions.AI.AzureAIInferenceChatClient(
modelId: modelId,
chatCompletionsClient: new Azure.AI.Inference.ChatCompletionsClient(endpoint, new Azure.AzureKeyCredential(apiKey ?? SingleSpace), options)
)
).AsChatCompletionService();
});
var loggerFactory = serviceProvider.GetService<ILoggerFactory>() ?? NullLoggerFactory.Instance;

return services;
return new Azure.AI.Inference.ChatCompletionsClient(endpoint, new Azure.AzureKeyCredential(apiKey ?? SingleSpace), options)
.AsChatClient(modelId)
.AsBuilder()
.UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes)
.UseLogging(loggerFactory)
.Build(serviceProvider)
.AsChatCompletionService(serviceProvider);
});
}

/// <summary>
Expand All @@ -88,34 +80,26 @@ public static IServiceCollection AddAzureAIInferenceChatCompletion(
{
Verify.NotNull(services);

services.AddKeyedSingleton<IChatCompletionService>(serviceId, (serviceProvider, _) =>
return services.AddKeyedSingleton<IChatCompletionService>(serviceId, (serviceProvider, _) =>
{
var chatClientBuilder = new ChatClientBuilder()
.UseFunctionInvocation(config =>
config.MaximumIterationsPerRequest = MaxInflightAutoInvokes);

var logger = serviceProvider.GetService<ILoggerFactory>()?.CreateLogger<ChatCompletionsClient>();
if (logger is not null)
{
chatClientBuilder.UseLogging(logger);
}

var options = new AzureAIInferenceClientOptions();

httpClient ??= serviceProvider.GetService<HttpClient>();
if (httpClient is not null)
{
options.Transport = new HttpClientTransport(HttpClientProvider.GetHttpClient(httpClient, serviceProvider));
options.Transport = new HttpClientTransport(httpClient);
}

return
chatClientBuilder.Use(
new Microsoft.Extensions.AI.AzureAIInferenceChatClient(
modelId: modelId,
chatCompletionsClient: new Azure.AI.Inference.ChatCompletionsClient(endpoint, credential, options)
)
).AsChatCompletionService();
});
var loggerFactory = serviceProvider.GetService<ILoggerFactory>() ?? NullLoggerFactory.Instance;

return services;
return new Azure.AI.Inference.ChatCompletionsClient(endpoint, credential, options)
.AsChatClient(modelId)
.AsBuilder()
.UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes)
.UseLogging(loggerFactory)
.Build(serviceProvider)
.AsChatCompletionService(serviceProvider);
});
}

/// <summary>
Expand All @@ -133,26 +117,18 @@ public static IServiceCollection AddAzureAIInferenceChatCompletion(this IService
{
Verify.NotNull(services);

services.AddKeyedSingleton<IChatCompletionService>(serviceId, (serviceProvider, _) =>
return services.AddKeyedSingleton<IChatCompletionService>(serviceId, (serviceProvider, _) =>
{
chatClient ??= serviceProvider.GetRequiredService<ChatCompletionsClient>();

var chatClientBuilder = new ChatClientBuilder()
.UseFunctionInvocation(config =>
config.MaximumIterationsPerRequest = MaxInflightAutoInvokes);

var logger = serviceProvider.GetService<ILoggerFactory>()?.CreateLogger<ChatCompletionsClient>();
if (logger is not null)
{
chatClientBuilder.UseLogging(logger);
}

return chatClientBuilder
.Use(new Microsoft.Extensions.AI.AzureAIInferenceChatClient(chatClient, modelId))
.AsChatCompletionService();
var loggerFactory = serviceProvider.GetService<ILoggerFactory>() ?? NullLoggerFactory.Instance;
return chatClient
.AsChatClient(modelId)
.AsBuilder()
.UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes)
.UseLogging(loggerFactory)
.Build(serviceProvider)
.AsChatCompletionService(serviceProvider);
});

return services;
}

/// <summary>
Expand All @@ -168,26 +144,17 @@ public static IServiceCollection AddAzureAIInferenceChatCompletion(this IService
{
Verify.NotNull(services);

services.AddKeyedSingleton<IChatCompletionService>(serviceId, (serviceProvider, _) =>
return services.AddKeyedSingleton<IChatCompletionService>(serviceId, (serviceProvider, _) =>
{
chatClient ??= serviceProvider.GetRequiredService<AzureAIInferenceChatClient>();

var chatClientBuilder = new ChatClientBuilder()
.UseFunctionInvocation(config =>
config.MaximumIterationsPerRequest = MaxInflightAutoInvokes);

var logger = serviceProvider.GetService<ILoggerFactory>()?.CreateLogger<ChatCompletionsClient>();
if (logger is not null)
{
chatClientBuilder.UseLogging(logger);
}

return chatClientBuilder
.Use(chatClient)
.AsChatCompletionService();
var loggerFactory = serviceProvider.GetService<ILoggerFactory>() ?? NullLoggerFactory.Instance;
return chatClient
.AsBuilder()
.UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes)
.UseLogging(loggerFactory)
.Build(serviceProvider)
.AsChatCompletionService(serviceProvider);
});

return services;
}

#region Private
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using Azure.Core;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.AzureAIInference.Core;

Expand Down Expand Up @@ -38,25 +39,16 @@ public AzureAIInferenceChatCompletionService(
HttpClient? httpClient = null,
ILoggerFactory? loggerFactory = null)
{
var logger = loggerFactory?.CreateLogger(typeof(AzureAIInferenceChatCompletionService));
this._core = new(
modelId,
apiKey,
endpoint,
httpClient,
logger);

var builder = new ChatClientBuilder()
.UseFunctionInvocation(config =>
config.MaximumIterationsPerRequest = MaxInflightAutoInvokes);

if (logger is not null)
{
builder = builder.UseLogging(logger);
}

this._chatService = builder
.Use(this._core.Client.AsChatClient(modelId))
loggerFactory ??= NullLoggerFactory.Instance;

this._core = new ChatClientCore(modelId, apiKey, endpoint, httpClient);

this._chatService = this._core.Client
.AsChatClient(modelId)
.AsBuilder()
.UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes)
.UseLogging(loggerFactory)
.Build()
.AsChatCompletionService();
}

Expand All @@ -75,25 +67,16 @@ public AzureAIInferenceChatCompletionService(
HttpClient? httpClient = null,
ILoggerFactory? loggerFactory = null)
{
var logger = loggerFactory?.CreateLogger(typeof(AzureAIInferenceChatCompletionService));
this._core = new(
modelId,
credential,
endpoint,
httpClient,
logger);

var builder = new ChatClientBuilder()
.UseFunctionInvocation(config =>
config.MaximumIterationsPerRequest = MaxInflightAutoInvokes);

if (logger is not null)
{
builder = builder.UseLogging(logger);
}

this._chatService = builder
.Use(this._core.Client.AsChatClient(modelId))
loggerFactory ??= NullLoggerFactory.Instance;

this._core = new ChatClientCore(modelId, credential, endpoint, httpClient);

this._chatService = this._core.Client
.AsChatClient(modelId)
.AsBuilder()
.UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes)
.UseLogging(loggerFactory)
.Build()
.AsChatCompletionService();
}

Expand All @@ -108,23 +91,18 @@ public AzureAIInferenceChatCompletionService(
ChatCompletionsClient chatClient,
ILoggerFactory? loggerFactory = null)
{
var logger = loggerFactory?.CreateLogger(typeof(AzureAIInferenceChatCompletionService));
this._core = new(
modelId,
chatClient,
logger);

var builder = new ChatClientBuilder()
.UseFunctionInvocation(config =>
config.MaximumIterationsPerRequest = MaxInflightAutoInvokes);

if (logger is not null)
{
builder = builder.UseLogging(logger);
}

this._chatService = builder
.Use(this._core.Client.AsChatClient(modelId))
Verify.NotNull(chatClient);

loggerFactory ??= NullLoggerFactory.Instance;

this._core = new ChatClientCore(modelId, chatClient);

this._chatService = chatClient
.AsChatClient(modelId)
.AsBuilder()
.UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes)
.UseLogging(loggerFactory)
.Build()
.AsChatCompletionService();
}

Expand Down
Loading

0 comments on commit e49b7e2

Please sign in to comment.