Skip to content

Commit

Permalink
globally redesigned query logic, added protection against SQL injections
Browse files Browse the repository at this point in the history
  • Loading branch information
theandrunique committed Jan 25, 2024
1 parent f7f35ab commit a2e1ceb
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 130 deletions.
10 changes: 0 additions & 10 deletions SQLModel.Tests/SqliteProviderTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -176,16 +176,6 @@ await Assert.ThrowsAsync<SqliteException>(async () =>
await session.Delete(listLogins[0]);
await session.Delete(listProfiles[0]);
}

//using (var session = await core.CreateAsyncSession())
//{
// // SELECT SCOPE_IDENTITY()
// using (var reader = await session.Execute("INSERT INTO profiles (Name, Description) VALUES ('Name', 'Description'); SELECT last_insert_rowid();"))
// {
// if (reader.Read())
// log.Debug(reader.GetValue(0));
// }
//}
}
}
}
183 changes: 63 additions & 120 deletions SQLModel/CRUD.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using System;
using Dapper;
using System;
using System.Collections.Generic;
using System.Data;
using System.Diagnostics;
using System.Linq;
using System.Reflection;
using System.Threading.Tasks;
Expand All @@ -12,53 +12,33 @@ internal static class Crud
{
public static T GetById<T>(int id, Session session)
{
string query = BuildSelectQueryById(Metadata.TableClasses[typeof(T)], id);
Table table = Metadata.TableClasses[typeof(T)];

using (IDataReader reader = session.Execute(query))
{
if (reader.Read())
{
return MapToObject<T>(reader);
}
else { return default(T); }
}
string query = BuildSelectQueryById(table);

Logging.Info(query);

return session.Connection.Query<T>(query, new { Id = id }, session.Transaction).FirstOrDefault();
}
async public static Task<T> GetByIdAsync<T>(int id, AsyncSession session)
{
string query = BuildSelectQueryById(Metadata.TableClasses[typeof(T)], id);
Table table = Metadata.TableClasses[typeof(T)];

using (IDataReader reader = await session.Execute(query))
{
if (await session.DbCore.ReadReaderAsync(reader))
{
return MapToObject<T>(reader);
}
else { return default(T); }
}
string query = BuildSelectQueryById(table);

Logging.Info(query);

return (await session.Connection.QueryAsync<T>(query, new { Id = id }, session.Transaction)).FirstOrDefault();
}
private static string BuildSelectQueryById(Table table, int id)
private static string MapFields(Table table)
{
//string idClause = string.Join(" AND ", table.PrimaryKeys.Select(key => $"{key.Name} = {key.Property.GetValue(existedObject)}"));

return $"SELECT * FROM {table.Name} WHERE {table.PrimaryKeys[0].Name} = {id};";
return string.Join(", ", table.FieldsRelation.Values.Select(field => $"{field.Name} AS {field.Property.Name}"));
}
private static T MapToObject<T>(IDataReader reader)
private static string BuildSelectQueryById(Table table)
{
Type type = typeof(T);

Table table = Metadata.TableClasses[type];

PropertyInfo[] properties = table.FieldsRelation.Keys.ToArray();

T obj = Activator.CreateInstance<T>();
foreach (PropertyInfo item in properties)
{
Field field = table.FieldsRelation[item];

item.SetValue(obj, Convert.ChangeType(reader[field.Name], item.PropertyType));
}
//string idClause = string.Join(" AND ", table.PrimaryKeys.Select(key => $"{key.Name} = {key.Property.GetValue(existedObject)}"));

return obj;
return $"SELECT {MapFields(table)} FROM {table.Name} WHERE {table.PrimaryKeys[0].Name} = @Id;";
}
private static string GetFields(Dictionary<PropertyInfo, Field> keyValuePairs)
{
Expand All @@ -74,49 +54,33 @@ private static string GetFields(Dictionary<PropertyInfo, Field> keyValuePairs)

}).Where(fieldName => fieldName != null));
}
private static string GetValues(Dictionary<PropertyInfo, Field> keyValuePairs, object currentObject, bool withPrimaryKey)
{
List<PropertyInfo> properties = keyValuePairs.Keys.ToList();

return string.Join(", ", properties.Select(property =>
{
if (!withPrimaryKey && keyValuePairs[property].PrimaryKey)
{
return null;
}
var value = property.GetValue(currentObject);

return (property.PropertyType == typeof(string) || property.PropertyType == typeof(DateTime)) ? $"'{value}'" : value.ToString();

}).Where(fieldName => fieldName != null));
}
public static void Create(object newObject, Session session)
{
string query = BuildCreateQuery(newObject, session.DbCore.GetLastInsertRowId());

using (var reader = session.Execute(query))
{
if (reader.Read())
{
PrimaryKey key = Metadata.TableClasses[newObject.GetType()].PrimaryKeys[0];
Logging.Info(query);

key.Property.SetValue(newObject, Convert.ChangeType(reader.GetValue(0), key.Property.PropertyType));
}
}
int newObjectId = session.Connection.Query<int>(query, newObject, session.Transaction).FirstOrDefault();

PrimaryKey key = Metadata.TableClasses[newObject.GetType()].PrimaryKeys[0];

key.Property.SetValue(newObject, Convert.ChangeType(newObjectId, key.Property.PropertyType));
}
async public static Task CreateAsync(object newObject, AsyncSession session)
{
string query = BuildCreateQuery(newObject, session.DbCore.GetLastInsertRowId());

using (var reader = await session.Execute(query))
{
if (await session.ReadAsync(reader))
{
PrimaryKey key = Metadata.TableClasses[newObject.GetType()].PrimaryKeys[0];
Logging.Info(query);

key.Property.SetValue(newObject, Convert.ChangeType(reader.GetValue(0), key.Property.PropertyType));
}
}
int newObjectId = (await session.Connection.QueryAsync<int>(query, newObject, session.Transaction)).FirstOrDefault();

PrimaryKey key = Metadata.TableClasses[newObject.GetType()].PrimaryKeys[0];

key.Property.SetValue(newObject, Convert.ChangeType(newObjectId, key.Property.PropertyType));
}
public static string GetParams(Dictionary<PropertyInfo, Field> keyValuePairs)
{
return string.Join(", ", keyValuePairs.Values.Where(field => !field.PrimaryKey).Select(field => $"@{field.Name}"));
}
private static string BuildCreateQuery(object newObject, string lastInsertRowId)
{
Expand All @@ -126,9 +90,9 @@ private static string BuildCreateQuery(object newObject, string lastInsertRowId)

string fieldList = GetFields(table.FieldsRelation);

string valueList = GetValues(table.FieldsRelation, newObject, false);
string paramsList = GetParams(table.FieldsRelation);

return $"INSERT INTO {table.Name} ({fieldList}) VALUES ({valueList}); {lastInsertRowId};";
return $"INSERT INTO {table.Name} ({fieldList}) VALUES ({paramsList}); {lastInsertRowId};";
}
private static string BuildUpdateQuery(object existedObject)
{
Expand All @@ -148,30 +112,29 @@ private static string BuildUpdateQuery(object existedObject)
}
var value = property.GetValue(existedObject);

