Skip to content

Commit 0fcef08

Browse files
committed
feat: Parameterized sql queries
1 parent 7c4d339 commit 0fcef08

File tree

6 files changed

+233
-66
lines changed

6 files changed

+233
-66
lines changed

Extensions/SqlServer/Cosmos.DataTransfer.SqlServerExtension.UnitTests/Cosmos.DataTransfer.SqlServerExtension.UnitTests.csproj

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
<PackageReference Include="MSTest.TestAdapter" />
1919
<PackageReference Include="MSTest.TestFramework" />
2020
<PackageReference Include="System.Linq.Async" />
21+
<PackageReference Include="Moq" />
2122
<PackageReference Include="coverlet.collector" />
2223
<PackageReference Include="coverlet.msbuild">
2324
<PrivateAssets>all</PrivateAssets>
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
using Microsoft.Data.Sqlite;
22
using Microsoft.Extensions.Logging.Abstractions;
3+
using System.Data.Common;
34
using Cosmos.DataTransfer.Interfaces;
45
using Cosmos.DataTransfer.Common;
56
using Cosmos.DataTransfer.Common.UnitTests;
7+
using Moq;
8+
using Microsoft.Extensions.Configuration;
69

710
namespace Cosmos.DataTransfer.SqlServerExtension.UnitTests;
811

