Skip to content

Commit 3dfc61b

Browse files
committed
feat: Parameterized sql queries
1 parent dddfc67 commit 3dfc61b

File tree

6 files changed

+216
-52
lines changed

6 files changed

+216
-52
lines changed

Extensions/SqlServer/Cosmos.DataTransfer.SqlServerExtension.UnitTests/Cosmos.DataTransfer.SqlServerExtension.UnitTests.csproj

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.2.0" />
1717
<PackageReference Include="MSTest.TestAdapter" Version="2.2.10" />
1818
<PackageReference Include="MSTest.TestFramework" Version="2.2.10" />
19+
<PackageReference Include="Moq" Version="4.18.4" />
1920
<PackageReference Include="coverlet.collector" Version="3.1.2" />
2021
<PackageReference Include="System.Linq.Async" Version="6.0.1" />
2122
<PackageReference Include="coverlet.msbuild" Version="2.8.0">
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
using Microsoft.Data.Sqlite;
22
using Microsoft.Extensions.Logging.Abstractions;
3+
using System.Data.Common;
34
using Cosmos.DataTransfer.Interfaces;
45
using Cosmos.DataTransfer.Common;
56
using Cosmos.DataTransfer.Common.UnitTests;
7+
using Moq;
8+
using Microsoft.Extensions.Configuration;
69

710
namespace Cosmos.DataTransfer.SqlServerExtension.UnitTests;
811

