Skip to content

Commit

Permalink
.Net: Add shared integration tests for checking vector search scores (#…
Browse files Browse the repository at this point in the history
…10144)

### Motivation and Context

Adds tests for verifying the scores as mentioned in #10103
Only adding tests for passing connectors right now.

### Description

Adding a common base test class with integration tests for checking the
returned vector scores.
Adding subclasses for InMemory, Qdrant and Postgres, since these are
passing end to end.

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄
  • Loading branch information
westey-m authored Jan 10, 2025
1 parent 53995c8 commit e8b31a2
Show file tree
Hide file tree
Showing 5 changed files with 246 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.Extensions.VectorData;
using Xunit;

namespace SemanticKernel.IntegrationTests.Connectors.Memory;

/// <summary>
/// Base class for common integration tests that should pass for any <see cref="IVectorStoreRecordCollection{TKey, TRecord}"/>.
/// </summary>
/// <typeparam name="TKey">The type of key to use with the record collection.</typeparam>
public abstract class BaseVectorStoreRecordCollectionTests<TKey>
where TKey : notnull
{
protected abstract TKey Key1 { get; }
protected abstract TKey Key2 { get; }
protected abstract TKey Key3 { get; }
protected abstract TKey Key4 { get; }

protected abstract HashSet<string> GetSupportedDistanceFunctions();

protected abstract IVectorStoreRecordCollection<TKey, TRecord> GetTargetRecordCollection<TRecord>(string recordCollectionName, VectorStoreRecordDefinition? vectorStoreRecordDefinition);

protected virtual int DelayAfterIndexCreateInMilliseconds { get; } = 0;

protected virtual int DelayAfterUploadInMilliseconds { get; } = 0;

[Theory]
[InlineData(DistanceFunction.CosineDistance, 0, 2, 1, new int[] { 0, 2, 1 })]
[InlineData(DistanceFunction.CosineSimilarity, 1, -1, 0, new int[] { 0, 2, 1 })]
[InlineData(DistanceFunction.DotProductSimilarity, 1, -1, 0, new int[] { 0, 2, 1 })]
[InlineData(DistanceFunction.EuclideanDistance, 0, 2, 1.73, new int[] { 0, 2, 1 })]
[InlineData(DistanceFunction.EuclideanSquaredDistance, 0, 4, 3, new int[] { 0, 2, 1 })]
[InlineData(DistanceFunction.Hamming, 0, 1, 3, new int[] { 0, 1, 2 })]
[InlineData(DistanceFunction.ManhattanDistance, 0, 2, 3, new int[] { 0, 1, 2 })]
public async Task VectorSearchShouldReturnExpectedScoresAsync(string distanceFunction, double expectedExactMatchScore, double expectedOppositeScore, double expectedOrthogonalScore, int[] resultOrder)
{
var keyDictionary = new Dictionary<int, TKey>
{
{ 0, this.Key1 },
{ 1, this.Key2 },
{ 2, this.Key3 },
};
var scoreDictionary = new Dictionary<int, double>
{
{ 0, expectedExactMatchScore },
{ 1, expectedOppositeScore },
{ 2, expectedOrthogonalScore },
};

// Don't test unsupported distance functions.
var supportedDistanceFunctions = this.GetSupportedDistanceFunctions();
if (!supportedDistanceFunctions.Contains(distanceFunction))
{
return;
}

// Arrange
var definition = CreateKeyWithVectorRecordDefinition(4, distanceFunction);
var sut = this.GetTargetRecordCollection<KeyWithVectorRecord<TKey>>(
$"scorebydistancefunction{distanceFunction}",
definition);

await sut.CreateCollectionIfNotExistsAsync();
await Task.Delay(this.DelayAfterIndexCreateInMilliseconds);

// Create two vectors that are opposite to each other and records that use these
// plus a further vector that is orthogonal to the base vector.
var baseVector = new ReadOnlyMemory<float>([1, 0, 0, 0]);
var oppositeVector = new ReadOnlyMemory<float>([-1, 0, 0, 0]);
var orthogonalVector = new ReadOnlyMemory<float>([0f, -1f, -1f, 0f]);

var baseRecord = new KeyWithVectorRecord<TKey>
{
Key = this.Key1,
Vector = baseVector,
};

var oppositeRecord = new KeyWithVectorRecord<TKey>
{
Key = this.Key2,
Vector = oppositeVector,
};

var orthogonalRecord = new KeyWithVectorRecord<TKey>
{
Key = this.Key3,
Vector = orthogonalVector,
};

await sut.UpsertBatchAsync([baseRecord, oppositeRecord, orthogonalRecord]).ToListAsync();
await Task.Delay(this.DelayAfterUploadInMilliseconds);

// Act
var searchResult = await sut.VectorizedSearchAsync(baseVector);

// Assert
var results = await searchResult.Results.ToListAsync();
Assert.Equal(3, results.Count);

Assert.Equal(keyDictionary[resultOrder[0]], results[0].Record.Key);
Assert.Equal(Math.Round(scoreDictionary[resultOrder[0]], 2), Math.Round(results[0].Score!.Value, 2));

Assert.Equal(keyDictionary[resultOrder[1]], results[1].Record.Key);
Assert.Equal(Math.Round(scoreDictionary[resultOrder[1]], 2), Math.Round(results[1].Score!.Value, 2));

Assert.Equal(keyDictionary[resultOrder[2]], results[2].Record.Key);
Assert.Equal(Math.Round(scoreDictionary[resultOrder[2]], 2), Math.Round(results[2].Score!.Value, 2));

// Cleanup
await sut.DeleteCollectionAsync();
}

private static VectorStoreRecordDefinition CreateKeyWithVectorRecordDefinition(int vectorDimensions, string distanceFunction)
{
var definition = new VectorStoreRecordDefinition
{
Properties =
[
new VectorStoreRecordKeyProperty("Key", typeof(TKey)),
new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory<float>)) { Dimensions = vectorDimensions, DistanceFunction = distanceFunction },
],
};

