Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.Net: Google Gemini - Adding response schema (Structured Outputs support) #10135

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .editorconfig
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ dotnet_diagnostic.IDE0005.severity = warning # Remove unnecessary using directiv
dotnet_diagnostic.IDE0009.severity = warning # Add this or Me qualification
dotnet_diagnostic.IDE0011.severity = warning # Add braces
dotnet_diagnostic.IDE0018.severity = warning # Inline variable declaration

dotnet_diagnostic.IDE0032.severity = warning # Use auto-implemented property
dotnet_diagnostic.IDE0034.severity = warning # Simplify 'default' expression
dotnet_diagnostic.IDE0035.severity = warning # Remove unreachable code
Expand Down Expand Up @@ -221,20 +222,29 @@ dotnet_diagnostic.RCS1241.severity = none # Implement IComparable when implement
dotnet_diagnostic.IDE0001.severity = none # Simplify name
dotnet_diagnostic.IDE0002.severity = none # Simplify member access
dotnet_diagnostic.IDE0004.severity = none # Remove unnecessary cast
dotnet_diagnostic.IDE0010.severity = none # Populate switch
dotnet_diagnostic.IDE0021.severity = none # Use block body for constructors
dotnet_diagnostic.IDE0022.severity = none # Use block body for methods
dotnet_diagnostic.IDE0024.severity = none # Use block body for operator
dotnet_diagnostic.IDE0035.severity = none # Remove unreachable code
dotnet_diagnostic.IDE0051.severity = none # Remove unused private member
dotnet_diagnostic.IDE0052.severity = none # Remove unread private member
dotnet_diagnostic.IDE0058.severity = none # Remove unused expression value
dotnet_diagnostic.IDE0059.severity = none # Unnecessary assignment of a value
dotnet_diagnostic.IDE0060.severity = none # Remove unused parameter
dotnet_diagnostic.IDE0061.severity = none # Use block body for local function
dotnet_diagnostic.IDE0079.severity = none # Remove unnecessary suppression.
dotnet_diagnostic.IDE0080.severity = none # Remove unnecessary suppression operator.
dotnet_diagnostic.IDE0100.severity = none # Remove unnecessary equality operator
dotnet_diagnostic.IDE0110.severity = none # Remove unnecessary discards
dotnet_diagnostic.IDE0130.severity = none # Namespace does not match folder structure
dotnet_diagnostic.IDE0290.severity = none # Use primary constructor
dotnet_diagnostic.IDE0032.severity = none # Use auto property
dotnet_diagnostic.IDE0160.severity = none # Use block-scoped namespace
dotnet_diagnostic.IDE1006.severity = warning # Naming rule violations
dotnet_diagnostic.IDE0046.severity = suggestion # If statement can be simplified
dotnet_diagnostic.IDE0056.severity = suggestion # Indexing can be simplified
dotnet_diagnostic.IDE0057.severity = suggestion # Substring can be simplified

###############################
# Naming Conventions #
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.IO;
using System.Linq;
using System.Net.Http;
using System.Text;
using System.Text.Json;
using System.Threading.Tasks;
using Microsoft.SemanticKernel.ChatCompletion;
Expand Down Expand Up @@ -419,13 +420,34 @@ public async Task ItCreatesPostRequestWithSemanticKernelVersionHeaderAsync()
Assert.Equal(expectedVersion, header);
}

[Fact]
public async Task ItCreatesPostRequestWithResponseSchemaPropertyAsync()
{
// Arrange
var client = this.CreateChatCompletionClient();
var chatHistory = CreateSampleChatHistory();
var settings = new GeminiPromptExecutionSettings { ResponseMimeType = "application/json", ResponseSchema = typeof(List<int>) };

// Act
await client.GenerateChatMessageAsync(chatHistory, settings);

// Assert
Assert.NotNull(this._messageHandlerStub.RequestHeaders);

var responseBody = Encoding.UTF8.GetString(this._messageHandlerStub.RequestContent!);

Assert.Contains("responseSchema", responseBody, StringComparison.Ordinal);
Assert.Contains("\"responseSchema\":{\"type\":\"array\",\"items\":{\"type\":\"integer\"}}", responseBody, StringComparison.Ordinal);
Assert.Contains("\"responseMimeType\":\"application/json\"", responseBody, StringComparison.Ordinal);
}