912
[TestClass]
1013
public class SqlServerDataSourceExtensionTests
1114
{
1215

13-
private static async Task<Func<string,ValueTask<System.Data.Common.DbConnection>>> connectionFactory(CancellationToken cancellationToken = default(CancellationToken)) {
14-
var connection = new SqliteConnection("");
16+
private static async Task<Tuple<SqliteFactory,DbConnection>> connectionFactory(CancellationToken cancellationToken = default(CancellationToken)) {
17+
var provider = SqliteFactory.Instance;
18+
var connection = provider.CreateConnection();
1519
await connection.OpenAsync(cancellationToken);
1620

1721
var cmd = connection.CreateCommand();
@@ -27,25 +31,20 @@ name TEXT
2731
VALUES (2, NULL);";
2832
await cmd.ExecuteNonQueryAsync(cancellationToken);
2933

30-
var func = (string connectionString) => {
31-
return new ValueTask<System.Data.Common.DbConnection>(connection);
32-
};
33-
34-
return func;
34+
return Tuple.Create(provider, connection);
3535
}
3636

3737
[TestMethod]
38-
public async Task TestReadAsync_QueryText() {
38+
public async Task TestReadAsync() {
39+
var config = new Mock<IConfiguration>();
40+
var cancellationToken = new CancellationTokenSource(500);
41+
var (providerFactory, connection) = await connectionFactory(cancellationToken.Token);
42+
3943
var extension = new SqlServerDataSourceExtension();
40-
var config = TestHelpers.CreateConfig(new Dictionary<string, string> {
41-
{ "ConnectionString", "Sqlite" },
42-
{ "QueryText", "SELECT * FROM foobar" }
43-
});
4444
Assert.AreEqual("SqlServer", extension.DisplayName);
45-
46-
var cancellationToken = new CancellationTokenSource(500);
4745

48-
var result = await extension.ReadAsync(config, NullLogger.Instance, await connectionFactory(cancellationToken.Token), cancellationToken.Token).ToListAsync();
46+
var result = await extension.ReadAsync(config.Object, NullLogger.Instance,
47+
"SELECT * FROM foobar", Array.Empty<DbParameter>(), connection, providerFactory, cancellationToken.Token).ToListAsync();
4948
var expected = new List<DictionaryDataItem> {
5049
new DictionaryDataItem(new Dictionary<string, object?> { { "id", (long)1 }, { "name", "zoo" } }),
5150
new DictionaryDataItem(new Dictionary<string, object?> { { "id", (long)2 }, { "name", null } })
@@ -54,24 +53,25 @@ public async Task TestReadAsync_QueryText() {
5453
}
5554

5655
[TestMethod]
57-
public async Task TestReadAsync_FromFile() {
58-
var outputFile = Path.GetTempFileName();
59-
await File.WriteAllTextAsync(outputFile, "SELECT * FROM foobar;");
56+
public async Task TestReadAsyncWithParameters() {
57+
var config = new Mock<IConfiguration>();
58+
var cancellationToken = new CancellationTokenSource();
59+
var (providerFactory, connection) = await connectionFactory(cancellationToken.Token);
60+
6061
var extension = new SqlServerDataSourceExtension();
61-
var config = TestHelpers.CreateConfig(new Dictionary<string, string> {
62-
{ "ConnectionString", "Sqlite" },
63-
{ "FilePath", outputFile }
64-
});
62+
Assert.AreEqual("SqlServer", extension.DisplayName);
6563

66-
67-
var cancellationToken = new CancellationTokenSource(500);
64+
var parameter = providerFactory.CreateParameter();
65+
parameter.ParameterName = "@x";
66+
parameter.DbType = System.Data.DbType.Int32;
67+
parameter.Value = 2;
6868

69-
var result = await extension.ReadAsync(config, NullLogger.Instance, await connectionFactory(cancellationToken.Token), cancellationToken.Token).ToListAsync();
70-
var expected = new List<DictionaryDataItem> {
71-
new DictionaryDataItem(new Dictionary<string, object?> { { "id", (long)1 }, { "name", "zoo" } }),
72-
new DictionaryDataItem(new Dictionary<string, object?> { { "id", (long)2 }, { "name", null } })
73-
};
74-
CollectionAssert.That.AreEqual(expected, result, new DataItemComparer());
69+
var result = await extension.ReadAsync(config.Object, NullLogger.Instance,
70+
"SELECT * FROM foobar WHERE id = @x",
71+
new DbParameter[] { parameter }, connection, providerFactory, cancellationToken.Token).FirstAsync();
72+
Assert.That.AreEqual(result,
73+
new DictionaryDataItem(new Dictionary<string, object?> { { "id", (long)2 }, { "name", null } }),
74+
new DataItemComparer());
7575
}
7676

7777
// Allows for testing against an actual SQL Server by specifying a
@@ -80,10 +80,11 @@ public async Task TestReadAsync_FromFile() {
8080
// <?xml version="1.0" encoding="utf-8"?>
8181
// <RunSettings>
8282
// <TestRunParameters>
83-
// <Parameter name="TestReadAsync_LiveSqlServer_ConnectionString" value="<Your connection string>" />
83+
// <Parameter name="TestReadAsync_LiveSqlServer_ConnectionString" value="Your connection string" />
8484
// </TestRunParameters>
8585
// </RunSettings>
86-
// run test with dotnet test --settings sql.runsettings
86+
// run test with
87+
// dotnet test --settings sql.runsettings
8788
#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable.
8889
public TestContext TestContext { get; set; }
8990
#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable.
@@ -111,4 +112,4 @@ public async Task TestReadAsync_LiveSqlServer() {
111112
{ "zoo", null }
112113
})));
113114
}
114-
}
115+
}

Extensions/SqlServer/Cosmos.DataTransfer.SqlServerExtension.UnitTests/SqlServerSourceSettingsTests.cs

+78-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
using Cosmos.DataTransfer.Interfaces;
22
using System.ComponentModel.DataAnnotations;
3+
using Cosmos.DataTransfer.Common.UnitTests;
4+
using Microsoft.Data.Sqlite;
5+
using System.Data;
36

47
namespace Cosmos.DataTransfer.SqlServerExtension.UnitTests;
58

@@ -31,9 +34,10 @@ public void TestSourceSettings_ValidationFails1()
3134
[TestMethod]
3235
public void TestSourceSettings_ValidationFails2()
3336
{
37+
var fn = Path.GetTempFileName();
3438
var settings = new SqlServerSourceSettings {
3539
QueryText = "SELECT 1;",
36-
FilePath = "dmt-query.sql"
40+
FilePath = fn
3741
};
3842

3943
var validationResults = settings.Validate(new ValidationContext(settings)).ToList();
@@ -51,6 +55,22 @@ public void TestSourceSettings_ValidationFails2()
5155
Assert.ThrowsException<AggregateException>(() => settings.Validate());
5256
}
5357

58+
[TestMethod]
59+
public void TestSourceSettings_Validation_FileNotFound()
60+
{
61+
var fn = Path.GetTempFileName();
62+
var settings = new SqlServerSourceSettings {
63+
ConnectionString = "Connection, please",
64+
FilePath = "dmt.sql"
65+
};
66+
67+
var validationResults = settings.Validate(new ValidationContext(settings)).ToList();
68+
Assert.AreEqual(1, validationResults.Count);
69+
Assert.IsTrue(validationResults[0].ErrorMessage!.StartsWith("Could not read `FilePath`. Reason:"));
70+
CollectionAssert.AreEqual(new string[] { "FilePath" }, validationResults[0].MemberNames.ToArray());
71+
}
72+
73+
5474
[TestMethod]
5575
[DataRow("SELECT 1", null)]
5676
[DataRow("SELECT 1", " ")]
@@ -60,12 +80,67 @@ public void TestSourceSettings_ValidationSuccess(string queryText, string filePa
6080
var settings = new SqlServerSourceSettings {
6181
ConnectionString = "Server",
6282
QueryText = queryText,
63-
FilePath = filePath
83+
FilePath = filePath == "filename" ? Path.GetTempFileName() : filePath
6484
};
6585

6686
var validationResults = settings.Validate(new ValidationContext(settings));
6787
Assert.AreEqual(0, validationResults.Count());
68-
6988
settings.Validate();
7089
}
90+
91+
[TestMethod]
92+
public void TestSourceSettings_GetQueryText1() {
93+
var settings = new SqlServerSourceSettings() {
94+
QueryText = "SELECT 1"
95+
};
96+
Assert.AreEqual("SELECT 1", settings.GetQueryText());
97+
98+
var fn = Path.GetTempFileName();
99+
settings.FilePath = fn; // But this shouldn't occur, as the settings are invalid.
100+
File.WriteAllText(fn, "More SQL");
101+
Assert.AreEqual("More SQL", settings.GetQueryText());
102+
103+
settings.QueryText = "";
104+
Assert.AreEqual("More SQL", settings.GetQueryText());
105+
}
106+
107+
[TestMethod]
108+
public void TestSourceSettings_Parameters() {
109+
var settings = new SqlServerSourceSettings();
110+
111+
Assert.AreEqual(0, settings.GetDbParameters(SqliteFactory.Instance).Count());
112+
settings.Parameters = new Dictionary<string, object> {
113+
{ "str", "str" },
114+
{ "bool", true },
115+
{ "int", 100 },
116+
{ "long", 100L },
117+
{ "double", 3.14d },
118+
{ "float", 2.718f },
119+
{ "datetime", DateTime.UtcNow }
120+
};
121+
122+
var parameters = settings.GetDbParameters(SqliteFactory.Instance);
123+
int i = -1;
124+
Assert.AreEqual("str", parameters[++i].ParameterName);
125+
Assert.AreEqual("str", parameters[i].Value);
126+
Assert.AreEqual(DbType.String, parameters[i].DbType);
127+
Assert.AreEqual("bool", parameters[++i].ParameterName);
128+
Assert.AreEqual(true, parameters[i].Value);
129+
Assert.AreEqual(DbType.Boolean, parameters[i].DbType);
130+
Assert.AreEqual("int", parameters[++i].ParameterName);
131+
Assert.AreEqual(100, parameters[i].Value);
132+
Assert.AreEqual(DbType.Int32, parameters[i].DbType);
133+
Assert.AreEqual("long", parameters[++i].ParameterName);
134+
Assert.AreEqual(100L, parameters[i].Value);
135+
Assert.AreEqual(DbType.Int64, parameters[i].DbType);
136+
Assert.AreEqual("double", parameters[++i].ParameterName);
137+
Assert.AreEqual(3.14, parameters[i].Value);
138+
Assert.AreEqual(DbType.Double, parameters[i].DbType);
139+
Assert.AreEqual("float", parameters[++i].ParameterName);
140+
Assert.AreEqual(2.718f, parameters[i].Value);
141+
Assert.AreEqual(DbType.Single, parameters[i].DbType);
142+
Assert.AreEqual("datetime", parameters[++i].ParameterName);
143+
//Assert.AreEqual(2.718f, parameters[i].Value);
144+
Assert.AreEqual(DbType.String, parameters[i].DbType);
145+
}
71146
}

