Skip to content

Commit

Permalink
[C#] refactor: refactor embeddings to sync with JS sdk (#1227)
Browse files Browse the repository at this point in the history
## Linked issues

closes: #1210 #907 

## Details

Provide a list of your changes here. If you are fixing a bug, please
provide steps to reproduce the bug.

#### Change details

- refactor embeddings to sync with the JS part and fix #1210
- add integration tests for embeddings #907 

## Attestation Checklist

- [x] My code follows the style guidelines of this project

- I have checked for/fixed spelling, linting, and other errors
- I have commented my code for clarity
- I have made corresponding changes to the documentation (updating the
doc strings in the code is sufficient)
- My changes generate no new warnings
- I have added tests that validates my changes, and provides sufficient
test coverage. I have tested with:
  - Local testing
  - E2E testing in Teams
- New and existing unit tests pass locally with my changes

### Additional information

> Feel free to add other relevant information below
  • Loading branch information
kuojianlu authored Jan 31, 2024
1 parent 76d6b29 commit 1883aa9
Show file tree
Hide file tree
Showing 9 changed files with 280 additions and 92 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
using Azure;
using Microsoft.Bot.Builder;
using Moq;
using System.Reflection;
using Microsoft.Teams.AI.AI.Embeddings;
using Azure.AI.OpenAI;
using Microsoft.Teams.AI.Exceptions;
using Microsoft.Teams.AI.State;

#pragma warning disable CS8604 // Possible null reference argument.
namespace Microsoft.Teams.AI.Tests.AITests
Expand All @@ -20,9 +18,7 @@ public async void Test_OpenAI_CreateEmbeddings_ReturnEmbeddings()
var model = "randomModelId";

var options = new OpenAIEmbeddingsOptions(apiKey, model);
var turnContextMock = new Mock<ITurnContext>();
var turnStateMock = new Mock<TurnState>();
var openAiEmbeddings = new OpenAIEmbeddings<TurnState, OpenAIEmbeddingsOptions>(options);
var openAiEmbeddings = new OpenAIEmbeddings(options);

IList<string> inputs = new List<string> { "test" };
var clientMock = new Mock<OpenAIClient>(It.IsAny<string>());
Expand All @@ -34,7 +30,7 @@ public async void Test_OpenAI_CreateEmbeddings_ReturnEmbeddings()
Embeddings embeddingsResult = AzureOpenAIModelFactory.Embeddings(data, usage);
Response? response = null;
clientMock.Setup(client => client.GetEmbeddingsAsync(It.IsAny<EmbeddingsOptions>(), It.IsAny<CancellationToken>())).ReturnsAsync(Response.FromValue(embeddingsResult, response));
openAiEmbeddings.GetType().GetField("_client", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(openAiEmbeddings, clientMock.Object);
openAiEmbeddings.GetType().GetField("_openAIClient", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(openAiEmbeddings, clientMock.Object);

// Act
var result = await openAiEmbeddings.CreateEmbeddingsAsync(inputs);
Expand All @@ -54,9 +50,7 @@ public async void Test_AzureOpenAI_CreateEmbeddings_ReturnEmbeddings()
var endpoint = "https://test.cognitiveservices.azure.com";
var options = new AzureOpenAIEmbeddingsOptions(apiKey, model, endpoint);

var turnContextMock = new Mock<ITurnContext>();
var turnStateMock = new Mock<TurnState>();
var openAiEmbeddings = new OpenAIEmbeddings<TurnState, AzureOpenAIEmbeddingsOptions>(options);
var openAiEmbeddings = new OpenAIEmbeddings(options);

IList<string> inputs = new List<string> { "test" };
IEnumerable<EmbeddingItem> data = new List<EmbeddingItem>()
Expand All @@ -68,7 +62,7 @@ public async void Test_AzureOpenAI_CreateEmbeddings_ReturnEmbeddings()
Response? response = null;
var clientMock = new Mock<OpenAIClient>(It.IsAny<string>());
clientMock.Setup(client => client.GetEmbeddingsAsync(It.IsAny<EmbeddingsOptions>(), It.IsAny<CancellationToken>())).ReturnsAsync(Response.FromValue(embeddingsResult, response));
openAiEmbeddings.GetType().GetField("_client", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(openAiEmbeddings, clientMock.Object);
openAiEmbeddings.GetType().GetField("_openAIClient", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(openAiEmbeddings, clientMock.Object);

// Act
var result = await openAiEmbeddings.CreateEmbeddingsAsync(inputs);
Expand All @@ -89,15 +83,13 @@ public async void Test_CreateEmbeddings_ThrowRequestFailedException(int statusCo
var model = "randomModelId";

var options = new OpenAIEmbeddingsOptions(apiKey, model);
var turnContextMock = new Mock<ITurnContext>();
var turnStateMock = new Mock<TurnState>();
var openAiEmbeddings = new OpenAIEmbeddings<TurnState, OpenAIEmbeddingsOptions>(options);
var openAiEmbeddings = new OpenAIEmbeddings(options);

IList<string> inputs = new List<string> { "test" };
var exception = new RequestFailedException(statusCode, errorMsg);
var clientMock = new Mock<OpenAIClient>(It.IsAny<string>());
clientMock.Setup(client => client.GetEmbeddingsAsync(It.IsAny<EmbeddingsOptions>(), It.IsAny<CancellationToken>())).ThrowsAsync(exception);
openAiEmbeddings.GetType().GetField("_client", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(openAiEmbeddings, clientMock.Object);
openAiEmbeddings.GetType().GetField("_openAIClient", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(openAiEmbeddings, clientMock.Object);

// Act
var result = await openAiEmbeddings.CreateEmbeddingsAsync(inputs);
Expand All @@ -116,15 +108,13 @@ public async void Test_CreateEmbeddings_ThrowException()
var model = "randomModelId";

var options = new OpenAIEmbeddingsOptions(apiKey, model);
var turnContextMock = new Mock<ITurnContext>();
var turnStateMock = new Mock<TurnState>();
var openAiEmbeddings = new OpenAIEmbeddings<TurnState, OpenAIEmbeddingsOptions>(options);
var openAiEmbeddings = new OpenAIEmbeddings(options);

IList<string> inputs = new List<string> { "test" };
var exception = new InvalidOperationException("other exception");
var clientMock = new Mock<OpenAIClient>(It.IsAny<string>());
clientMock.Setup(client => client.GetEmbeddingsAsync(It.IsAny<EmbeddingsOptions>(), It.IsAny<CancellationToken>())).ThrowsAsync(exception);
openAiEmbeddings.GetType().GetField("_client", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(openAiEmbeddings, clientMock.Object);
openAiEmbeddings.GetType().GetField("_openAIClient", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(openAiEmbeddings, clientMock.Object);

// Act
var result = await Assert.ThrowsAsync<TeamsAIException>(async () => await openAiEmbeddings.CreateEmbeddingsAsync(inputs));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ internal sealed class AzureOpenAIConfiguration
public string ModelId { get; set; } = string.Empty;
public string ApiKey { get; set; } = string.Empty;
public string? ChatModelId { get; set; }
public string? EmbeddingModelId { get; set; }
public string Endpoint { get; set; } = string.Empty;

public AzureOpenAIConfiguration(string modelId, string? chatModelId, string apiKey, string endpoint)
public AzureOpenAIConfiguration(string modelId, string? chatModelId, string? embeddingModelId, string apiKey, string endpoint)
{
ModelId = modelId;
ChatModelId = chatModelId;
EmbeddingModelId = embeddingModelId;
ApiKey = apiKey;
Endpoint = endpoint;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ internal sealed class OpenAIConfiguration
public string ModelId { get; set; } = string.Empty;
public string ApiKey { get; set; } = string.Empty;
public string? ChatModelId { get; set; }
public string? EmbeddingModelId { get; set; }

public OpenAIConfiguration(string modelId, string? chatModelId, string apiKey)
public OpenAIConfiguration(string modelId, string? chatModelId, string? embeddingModelId, string apiKey)
{
ModelId = modelId;
ChatModelId = chatModelId;
EmbeddingModelId = embeddingModelId;
ApiKey = apiKey;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
using Microsoft.Extensions.Configuration;
using Microsoft.Teams.AI.AI.Embeddings;
using Microsoft.Teams.AI.Tests.TestUtils;
using System.Reflection;
using Xunit.Abstractions;
using Microsoft.Extensions.Logging;

namespace Microsoft.Teams.AI.Tests.IntegrationTests
{
public sealed class OpenAIEmbeddingsTests
{
private readonly IConfigurationRoot _configuration;
private readonly RedirectOutput _output;
private readonly ILoggerFactory _loggerFactory;

public OpenAIEmbeddingsTests(ITestOutputHelper output)
{
_output = new RedirectOutput(output);
_loggerFactory = new TestLoggerFactory(_output);

var currentAssemblyDirectory = Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location);

if (string.IsNullOrWhiteSpace(currentAssemblyDirectory))
{
throw new InvalidOperationException("Unable to determine current assembly directory.");
}

var directoryPath = Path.GetFullPath(Path.Combine(currentAssemblyDirectory, $"../../../IntegrationTests/"));
var settingsPath = Path.Combine(directoryPath, "testsettings.json");

_configuration = new ConfigurationBuilder()
.AddJsonFile(path: settingsPath, optional: false, reloadOnChange: true)
.AddEnvironmentVariables()
.AddUserSecrets<OpenAIEmbeddingsTests>()
.Build();
}

[Theory(Skip = "This test should only be run manually.")]
public async Task Test_CreateEmbeddingsAsync_OpenAI()
{
// Arrange
var config = _configuration.GetSection("OpenAI").Get<OpenAIConfiguration>();
var options = new OpenAIEmbeddingsOptions(config.ApiKey, config.EmbeddingModelId!);
var embeddings = new OpenAIEmbeddings(options, _loggerFactory);
var inputs = new List<string>()
{
"test-input1",
"test-input2"
};
var dimension = config.EmbeddingModelId!.Equals("text-embedding-3-large") ? 3072 : 1536;

// Act
var result = await embeddings.CreateEmbeddingsAsync(inputs);

// Assert
Assert.Equal(EmbeddingsResponseStatus.Success, result.Status);
Assert.NotNull(result.Output);
Assert.Equal(2, result.Output.Count);
Assert.Equal(dimension, result.Output[0].Length);
Assert.Equal(dimension, result.Output[1].Length);
}

[Theory(Skip = "This test should only be run manually.")]
public async Task Test_CreateEmbeddingsAsync_AzureOpenAI()
{
// Arrange
var config = _configuration.GetSection("AzureOpenAI").Get<AzureOpenAIConfiguration>();
var options = new AzureOpenAIEmbeddingsOptions(config.ApiKey, config.EmbeddingModelId!, config.Endpoint);
var embeddings = new OpenAIEmbeddings(options, _loggerFactory);
var inputs = new List<string>()
{
"test-input1",
"test-input2"
};
var dimension = config.EmbeddingModelId!.Equals("text-embedding-3-large") ? 3072 : 1536;

// Act
var result = await embeddings.CreateEmbeddingsAsync(inputs);

// Assert
Assert.Equal(EmbeddingsResponseStatus.Success, result.Status);
Assert.NotNull(result.Output);
Assert.Equal(2, result.Output.Count);
Assert.Equal(dimension, result.Output[0].Length);
Assert.Equal(dimension, result.Output[1].Length);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,54 @@
namespace Microsoft.Teams.AI.AI.Embeddings
{
/// <summary>
/// Additional options needed to use the Azure OpenAI service.
/// Options for configuring an `OpenAIEmbeddings` to generate embeddings using an Azure OpenAI hosted model.
/// </summary>
/// <remarks>
/// The Azure OpenAI API version is set to latest by default.
/// </remarks>
public class AzureOpenAIEmbeddingsOptions : OpenAIEmbeddingsOptions
public class AzureOpenAIEmbeddingsOptions : BaseOpenAIEmbeddingsOptions
{
/// <summary>
/// Endpoint for your Azure OpenAI deployment.
/// API key to use when making requests to Azure OpenAI.
/// </summary>
public new string Endpoint { get; set; }
public string AzureApiKey { get; set; }

/// <summary>
/// Create an instance of the AzureOpenAIEmbeddingsOptions class.
/// Name of the Azure OpenAI deployment (model) to use.
/// </summary>
/// <param name="apiKey">OpenAI API key.</param>
/// <param name="model">The model to use for embeddings. This should be the model deployment name, not the model</param>
/// <param name="endpoint">Endpoint for your Azure OpenAI deployment.</param>
public AzureOpenAIEmbeddingsOptions(string apiKey, string model, string endpoint) : base(apiKey, model)
public string AzureDeployment { get; set; }

/// <summary>
/// Deployment endpoint to use.
/// </summary>
public string AzureEndpoint { get; set; }

/// <summary>
/// Optional. Version of the API being called.
/// </summary>
public string? AzureApiVersion { get; set; }

/// <summary>
/// Initializes a new instance of the <see cref="AzureOpenAIEmbeddingsOptions"/> class.
/// </summary>
/// <param name="azureApiKey">API key to use when making requests to Azure OpenAI.</param>
/// <param name="azureDeployment">Name of the Azure OpenAI deployment (model) to use.</param>
/// <param name="azureEndpoint">Deployment endpoint to use.</param>
public AzureOpenAIEmbeddingsOptions(
string azureApiKey,
string azureDeployment,
string azureEndpoint) : base()
{
Verify.ParamNotNull(endpoint);
Verify.ParamNotNull(azureApiKey);
Verify.ParamNotNull(azureDeployment);
Verify.ParamNotNull(azureEndpoint);

azureEndpoint = azureEndpoint.Trim();
if (!azureEndpoint.StartsWith("https://"))
{
throw new ArgumentException($"Model created with an invalid endpoint of `{azureEndpoint}`. The endpoint must be a valid HTTPS url.");
}

Endpoint = endpoint;
this.AzureApiKey = azureApiKey;
this.AzureDeployment = azureDeployment;
this.AzureEndpoint = azureEndpoint;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
namespace Microsoft.Teams.AI.AI.Embeddings
{
/// <summary>
/// Base embeddings options common to both OpenAI and Azure OpenAI services.
/// </summary>
public class BaseOpenAIEmbeddingsOptions
{
/// <summary>
/// Optional. Whether to log requests to the console.
/// </summary>
/// <remarks>
/// This is useful for debugging prompts.
/// The default value is `false`.
/// </remarks>
public bool? LogRequests { get; set; }

/// <summary>
/// Optional. Retry policy to use when calling the OpenAI API.
/// </summary>
/// <remarks>
/// The default retry policy is `{ TimeSpan.FromMilliseconds(2000), TimeSpan.FromMilliseconds(5000) }`
/// which means that the first retry will be after 2 seconds and the second retry will be after 5 seconds.
/// </remarks>
public List<TimeSpan>? RetryPolicy { get; set; }
}
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
using Microsoft.Teams.AI.State;

namespace Microsoft.Teams.AI.AI.Embeddings
namespace Microsoft.Teams.AI.AI.Embeddings
{
/// <summary>
/// Interface for Embeddings.
/// An AI model that can be used to create embeddings.
/// </summary>
public interface IEmbeddings<TState> where TState : TurnState
public interface IEmbeddingsModel
{
/// <summary>
/// Creates embeddings for the given inputs using the OpenAI API.
/// Creates embeddings for the given inputs.
/// </summary>
/// <param name="inputs">Text inputs to create embeddings for.</param>
/// <param name="cancellationToken">A cancellation token that can be used by other objects
Expand Down
Loading

0 comments on commit 1883aa9

Please sign in to comment.