912
[TestClass]
1013
public class SqlServerDataSourceExtensionTests
1114
{
1215

13-
private static async Task<Func<string,ValueTask<System.Data.Common.DbConnection>>> connectionFactory(CancellationToken cancellationToken = default(CancellationToken)) {
14-
var connection = new SqliteConnection("");
16+
private static async Task<Tuple<SqliteFactory,DbConnection>> connectionFactory(CancellationToken cancellationToken = default(CancellationToken)) {
17+
var provider = SqliteFactory.Instance;
18+
var connection = provider.CreateConnection();
1519
await connection.OpenAsync(cancellationToken);
1620

1721
var cmd = connection.CreateCommand();
@@ -27,25 +31,20 @@ name TEXT
2731
VALUES (2, NULL);";
2832
await cmd.ExecuteNonQueryAsync(cancellationToken);
2933

30-
var func = (string connectionString) => {
31-
return new ValueTask<System.Data.Common.DbConnection>(connection);
32-
};
33-
34-
return func;
34+
return Tuple.Create(provider, connection);
3535
}
3636

3737
[TestMethod]
38-
public async Task TestReadAsync_QueryText() {
38+
public async Task TestReadAsync() {
39+
var config = new Mock<IConfiguration>();
40+
var cancellationToken = new CancellationTokenSource(500);
41+
var (providerFactory, connection) = await connectionFactory(cancellationToken.Token);
42+
3943
var extension = new SqlServerDataSourceExtension();
40-
var config = TestHelpers.CreateConfig(new Dictionary<string, string> {
41-
{ "ConnectionString", "Sqlite" },
42-
{ "QueryText", "SELECT * FROM foobar" }
43-
});
4444
Assert.AreEqual("SqlServer", extension.DisplayName);
45-
46-
var cancellationToken = new CancellationTokenSource(500);
4745

48-
var result = await extension.ReadAsync(config, NullLogger.Instance, await connectionFactory(cancellationToken.Token), cancellationToken.Token).ToListAsync();
46+
var result = await extension.ReadAsync(config.Object, NullLogger.Instance,
47+
"SELECT * FROM foobar", Array.Empty<DbParameter>(), connection, providerFactory, cancellationToken.Token).ToListAsync();
4948
var expected = new List<DictionaryDataItem> {
5049
new DictionaryDataItem(new Dictionary<string, object?> { { "id", (long)1 }, { "name", "zoo" } }),
5150
new DictionaryDataItem(new Dictionary<string, object?> { { "id", (long)2 }, { "name", null } })
@@ -54,24 +53,25 @@ public async Task TestReadAsync_QueryText() {
5453
}
5554

5655
[TestMethod]
57-
public async Task TestReadAsync_FromFile() {
58-
var outputFile = Path.GetTempFileName();
59-
await File.WriteAllTextAsync(outputFile, "SELECT * FROM foobar;");
56+
public async Task TestReadAsyncWithParameters() {
57+
var config = new Mock<IConfiguration>();
58+
var cancellationToken = new CancellationTokenSource();
59+
var (providerFactory, connection) = await connectionFactory(cancellationToken.Token);
60+
6061
var extension = new SqlServerDataSourceExtension();
61-
var config = TestHelpers.CreateConfig(new Dictionary<string, string> {
62-
{ "ConnectionString", "Sqlite" },
63-
{ "FilePath", outputFile }
64-
});
62+
Assert.AreEqual("SqlServer", extension.DisplayName);
6563

66-
67-
var cancellationToken = new CancellationTokenSource(500);
64+
var parameter = providerFactory.CreateParameter();
65+
parameter.ParameterName = "@x";
66+
parameter.DbType = System.Data.DbType.Int32;
67+
parameter.Value = 2;
6868

69-
var result = await extension.ReadAsync(config, NullLogger.Instance, await connectionFactory(cancellationToken.Token), cancellationToken.Token).ToListAsync();
70-
var expected = new List<DictionaryDataItem> {
71-
new DictionaryDataItem(new Dictionary<string, object?> { { "id", (long)1 }, { "name", "zoo" } }),
72-
new DictionaryDataItem(new Dictionary<string, object?> { { "id", (long)2 }, { "name", null } })
73-
};
74-
CollectionAssert.That.AreEqual(expected, result, new DataItemComparer());
69+
var result = await extension.ReadAsync(config.Object, NullLogger.Instance,
70+
"SELECT * FROM foobar WHERE id = @x",
71+
new DbParameter[] { parameter }, connection, providerFactory, cancellationToken.Token).FirstAsync();
72+
Assert.That.AreEqual(result,
73+
new DictionaryDataItem(new Dictionary<string, object?> { { "id", (long)2 }, { "name", null } }),
74+
new DataItemComparer());
7575
}
7676

7777
// Allows for testing against an actual SQL Server by specifying a
@@ -80,10 +80,11 @@ public async Task TestReadAsync_FromFile() {
8080
// <?xml version="1.0" encoding="utf-8"?>
8181
// <RunSettings>
8282
// <TestRunParameters>
83-
// <Parameter name="TestReadAsync_LiveSqlServer_ConnectionString" value="<Your connection string>" />
83+
// <Parameter name="TestReadAsync_LiveSqlServer_ConnectionString" value="Your connection string" />
8484
// </TestRunParameters>
8585
// </RunSettings>
86-
// run test with dotnet test --settings sql.runsettings
86+
// run test with
87+
// dotnet test --settings sql.runsettings
8788
#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable.
8889
public TestContext TestContext { get; set; }
8990
#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable.
@@ -111,4 +112,4 @@ public async Task TestReadAsync_LiveSqlServer() {
111112
{ "zoo", null }
112113
})));
113114
}
114-
}
115+
}

Extensions/SqlServer/Cosmos.DataTransfer.SqlServerExtension.UnitTests/SqlServerSourceSettingsTests.cs

+78-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
using Cosmos.DataTransfer.Interfaces;
22
using System.ComponentModel.DataAnnotations;
3+
using Cosmos.DataTransfer.Common.UnitTests;
4+
using Microsoft.Data.Sqlite;
5+
using System.Data;
36

47
namespace Cosmos.DataTransfer.SqlServerExtension.UnitTests;
58

@@ -31,9 +34,10 @@ public void TestSourceSettings_ValidationFails1()
3134
[TestMethod]
3235
public void TestSourceSettings_ValidationFails2()
3336
{
37+
var fn = Path.GetTempFileName();
3438
var settings = new SqlServerSourceSettings {
3539
QueryText = "SELECT 1;",
36-
FilePath = "dmt-query.sql"
40+
FilePath = fn
3741
};
3842

3943
var validationResults = settings.Validate(new ValidationContext(settings)).ToList();
@@ -51,6 +55,22 @@ public void TestSourceSettings_ValidationFails2()
5155
Assert.ThrowsException<AggregateException>(() => settings.Validate());
5256
}
5357

58+
[TestMethod]
59+
public void TestSourceSettings_Validation_FileNotFound()
60+
{
61+
var fn = Path.GetTempFileName();
62+
var settings = new SqlServerSourceSettings {
63+
ConnectionString = "Connection, please",
64+
FilePath = "dmt.sql"
65+
};
66+
67+
var validationResults = settings.Validate(new ValidationContext(settings)).ToList();
68+
Assert.AreEqual(1, validationResults.Count);
69+
Assert.IsTrue(validationResults[0].ErrorMessage!.StartsWith("Could not read `FilePath`. Reason:"));
70+
CollectionAssert.AreEqual(new string[] { "FilePath" }, validationResults[0].MemberNames.ToArray());
71+
}
72+
73+
5474
[TestMethod]
5575
[DataRow("SELECT 1", null)]
5676
[DataRow("SELECT 1", " ")]
@@ -60,12 +80,67 @@ public void TestSourceSettings_ValidationSuccess(string queryText, string filePa
6080
var settings = new SqlServerSourceSettings {
6181
ConnectionString = "Server",
6282
QueryText = queryText,
63-
FilePath = filePath
83+
FilePath = filePath == "filename" ? Path.GetTempFileName() : filePath
6484
};
6585

6686
var validationResults = settings.Validate(new ValidationContext(settings));
6787
Assert.AreEqual(0, validationResults.Count());
68-
6988
settings.Validate();
7089
}
90+
91+
[TestMethod]
92+
public void TestSourceSettings_GetQueryText1() {
93+
var settings = new SqlServerSourceSettings() {
94+
QueryText = "SELECT 1"
95+
};
96+
Assert.AreEqual("SELECT 1", settings.GetQueryText());
97+
98+
var fn = Path.GetTempFileName();
99+
settings.FilePath = fn; // But this shouldn't occur, as the settings are invalid.
100+
File.WriteAllText(fn, "More SQL");
101+
Assert.AreEqual("More SQL", settings.GetQueryText());
102+
103+
settings.QueryText = "";
104+
Assert.AreEqual("More SQL", settings.GetQueryText());
105+
}
106+
107+
[TestMethod]
108+
public void TestSourceSettings_Parameters() {
109+
var settings = new SqlServerSourceSettings();
110+
111+
Assert.AreEqual(0, settings.GetDbParameters(SqliteFactory.Instance).Count());
112+
settings.Parameters = new Dictionary<string, object> {
113+
{ "str", "str" },
114+
{ "bool", true },
115+
{ "int", 100 },
116+
{ "long", 100L },
117+
{ "double", 3.14d },
118+
{ "float", 2.718f },
119+
{ "datetime", DateTime.UtcNow }
120+
};
121+
122+
var parameters = settings.GetDbParameters(SqliteFactory.Instance);
123+
int i = -1;
124+
Assert.AreEqual("str", parameters[++i].ParameterName);
125+
Assert.AreEqual("str", parameters[i].Value);
126+
Assert.AreEqual(DbType.String, parameters[i].DbType);
127+
Assert.AreEqual("bool", parameters[++i].ParameterName);
128+
Assert.AreEqual(true, parameters[i].Value);
129+
Assert.AreEqual(DbType.Boolean, parameters[i].DbType);
130+
Assert.AreEqual("int", parameters[++i].ParameterName);
131+
Assert.AreEqual(100, parameters[i].Value);
132+
Assert.AreEqual(DbType.Int32, parameters[i].DbType);
133+
Assert.AreEqual("long", parameters[++i].ParameterName);
134+
Assert.AreEqual(100L, parameters[i].Value);
135+
Assert.AreEqual(DbType.Int64, parameters[i].DbType);
136+
Assert.AreEqual("double", parameters[++i].ParameterName);
137+
Assert.AreEqual(3.14, parameters[i].Value);
138+
Assert.AreEqual(DbType.Double, parameters[i].DbType);
139+
Assert.AreEqual("float", parameters[++i].ParameterName);
140+
Assert.AreEqual(2.718f, parameters[i].Value);
141+
Assert.AreEqual(DbType.Single, parameters[i].DbType);
142+
Assert.AreEqual("datetime", parameters[++i].ParameterName);
143+
//Assert.AreEqual(2.718f, parameters[i].Value);
144+
Assert.AreEqual(DbType.String, parameters[i].DbType);
145+
}
71146
}