return definition;
}

private sealed class KeyWithVectorRecord<TRecordKey>
{
public required TRecordKey Key { get; set; }

public ReadOnlyMemory<float> Vector { get; set; }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using Microsoft.Extensions.VectorData;
using Microsoft.SemanticKernel.Connectors.InMemory;

namespace SemanticKernel.IntegrationTests.Connectors.Memory.InMemory;

/// <summary>
/// Inherits common integration tests that should pass for any <see cref="IVectorStoreRecordCollection{TKey, TRecord}"/>.
/// </summary>
public class CommonInMemoryVectorStoreRecordCollectionTests() : BaseVectorStoreRecordCollectionTests<string>
{
protected override string Key1 => "1";
protected override string Key2 => "2";
protected override string Key3 => "3";
protected override string Key4 => "4";

protected override IVectorStoreRecordCollection<string, TRecord> GetTargetRecordCollection<TRecord>(string recordCollectionName, VectorStoreRecordDefinition? vectorStoreRecordDefinition)
{
return new InMemoryVectorStoreRecordCollection<string, TRecord>(recordCollectionName, new()
{
VectorStoreRecordDefinition = vectorStoreRecordDefinition
});
}

protected override HashSet<string> GetSupportedDistanceFunctions()
{
return [DistanceFunction.CosineDistance, DistanceFunction.CosineSimilarity, DistanceFunction.DotProductSimilarity, DistanceFunction.EuclideanDistance];
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using Microsoft.Extensions.VectorData;
using Microsoft.SemanticKernel.Connectors.Postgres;
using Xunit;

namespace SemanticKernel.IntegrationTests.Connectors.Memory.Postgres;

/// <summary>
/// Inherits common integration tests that should pass for any <see cref="IVectorStoreRecordCollection{TKey, TRecord}"/>.
/// </summary>
/// <param name="fixture">Postres setup and teardown.</param>
[Collection("PostgresVectorStoreCollection")]
public class CommonPostgresVectorStoreRecordCollectionTests(PostgresVectorStoreFixture fixture) : BaseVectorStoreRecordCollectionTests<string>
{
protected override string Key1 => "1";
protected override string Key2 => "2";
protected override string Key3 => "3";
protected override string Key4 => "4";

protected override IVectorStoreRecordCollection<string, TRecord> GetTargetRecordCollection<TRecord>(string recordCollectionName, VectorStoreRecordDefinition? vectorStoreRecordDefinition)
{
return new PostgresVectorStoreRecordCollection<string, TRecord>(fixture.DataSource!, recordCollectionName, new()
{
VectorStoreRecordDefinition = vectorStoreRecordDefinition
});
}

protected override HashSet<string> GetSupportedDistanceFunctions()
{
return [DistanceFunction.CosineDistance, DistanceFunction.CosineSimilarity, DistanceFunction.DotProductSimilarity, DistanceFunction.EuclideanDistance, DistanceFunction.ManhattanDistance];
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,22 @@ public PostgresVectorStoreFixture()
/// <summary>
/// Holds the Npgsql data source to use for tests.
/// </summary>
private NpgsqlDataSource? _dataSource;
public NpgsqlDataSource? DataSource { get; private set; }

private string _connectionString = null!;
private string _databaseName = null!;

/// <summary>
/// Gets a vector store to use for tests.
/// </summary>
public IVectorStore VectorStore => new PostgresVectorStore(this._dataSource!);
public IVectorStore VectorStore => new PostgresVectorStore(this.DataSource!);

/// <summary>
/// Get a database connection
/// </summary>
public NpgsqlConnection GetConnection()
{
return this._dataSource!.OpenConnection();
return this.DataSource!.OpenConnection();
}

public IVectorStoreRecordCollection<TKey, TRecord> GetCollection<TKey, TRecord>(
Expand Down Expand Up @@ -81,7 +81,7 @@ public async Task InitializeAsync()
NpgsqlDataSourceBuilder dataSourceBuilder = new(connectionStringBuilder.ToString());
dataSourceBuilder.UseVector();

this._dataSource = dataSourceBuilder.Build();
this.DataSource = dataSourceBuilder.Build();

// Wait for the postgres container to be ready and create the test database using the initial data source.
var initialDataSource = NpgsqlDataSource.Create(this._connectionString);
Expand Down Expand Up @@ -124,7 +124,7 @@ public async Task InitializeAsync()

private async Task CreateTableAsync()
{
NpgsqlConnection connection = await this._dataSource!.OpenConnectionAsync().ConfigureAwait(false);
NpgsqlConnection connection = await this.DataSource!.OpenConnectionAsync().ConfigureAwait(false);

await using (connection)
{
Expand All @@ -150,9 +150,9 @@ DescriptionEmbedding VECTOR(4) NOT NULL,
/// <returns>An async task.</returns>
public async Task DisposeAsync()
{
if (this._dataSource != null)
if (this.DataSource != null)
{
this._dataSource.Dispose();
this.DataSource.Dispose();
}

await this.DropDatabaseAsync();
Expand Down Expand Up @@ -218,7 +218,7 @@ private async Task CreateDatabaseAsync(NpgsqlDataSource initialDataSource)
await command.ExecuteNonQueryAsync();
}

await using (NpgsqlConnection conn = await this._dataSource!.OpenConnectionAsync())
await using (NpgsqlConnection conn = await this.DataSource!.OpenConnectionAsync())
{
await using (NpgsqlCommand command = new("CREATE EXTENSION vector", conn))
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using Microsoft.Extensions.VectorData;
using Microsoft.SemanticKernel.Connectors.Qdrant;
using Xunit;

namespace SemanticKernel.IntegrationTests.Connectors.Memory.Qdrant;

/// <summary>
/// Inherits common integration tests that should pass for any <see cref="IVectorStoreRecordCollection{TKey, TRecord}"/>.
/// </summary>
/// <param name="fixture">Qdrant setup and teardown.</param>
[Collection("QdrantVectorStoreCollection")]
public class CommonQdrantVectorStoreRecordCollectionTests(QdrantVectorStoreFixture fixture) : BaseVectorStoreRecordCollectionTests<ulong>
{
protected override ulong Key1 => 1;
protected override ulong Key2 => 2;
protected override ulong Key3 => 3;
protected override ulong Key4 => 4;

protected override IVectorStoreRecordCollection<ulong, TRecord> GetTargetRecordCollection<TRecord>(string recordCollectionName, VectorStoreRecordDefinition? vectorStoreRecordDefinition)
{
return new QdrantVectorStoreRecordCollection<TRecord>(fixture.QdrantClient, recordCollectionName, new()
{
HasNamedVectors = true,
VectorStoreRecordDefinition = vectorStoreRecordDefinition
});
}

protected override HashSet<string> GetSupportedDistanceFunctions()
{
return [DistanceFunction.CosineSimilarity, DistanceFunction.EuclideanDistance, DistanceFunction.ManhattanDistance];
}
}

0 comments on commit e8b31a2

Please sign in to comment.