-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
.Net: Add shared integration tests for checking vector search scores (#…
…10144) ### Motivation and Context Adds tests for verifying the scores as mentioned in #10103 Only adding tests for passing connectors right now. ### Description Adding a common base test class with integration tests for checking the returned vector scores. Adding subclasses for InMemory, Qdrant and Postgres, since these are passing end to end. ### Contribution Checklist <!-- Before submitting this PR, please make sure: --> - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone 😄
- Loading branch information
Showing
5 changed files
with
246 additions
and
8 deletions.
There are no files selected for viewing
138 changes: 138 additions & 0 deletions
138
dotnet/src/IntegrationTests/Connectors/Memory/BaseVectorStoreRecordCollectionTests.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
// Copyright (c) Microsoft. All rights reserved. | ||
|
||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
using System.Threading.Tasks; | ||
using Microsoft.Extensions.VectorData; | ||
using Xunit; | ||
|
||
namespace SemanticKernel.IntegrationTests.Connectors.Memory; | ||
|
||
/// <summary> | ||
/// Base class for common integration tests that should pass for any <see cref="IVectorStoreRecordCollection{TKey, TRecord}"/>. | ||
/// </summary> | ||
/// <typeparam name="TKey">The type of key to use with the record collection.</typeparam> | ||
public abstract class BaseVectorStoreRecordCollectionTests<TKey> | ||
where TKey : notnull | ||
{ | ||
protected abstract TKey Key1 { get; } | ||
protected abstract TKey Key2 { get; } | ||
protected abstract TKey Key3 { get; } | ||
protected abstract TKey Key4 { get; } | ||
|
||
protected abstract HashSet<string> GetSupportedDistanceFunctions(); | ||
|
||
protected abstract IVectorStoreRecordCollection<TKey, TRecord> GetTargetRecordCollection<TRecord>(string recordCollectionName, VectorStoreRecordDefinition? vectorStoreRecordDefinition); | ||
|
||
protected virtual int DelayAfterIndexCreateInMilliseconds { get; } = 0; | ||
|
||
protected virtual int DelayAfterUploadInMilliseconds { get; } = 0; | ||
|
||
[Theory] | ||
[InlineData(DistanceFunction.CosineDistance, 0, 2, 1, new int[] { 0, 2, 1 })] | ||
[InlineData(DistanceFunction.CosineSimilarity, 1, -1, 0, new int[] { 0, 2, 1 })] | ||
[InlineData(DistanceFunction.DotProductSimilarity, 1, -1, 0, new int[] { 0, 2, 1 })] | ||
[InlineData(DistanceFunction.EuclideanDistance, 0, 2, 1.73, new int[] { 0, 2, 1 })] | ||
[InlineData(DistanceFunction.EuclideanSquaredDistance, 0, 4, 3, new int[] { 0, 2, 1 })] | ||
[InlineData(DistanceFunction.Hamming, 0, 1, 3, new int[] { 0, 1, 2 })] | ||
[InlineData(DistanceFunction.ManhattanDistance, 0, 2, 3, new int[] { 0, 1, 2 })] | ||
public async Task VectorSearchShouldReturnExpectedScoresAsync(string distanceFunction, double expectedExactMatchScore, double expectedOppositeScore, double expectedOrthogonalScore, int[] resultOrder) | ||
{ | ||
var keyDictionary = new Dictionary<int, TKey> | ||
{ | ||
{ 0, this.Key1 }, | ||
{ 1, this.Key2 }, | ||
{ 2, this.Key3 }, | ||
}; | ||
var scoreDictionary = new Dictionary<int, double> | ||
{ | ||
{ 0, expectedExactMatchScore }, | ||
{ 1, expectedOppositeScore }, | ||
{ 2, expectedOrthogonalScore }, | ||
}; | ||
|
||
// Don't test unsupported distance functions. | ||
var supportedDistanceFunctions = this.GetSupportedDistanceFunctions(); | ||
if (!supportedDistanceFunctions.Contains(distanceFunction)) | ||
{ | ||
return; | ||
} | ||
|
||
// Arrange | ||
var definition = CreateKeyWithVectorRecordDefinition(4, distanceFunction); | ||
var sut = this.GetTargetRecordCollection<KeyWithVectorRecord<TKey>>( | ||
$"scorebydistancefunction{distanceFunction}", | ||
definition); | ||
|
||
await sut.CreateCollectionIfNotExistsAsync(); | ||
await Task.Delay(this.DelayAfterIndexCreateInMilliseconds); | ||
|
||
// Create two vectors that are opposite to each other and records that use these | ||
// plus a further vector that is orthogonal to the base vector. | ||
var baseVector = new ReadOnlyMemory<float>([1, 0, 0, 0]); | ||
var oppositeVector = new ReadOnlyMemory<float>([-1, 0, 0, 0]); | ||
var orthogonalVector = new ReadOnlyMemory<float>([0f, -1f, -1f, 0f]); | ||
|
||
var baseRecord = new KeyWithVectorRecord<TKey> | ||
{ | ||
Key = this.Key1, | ||
Vector = baseVector, | ||
}; | ||
|
||
var oppositeRecord = new KeyWithVectorRecord<TKey> | ||
{ | ||
Key = this.Key2, | ||
Vector = oppositeVector, | ||
}; | ||
|
||
var orthogonalRecord = new KeyWithVectorRecord<TKey> | ||
{ | ||
Key = this.Key3, | ||
Vector = orthogonalVector, | ||
}; | ||
|
||
await sut.UpsertBatchAsync([baseRecord, oppositeRecord, orthogonalRecord]).ToListAsync(); | ||
await Task.Delay(this.DelayAfterUploadInMilliseconds); | ||
|
||
// Act | ||
var searchResult = await sut.VectorizedSearchAsync(baseVector); | ||
|
||
// Assert | ||
var results = await searchResult.Results.ToListAsync(); | ||
Assert.Equal(3, results.Count); | ||
|
||
Assert.Equal(keyDictionary[resultOrder[0]], results[0].Record.Key); | ||
Assert.Equal(Math.Round(scoreDictionary[resultOrder[0]], 2), Math.Round(results[0].Score!.Value, 2)); | ||
|
||
Assert.Equal(keyDictionary[resultOrder[1]], results[1].Record.Key); | ||
Assert.Equal(Math.Round(scoreDictionary[resultOrder[1]], 2), Math.Round(results[1].Score!.Value, 2)); | ||
|
||
Assert.Equal(keyDictionary[resultOrder[2]], results[2].Record.Key); | ||
Assert.Equal(Math.Round(scoreDictionary[resultOrder[2]], 2), Math.Round(results[2].Score!.Value, 2)); | ||
|
||
// Cleanup | ||
await sut.DeleteCollectionAsync(); | ||
} | ||
|
||
private static VectorStoreRecordDefinition CreateKeyWithVectorRecordDefinition(int vectorDimensions, string distanceFunction) | ||
{ | ||
var definition = new VectorStoreRecordDefinition | ||
{ | ||
Properties = | ||
[ | ||
new VectorStoreRecordKeyProperty("Key", typeof(TKey)), | ||
new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory<float>)) { Dimensions = vectorDimensions, DistanceFunction = distanceFunction }, | ||
], | ||
}; | ||
|
||
return definition; | ||
} | ||
|
||
private sealed class KeyWithVectorRecord<TRecordKey> | ||
{ | ||
public required TRecordKey Key { get; set; } | ||
|
||
public ReadOnlyMemory<float> Vector { get; set; } | ||
} | ||
} |
31 changes: 31 additions & 0 deletions
31
...grationTests/Connectors/Memory/InMemory/CommonInMemoryVectorStoreRecordCollectionTests.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
// Copyright (c) Microsoft. All rights reserved. | ||
|
||
using System.Collections.Generic; | ||
using Microsoft.Extensions.VectorData; | ||
using Microsoft.SemanticKernel.Connectors.InMemory; | ||
|
||
namespace SemanticKernel.IntegrationTests.Connectors.Memory.InMemory; | ||
|
||
/// <summary> | ||
/// Inherits common integration tests that should pass for any <see cref="IVectorStoreRecordCollection{TKey, TRecord}"/>. | ||
/// </summary> | ||
public class CommonInMemoryVectorStoreRecordCollectionTests() : BaseVectorStoreRecordCollectionTests<string> | ||
{ | ||
protected override string Key1 => "1"; | ||
protected override string Key2 => "2"; | ||
protected override string Key3 => "3"; | ||
protected override string Key4 => "4"; | ||
|
||
protected override IVectorStoreRecordCollection<string, TRecord> GetTargetRecordCollection<TRecord>(string recordCollectionName, VectorStoreRecordDefinition? vectorStoreRecordDefinition) | ||
{ | ||
return new InMemoryVectorStoreRecordCollection<string, TRecord>(recordCollectionName, new() | ||
{ | ||
VectorStoreRecordDefinition = vectorStoreRecordDefinition | ||
}); | ||
} | ||
|
||
protected override HashSet<string> GetSupportedDistanceFunctions() | ||
{ | ||
return [DistanceFunction.CosineDistance, DistanceFunction.CosineSimilarity, DistanceFunction.DotProductSimilarity, DistanceFunction.EuclideanDistance]; | ||
} | ||
} |
34 changes: 34 additions & 0 deletions
34
...grationTests/Connectors/Memory/Postgres/CommonPostgresVectorStoreRecordCollectionTests.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
// Copyright (c) Microsoft. All rights reserved. | ||
|
||
using System.Collections.Generic; | ||
using Microsoft.Extensions.VectorData; | ||
using Microsoft.SemanticKernel.Connectors.Postgres; | ||
using Xunit; | ||
|
||
namespace SemanticKernel.IntegrationTests.Connectors.Memory.Postgres; | ||
|
||
/// <summary> | ||
/// Inherits common integration tests that should pass for any <see cref="IVectorStoreRecordCollection{TKey, TRecord}"/>. | ||
/// </summary> | ||
/// <param name="fixture">Postres setup and teardown.</param> | ||
[Collection("PostgresVectorStoreCollection")] | ||
public class CommonPostgresVectorStoreRecordCollectionTests(PostgresVectorStoreFixture fixture) : BaseVectorStoreRecordCollectionTests<string> | ||
{ | ||
protected override string Key1 => "1"; | ||
protected override string Key2 => "2"; | ||
protected override string Key3 => "3"; | ||
protected override string Key4 => "4"; | ||
|
||
protected override IVectorStoreRecordCollection<string, TRecord> GetTargetRecordCollection<TRecord>(string recordCollectionName, VectorStoreRecordDefinition? vectorStoreRecordDefinition) | ||
{ | ||
return new PostgresVectorStoreRecordCollection<string, TRecord>(fixture.DataSource!, recordCollectionName, new() | ||
{ | ||
VectorStoreRecordDefinition = vectorStoreRecordDefinition | ||
}); | ||
} | ||
|
||
protected override HashSet<string> GetSupportedDistanceFunctions() | ||
{ | ||
return [DistanceFunction.CosineDistance, DistanceFunction.CosineSimilarity, DistanceFunction.DotProductSimilarity, DistanceFunction.EuclideanDistance, DistanceFunction.ManhattanDistance]; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
35 changes: 35 additions & 0 deletions
35
...IntegrationTests/Connectors/Memory/Qdrant/CommonQdrantVectorStoreRecordCollectionTests.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
// Copyright (c) Microsoft. All rights reserved. | ||
|
||
using System.Collections.Generic; | ||
using Microsoft.Extensions.VectorData; | ||
using Microsoft.SemanticKernel.Connectors.Qdrant; | ||
using Xunit; | ||
|
||
namespace SemanticKernel.IntegrationTests.Connectors.Memory.Qdrant; | ||
|
||
/// <summary> | ||
/// Inherits common integration tests that should pass for any <see cref="IVectorStoreRecordCollection{TKey, TRecord}"/>. | ||
/// </summary> | ||
/// <param name="fixture">Qdrant setup and teardown.</param> | ||
[Collection("QdrantVectorStoreCollection")] | ||
public class CommonQdrantVectorStoreRecordCollectionTests(QdrantVectorStoreFixture fixture) : BaseVectorStoreRecordCollectionTests<ulong> | ||
{ | ||
protected override ulong Key1 => 1; | ||
protected override ulong Key2 => 2; | ||
protected override ulong Key3 => 3; | ||
protected override ulong Key4 => 4; | ||
|
||
protected override IVectorStoreRecordCollection<ulong, TRecord> GetTargetRecordCollection<TRecord>(string recordCollectionName, VectorStoreRecordDefinition? vectorStoreRecordDefinition) | ||
{ | ||
return new QdrantVectorStoreRecordCollection<TRecord>(fixture.QdrantClient, recordCollectionName, new() | ||
{ | ||
HasNamedVectors = true, | ||
VectorStoreRecordDefinition = vectorStoreRecordDefinition | ||
}); | ||
} | ||
|
||
protected override HashSet<string> GetSupportedDistanceFunctions() | ||
{ | ||
return [DistanceFunction.CosineSimilarity, DistanceFunction.EuclideanDistance, DistanceFunction.ManhattanDistance]; | ||
} | ||
} |