if (property.PropertyType == typeof(string) || property.PropertyType == typeof(DateTime))
{
return $"{field.Name} = '{value}'";

} else
{
return $"{field.Name} = {value}";
}
return $"{field.Name} = @{field.Property.Name}";

}).Where(fieldValue => fieldValue != null));

string idClause = string.Join(" AND ", table.PrimaryKeys.Select(key => $"{key.Name} = {key.Property.GetValue(existedObject)}"));
string idClause = string.Join(" AND ", table.PrimaryKeys.Select(key => $"{key.Name} = @{key.Property.Name}"));

return $"UPDATE {table.Name} SET {setClause} WHERE {idClause};";
}
async public static Task UpdateAsync(object existedObject, AsyncSession session)
{
string query = BuildUpdateQuery(existedObject);
await session.ExecuteNonQuery(query);

Logging.Info(query);

await session.Connection.ExecuteAsync(query, existedObject, session.Transaction);
}
public static void Update(object existedObject, Session session)
{
string query = BuildUpdateQuery(existedObject);
session.ExecuteNonQuery(query);

Logging.Info(query);

session.Connection.Execute(query, existedObject, session.Transaction);
}
private static string BuildDeleteQuery(object existedObject)
{
Expand All @@ -181,72 +144,52 @@ private static string BuildDeleteQuery(object existedObject)

PrimaryKey primaryKey = table.PrimaryKeys[0];

string idClause = $"{primaryKey.Name} = {primaryKey.Property.GetValue(existedObject)}";
string idClause = $"{primaryKey.Name} = @{primaryKey.Property.Name}";

return $"DELETE FROM {table.Name} WHERE {idClause};";
}
public static void Delete(object existedObject, Session session)
{
string query = BuildDeleteQuery(existedObject);

session.ExecuteNonQuery(query);
Logging.Info(query);

session.Connection.Execute(query, existedObject, session.Transaction);
}
async public static Task DeleteAsync(object existedObject, AsyncSession session)
{
string query = BuildDeleteQuery(existedObject);

await session.ExecuteNonQuery(query);
Logging.Info(query);

await session.Connection.ExecuteAsync(query, existedObject, session.Transaction);
}
private static string BuildSelectAllQuery<T>()
{
Type type = typeof(T);

Table table = Metadata.TableClasses[type];

return $"SELECT * FROM {table.Name};";
}
private static T CreateInstance<T>(IDataReader reader)
{
T obj = Activator.CreateInstance<T>();

Table table = Metadata.TableClasses[typeof(T)];

List<PropertyInfo> properties = table.FieldsRelation.Keys.ToList();

foreach (PropertyInfo item in properties)
{
item.SetValue(obj, Convert.ChangeType(reader[table.FieldsRelation[item].Name], item.PropertyType));
}
return obj;
return $"SELECT {MapFields(table)} FROM {table.Name};";
}
public static List<T> GetAll<T>(Session session)
{
string query = BuildSelectAllQuery<T>();
List<T> list = new List<T>();

using (IDataReader reader = session.Execute(query))
{
while (reader.Read())
{
T obj = CreateInstance<T>(reader);
list.Add(obj);
}
}
Logging.Info(query);

List<T> list = session.Connection.Query<T>(query, null, session.Transaction).ToList();

return list;
}
async public static Task<List<T>> GetAllAsync<T>(AsyncSession session)
{
string query = BuildSelectAllQuery<T>();
List<T> list = new List<T>();

using (IDataReader reader = await session.Execute(query))
{
while (await session.ReadAsync(reader))
{
T obj = CreateInstance<T>(reader);
list.Add(obj);
}
}
Logging.Info(query);

List<T> list = (await session.Connection.QueryAsync<T>(query, null, session.Transaction)).ToList();

return list;
}
}
Expand Down
1 change: 1 addition & 0 deletions SQLModel/SQLModel.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Dapper" Version="2.1.28" />
<PackageReference Include="Microsoft.Data.Sqlite" Version="2.2.0" />
<PackageReference Include="NLog" Version="5.2.8" />
<PackageReference Include="System.Data.SqlClient" Version="4.8.5" />
Expand Down

0 comments on commit a2e1ceb

Please sign in to comment.