[Fact]
public async Task ItCanUseValueTasksSequentiallyForBearerTokenAsync()
{
// Arrange
var bearerTokenGenerator = new BearerTokenGenerator()
{
BearerKeys = new List<string> { "key1", "key2", "key3" }
BearerKeys = ["key1", "key2", "key3"]
};

var responseContent = File.ReadAllText(ChatTestDataFilePath);
Expand All @@ -442,7 +464,7 @@ public async Task ItCanUseValueTasksSequentiallyForBearerTokenAsync()
httpClient: httpClient,
modelId: "fake-model",
apiVersion: VertexAIVersion.V1,
bearerTokenProvider: () => bearerTokenGenerator.GetBearerToken(),
bearerTokenProvider: bearerTokenGenerator.GetBearerToken,
location: "fake-location",
projectId: "fake-project-id");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Nodes;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
Expand All @@ -25,7 +26,8 @@ public void FromPromptItReturnsWithConfiguration()
MaxTokens = 10,
TopP = 0.9,
AudioTimestamp = true,
ResponseMimeType = "application/json"
ResponseMimeType = "application/json",
ResponseSchema = JsonSerializer.Deserialize<JsonElement>(@"{""schema"":""schema""}")
};

// Act
Expand All @@ -37,9 +39,58 @@ public void FromPromptItReturnsWithConfiguration()
Assert.Equal(executionSettings.MaxTokens, request.Configuration.MaxOutputTokens);
Assert.Equal(executionSettings.AudioTimestamp, request.Configuration.AudioTimestamp);
Assert.Equal(executionSettings.ResponseMimeType, request.Configuration.ResponseMimeType);
Assert.Equal(executionSettings.ResponseSchema, request.Configuration.ResponseSchema);
Assert.Equal(executionSettings.TopP, request.Configuration.TopP);
}

[Fact]
public void JsonElementResponseSchemaFromPromptReturnsAsExpected()
{
// Arrange
var prompt = "prompt-example";
var executionSettings = new GeminiPromptExecutionSettings
{
ResponseMimeType = "application/json",
ResponseSchema = JsonSerializer.Deserialize<JsonElement>(@"{""schema"":""schema""}")
};

// Act
var request = GeminiRequest.FromPromptAndExecutionSettings(prompt, executionSettings);

// Assert
Assert.NotNull(request.Configuration);
Assert.Equal(executionSettings.ResponseMimeType, request.Configuration.ResponseMimeType);
Assert.Equal(executionSettings.ResponseSchema, request.Configuration.ResponseSchema);
}

[Theory]
[InlineData(typeof(int), "integer")]
[InlineData(typeof(bool), "boolean")]
[InlineData(typeof(string), "string")]
[InlineData(typeof(double), "number")]
[InlineData(typeof(GeminiRequest), "object")]
[InlineData(typeof(List<int>), "array")]
public void TypeResponseSchemaFromPromptReturnsAsExpected(Type type, string expectedSchemaType)
{
// Arrange
var prompt = "prompt-example";
var executionSettings = new GeminiPromptExecutionSettings
{
ResponseMimeType = "application/json",
ResponseSchema = type
};

// Act
var request = GeminiRequest.FromPromptAndExecutionSettings(prompt, executionSettings);

// Assert
Assert.NotNull(request.Configuration);
var schemaType = request.Configuration.ResponseSchema?.GetProperty("type").GetString();

Assert.Equal(expectedSchemaType, schemaType);
Assert.Equal(executionSettings.ResponseMimeType, request.Configuration.ResponseMimeType);
}

[Fact]
public void FromPromptItReturnsWithSafetySettings()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ public void ItCreatesGeminiExecutionSettingsWithCorrectDefaults()
Assert.Null(executionSettings.SafetySettings);
Assert.Null(executionSettings.AudioTimestamp);
Assert.Null(executionSettings.ResponseMimeType);
Assert.Null(executionSettings.ResponseSchema);
Assert.Equal(GeminiPromptExecutionSettings.DefaultTextMaxTokens, executionSettings.MaxTokens);
}

Expand Down Expand Up @@ -70,7 +71,8 @@ public void ItCreatesGeminiExecutionSettingsFromExtensionDataSnakeCase()
{ "max_tokens", 1000 },
{ "temperature", 0 },
{ "audio_timestamp", true },
{ "response_mimetype", "application/json" }
{ "response_mimetype", "application/json" },
{ "response_schema", JsonSerializer.Serialize(new { }) }
}
};

Expand All @@ -81,6 +83,9 @@ public void ItCreatesGeminiExecutionSettingsFromExtensionDataSnakeCase()
Assert.NotNull(executionSettings);
Assert.Equal(1000, executionSettings.MaxTokens);
Assert.Equal(0, executionSettings.Temperature);
Assert.Equal("application/json", executionSettings.ResponseMimeType);
Assert.NotNull(executionSettings.ResponseSchema);
Assert.Equal(typeof(JsonElement), executionSettings.ResponseSchema.GetType());
Assert.True(executionSettings.AudioTimestamp);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,27 @@

using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Text.Json.Serialization.Metadata;
using Microsoft.Extensions.AI;
using Microsoft.SemanticKernel.ChatCompletion;

namespace Microsoft.SemanticKernel.Connectors.Google.Core;