Extensions/SqlServer/Cosmos.DataTransfer.SqlServerExtension/SqlServerDataSourceExtension.cs

+36-26
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using System.ComponentModel.Composition;
2-
using System.IO;
32
using System.Runtime.CompilerServices;
43
using Cosmos.DataTransfer.Interfaces;
4+
using System.Data.Common;
55
using Microsoft.Data.SqlClient;
66
using Microsoft.Extensions.Configuration;
77
using Microsoft.Extensions.Logging;
@@ -15,45 +15,55 @@ public class SqlServerDataSourceExtension : IDataSourceExtensionWithSettings
1515

1616
public async IAsyncEnumerable<IDataItem> ReadAsync(IConfiguration config, ILogger logger, [EnumeratorCancellation] CancellationToken cancellationToken = default)
1717
{
18-
await foreach (var item in this.ReadAsync(config, logger, (string connectionString) => new ValueTask<System.Data.Common.DbConnection>(new SqlConnection(connectionString)), cancellationToken)) {
18+
var settings = config.Get<SqlServerSourceSettings>();
19+
settings.Validate();
20+
21+
var providerFactory = SqlClientFactory.Instance;
22+
var connection = providerFactory.CreateConnection()!;
23+
connection.ConnectionString = settings!.ConnectionString;
24+
25+
var iterable = this.ReadAsync(config, logger, settings.GetQueryText(),
26+
settings.GetDbParameters(providerFactory), connection,
27+
providerFactory, cancellationToken);
28+
29+
await foreach (var item in iterable) {
1930
yield return item;
2031
}
2132
}
2233

2334
public async IAsyncEnumerable<IDataItem> ReadAsync(
2435
IConfiguration config,
2536
ILogger logger,
26-
Func<string,ValueTask<System.Data.Common.DbConnection>> connectionFactory,
37+
string queryText,
38+
DbParameter[] parameters,
39+
DbConnection connection,
40+
DbProviderFactory dbProviderFactory,
2741
[EnumeratorCancellation] CancellationToken cancellationToken = default)
2842
{
29-
var settings = config.Get<SqlServerSourceSettings>();
30-
settings.Validate();
43+
try {
44+
await connection.OpenAsync(cancellationToken);
45+
var command = connection.CreateCommand();
46+
command.CommandText = queryText;
47+
command.Parameters.AddRange(parameters);
3148

32-
string queryText = settings!.QueryText!;
33-
if (settings.FilePath != null) {
34-
queryText = File.ReadAllText(settings.FilePath);
35-
}
36-
37-
await using var connection = connectionFactory(settings.ConnectionString!).Result;
38-
await connection.OpenAsync(cancellationToken);
39-
var command = connection.CreateCommand();
40-
command.CommandText = queryText;
41-
//await using SqlCommand command = new SqlCommand(queryText, connection);
42-
await using var reader = await command.ExecuteReaderAsync(cancellationToken);
43-
while (await reader.ReadAsync(cancellationToken))
44-
{
45-
var columns = await reader.GetColumnSchemaAsync(cancellationToken);
46-
Dictionary<string, object?> fields = new Dictionary<string, object?>();
47-
foreach (var column in columns)
49+
await using var reader = await command.ExecuteReaderAsync(cancellationToken);
50+
while (await reader.ReadAsync(cancellationToken))
4851
{
49-
var value = column.ColumnOrdinal.HasValue ? reader[column.ColumnOrdinal.Value] : reader[column.ColumnName];
50-
if (value == DBNull.Value)
52+
var columns = await reader.GetColumnSchemaAsync(cancellationToken);
53+
Dictionary<string, object?> fields = new Dictionary<string, object?>();
54+
foreach (var column in columns)
5155
{
52-
value = null;
56+
var value = column.ColumnOrdinal.HasValue ? reader[column.ColumnOrdinal.Value] : reader[column.ColumnName];
57+
if (value == DBNull.Value)
58+
{
59+
value = null;
60+
}
61+
fields[column.ColumnName] = value;
5362
}
54-
fields[column.ColumnName] = value;
63+
yield return new DictionaryDataItem(fields);
5564
}
56-
yield return new DictionaryDataItem(fields);
65+
} finally {
66+
await connection.CloseAsync();
5767
}
5868
}
5969

0 commit comments

Comments
 (0)