Extensions/SqlServer/Cosmos.DataTransfer.SqlServerExtension/SqlServerDataSourceExtension.cs

+19-12
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using System.IO;
33
using System.Runtime.CompilerServices;
44
using Cosmos.DataTransfer.Interfaces;
5+
using System.Data.Common;
56
using Microsoft.Data.SqlClient;
67
using Microsoft.Extensions.Configuration;
78
using Microsoft.Extensions.Logging;
@@ -15,30 +16,36 @@ public class SqlServerDataSourceExtension : IDataSourceExtensionWithSettings
1516

1617
public async IAsyncEnumerable<IDataItem> ReadAsync(IConfiguration config, ILogger logger, [EnumeratorCancellation] CancellationToken cancellationToken = default)
1718
{
18-
await foreach (var item in this.ReadAsync(config, logger, (string connectionString) => new ValueTask<System.Data.Common.DbConnection>(new SqlConnection(connectionString)), cancellationToken)) {
19+
var settings = config.Get<SqlServerSourceSettings>();
20+
settings.Validate();
21+
22+
var providerFactory = SqlClientFactory.Instance;
23+
var connection = providerFactory.CreateConnection()!;
24+
connection.ConnectionString = settings.ConnectionString;
25+
26+
var iterable = this.ReadAsync(config, logger, settings.GetQueryText(),
27+
settings.GetDbParameters(providerFactory), connection,
28+
providerFactory, cancellationToken);
29+
30+
await foreach (var item in iterable) {
1931
yield return item;
2032
}
2133
}
2234

2335
public async IAsyncEnumerable<IDataItem> ReadAsync(
2436
IConfiguration config,
2537
ILogger logger,
26-
Func<string,ValueTask<System.Data.Common.DbConnection>> connectionFactory,
38+
string queryText,
39+
DbParameter[] parameters,
40+
DbConnection connection,
41+
DbProviderFactory dbProviderFactory,
2742
[EnumeratorCancellation] CancellationToken cancellationToken = default)
2843
{
29-
var settings = config.Get<SqlServerSourceSettings>();
30-
settings.Validate();
31-
32-
string queryText = settings.QueryText!;
33-
if (settings.FilePath != null) {
34-
queryText = File.ReadAllText(settings.FilePath);
35-
}
36-
37-
await using var connection = connectionFactory(settings.ConnectionString!).Result;
3844
await connection.OpenAsync(cancellationToken);
3945
var command = connection.CreateCommand();
4046
command.CommandText = queryText;
41-
//await using SqlCommand command = new SqlCommand(queryText, connection);
47+
command.Parameters.AddRange(parameters);
48+
4249
await using var reader = await command.ExecuteReaderAsync(cancellationToken);
4350
while (await reader.ReadAsync(cancellationToken))
4451
{

Extensions/SqlServer/Cosmos.DataTransfer.SqlServerExtension/SqlServerSourceSettings.cs

+64-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using System.ComponentModel.DataAnnotations;
22
using Cosmos.DataTransfer.Interfaces;
33
using Cosmos.DataTransfer.Interfaces.Manifest;
4+
using System.Data;
5+
using System.Data.Common;
46

57
namespace Cosmos.DataTransfer.SqlServerExtension
68
{
@@ -13,6 +15,8 @@ public class SqlServerSourceSettings : IDataExtensionSettings, IValidatableObjec
1315

1416
public string? FilePath { get; set; }
1517

18+
public IDictionary<string, object>? Parameters { get; set; }
19+
1620
public IEnumerable<ValidationResult> Validate(ValidationContext validationContext)
1721
{
1822
if (String.IsNullOrWhiteSpace(this.ConnectionString)) {
@@ -30,6 +34,65 @@ public IEnumerable<ValidationResult> Validate(ValidationContext validationContex
3034
"Both `QueryText` and `FilePath` are not allowed.",
3135
new string[] { "QueryText", "FilePath"});
3236
}
37+
if (!String.IsNullOrWhiteSpace(this.FilePath)) {
38+
ValidationResult? res = null;
39+
try {
40+
_ = File.ReadAllText(this.FilePath);
41+
} catch (Exception e) {
42+
res = new ValidationResult("Could not read `FilePath`. Reason: \n" + e.Message,
43+
new string[] { "FilePath" });
44+
}
45+
if (res is not null) yield return res;
46+
}
47+
}
48+
49+
/// <summary>
50+
///
51+
/// </summary>
52+
/// <param name="dbProviderFactory">
53+
/// Use e.g., <code>Microsoft.Data.SqlClient.SqlClientFactory.Instance</code>
54+
/// or <code>Microsoft.Data.Sqlite.SqliteFactory.Instance</code>.
55+
/// </param>
56+
/// <returns></returns>
57+
public DbParameter[] GetDbParameters(DbProviderFactory dbProviderFactory) {
58+
var result = new List<DbParameter>();
59+
60+
if (this.Parameters is null || this.Parameters.Count == 0) {
61+
return Array.Empty<DbParameter>();
62+
}
63+
64+
foreach (var param in this.Parameters) {
65+
var dbparam = dbProviderFactory.CreateParameter()!;
66+
dbparam.ParameterName = param.Key;
67+
if (param.Value is bool b) {
68+
dbparam.DbType = DbType.Boolean;
69+
dbparam.Value = b;
70+
} else if (param.Value is long l) {
71+
dbparam.DbType = DbType.Int64;
72+
dbparam.Value = l;
73+
} else if (param.Value is int i) {
74+
dbparam.DbType = DbType.Int32;
75+
dbparam.Value = i;
76+
} else if (param.Value is float f) {
77+
dbparam.DbType = DbType.Single;
78+
dbparam.Value = f;
79+
} else if (param.Value is double d) {
80+
dbparam.DbType = DbType.Double;
81+
dbparam.Value = d;
82+
} else {
83+
dbparam.DbType = DbType.String;
84+
dbparam.Value = param.Value;
85+
}
86+
result.Add(dbparam);
87+
}
88+
return result.ToArray();
89+
}
90+
91+
public string GetQueryText() {
92+
if (!String.IsNullOrWhiteSpace(this.FilePath)) {
93+
return File.ReadAllText(this.FilePath);
94+
}
95+
return this.QueryText!;
3396
}
3497
}
35-
}
98+
}

0 commit comments

Comments
 (0)