internal sealed class GeminiRequest
{
private static JsonSerializerOptions? s_options;
private static readonly AIJsonSchemaCreateOptions s_schemaOptions = new()
{
IncludeSchemaKeyword = false,
IncludeTypeInEnumSchemas = true,
RequireAllProperties = false,
DisallowAdditionalProperties = false,
};

[JsonPropertyName("contents")]
public IList<GeminiContent> Contents { get; set; } = null!;

Expand Down Expand Up @@ -249,10 +261,56 @@ private static void AddConfiguration(GeminiPromptExecutionSettings executionSett
StopSequences = executionSettings.StopSequences,
CandidateCount = executionSettings.CandidateCount,
AudioTimestamp = executionSettings.AudioTimestamp,
ResponseMimeType = executionSettings.ResponseMimeType
ResponseMimeType = executionSettings.ResponseMimeType,
ResponseSchema = GetResponseSchemaConfig(executionSettings.ResponseSchema)
};
}

private static JsonElement? GetResponseSchemaConfig(object? responseSchemaSettings)
{
if (responseSchemaSettings is null)
{
return null;
}

if (responseSchemaSettings is JsonElement jsonElement)
{
return jsonElement;
}

return responseSchemaSettings is Type type
? CreateSchema(type, GetDefaultOptions())
: CreateSchema(responseSchemaSettings.GetType(), GetDefaultOptions());
}

private static JsonElement CreateSchema(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code in this method and the one below is very similar or identical to that from KernelJsonSchemaBuilder. I wonder whether the builder can be used instead?

Type type,
JsonSerializerOptions options,
string? description = null,
AIJsonSchemaCreateOptions? configuration = null)
{
configuration ??= s_schemaOptions;
return AIJsonUtilities.CreateJsonSchema(type, description, serializerOptions: options, inferenceOptions: configuration);
}

[RequiresUnreferencedCode("Uses JsonStringEnumConverter and DefaultJsonTypeInfoResolver classes, making it incompatible with AOT scenarios.")]
[RequiresDynamicCode("Uses JsonStringEnumConverter and DefaultJsonTypeInfoResolver classes, making it incompatible with AOT scenarios.")]
private static JsonSerializerOptions GetDefaultOptions()
{
if (s_options is null)
{
JsonSerializerOptions options = new()
{
TypeInfoResolver = new DefaultJsonTypeInfoResolver(),
Converters = { new JsonStringEnumConverter() },
};
options.MakeReadOnly();
s_options = options;
}

return s_options;
}

private static void AddSafetySettings(GeminiPromptExecutionSettings executionSettings, GeminiRequest request)
{
request.SafetySettings = executionSettings.SafetySettings?.Select(s
Expand Down Expand Up @@ -292,5 +350,9 @@ internal sealed class ConfigurationElement
[JsonPropertyName("responseMimeType")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? ResponseMimeType { get; set; }

[JsonPropertyName("responseSchema")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public JsonElement? ResponseSchema { get; set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public sealed class GeminiPromptExecutionSettings : PromptExecutionSettings
private IList<string>? _stopSequences;
private bool? _audioTimestamp;
private string? _responseMimeType;
private object? _responseSchema;
private IList<GeminiSafetySetting>? _safetySettings;
private GeminiToolCallBehavior? _toolCallBehavior;

Expand Down Expand Up @@ -206,6 +207,29 @@ public string? ResponseMimeType
}
}

/// <summary>
/// Optional. Output schema of the generated candidate text. Schemas must be a subset of the OpenAPI schema and can be objects, primitives or arrays.
/// If set, a compatible responseMimeType must also be set. Compatible MIME types: application/json: Schema for JSON response.
/// Refer to the https://ai.google.dev/gemini-api/docs/json-mode for more information.
/// </summary>
/// <remarks>
/// Possible values are:
/// <para>- <see cref="object"/> object, which type will be used to automatically create a JSON schema;</para>
/// <para>- <see cref="Type"/> object, which will be used to automatically create a JSON schema.</para>
/// </remarks>
[JsonPropertyName("response_schema")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public object? ResponseSchema
{
get => this._responseSchema;

set
{
this.ThrowIfFrozen();
this._responseSchema = value;
}
}

/// <inheritdoc />
public override void Freeze()
{
Expand Down Expand Up @@ -243,7 +267,8 @@ public override PromptExecutionSettings Clone()
SafetySettings = this.SafetySettings?.Select(setting => new GeminiSafetySetting(setting)).ToList(),
ToolCallBehavior = this.ToolCallBehavior?.Clone(),
AudioTimestamp = this.AudioTimestamp,
ResponseMimeType = this.ResponseMimeType
ResponseMimeType = this.ResponseMimeType,
ResponseSchema = this.ResponseSchema,
};
}

Expand Down
Loading