diff --git a/src/JsonApiDotNetCore/Repositories/DbContextExtensions.cs b/src/JsonApiDotNetCore/Repositories/DbContextExtensions.cs
index cadbd658a8..f4cfd1ddbd 100644
--- a/src/JsonApiDotNetCore/Repositories/DbContextExtensions.cs
+++ b/src/JsonApiDotNetCore/Repositories/DbContextExtensions.cs
@@ -1,3 +1,4 @@
+using System.Reflection;
using JetBrains.Annotations;
using JsonApiDotNetCore.Resources;
using Microsoft.EntityFrameworkCore;
@@ -8,6 +9,8 @@ namespace JsonApiDotNetCore.Repositories;
[PublicAPI]
public static class DbContextExtensions
{
+ private static readonly MethodInfo DbContextSetMethod = typeof(DbContext).GetMethod(nameof(DbContext.Set), Type.EmptyTypes)!;
+
///
/// If not already tracked, attaches the specified resource to the change tracker in state.
///
@@ -57,4 +60,10 @@ public static void ResetChangeTracker(this DbContext dbContext)
dbContext.ChangeTracker.Clear();
}
+
+ public static IQueryable Set(this DbContext context, Type entityType)
+ {
+ MethodInfo setMethod = DbContextSetMethod.MakeGenericMethod(entityType);
+ return (IQueryable)setMethod.Invoke(context, null)!;
+ }
}
diff --git a/src/JsonApiDotNetCore/Repositories/EntityFrameworkCoreRepository.cs b/src/JsonApiDotNetCore/Repositories/EntityFrameworkCoreRepository.cs
index 1b807fd24f..b2cd3d86ff 100644
--- a/src/JsonApiDotNetCore/Repositories/EntityFrameworkCoreRepository.cs
+++ b/src/JsonApiDotNetCore/Repositories/EntityFrameworkCoreRepository.cs
@@ -325,6 +325,9 @@ public virtual async Task DeleteAsync(TResource? resourceFromDatabase, TId id, C
var resourceTracked = (TResource)_dbContext.GetTrackedOrAttach(placeholderResource);
+ EnsureIncomingNavigationsAreTracked(resourceTracked);
+
+ /*
foreach (RelationshipAttribute relationship in _resourceGraph.GetResourceType().Relationships)
{
// Loads the data of the relationship, if in Entity Framework Core it is configured in such a way that loading
@@ -335,6 +338,7 @@ public virtual async Task DeleteAsync(TResource? resourceFromDatabase, TId id, C
await navigation.LoadAsync(cancellationToken);
}
}
+ */
_dbContext.Remove(resourceTracked);
@@ -343,6 +347,114 @@ public virtual async Task DeleteAsync(TResource? resourceFromDatabase, TId id, C
await _resourceDefinitionAccessor.OnWriteSucceededAsync(resourceTracked, WriteOperationKind.DeleteResource, cancellationToken);
}
+ private void EnsureIncomingNavigationsAreTracked(TResource resourceTracked)
+ {
+ IEntityType[] entityTypes = _dbContext.Model.GetEntityTypes().ToArray();
+ IEntityType thisEntityType = entityTypes.Single(entityType => entityType.ClrType == typeof(TResource));
+
+ HashSet navigationsToLoad = new();
+
+ foreach (INavigation navigation in entityTypes.SelectMany(entityType => entityType.GetNavigations()))
+ {
+ bool requiresLoad = navigation.IsOnDependent ? navigation.TargetEntityType == thisEntityType : navigation.DeclaringEntityType == thisEntityType;
+
+ if (requiresLoad && navigation.ForeignKey.DeleteBehavior == DeleteBehavior.ClientSetNull)
+ {
+ navigationsToLoad.Add(navigation);
+ }
+ }
+
+ // {Navigation: Customer.FirstOrder (Order) ToPrincipal Order}
+ // var query = from _dbContext.Set().Where(customer => customer.FirstOrder == resourceTracked) // .Select(customer => customer.Id)
+
+ // {Navigation: Customer.LastOrder (Order) ToPrincipal Order}
+ // var query = from _dbContext.Set().Where(customer => customer.LastOrder == resourceTracked) // .Select(customer => customer.Id)
+
+ // {Navigation: Order.Parent (Order) ToPrincipal Order}
+ // var query = from _dbContext.Set().Where(order => order.Parent == resourceTracked) // .Select(order => order.Id)
+
+ // {Navigation: ShoppingBasket.CurrentOrder (Order) ToPrincipal Order}
+ // var query = from _dbContext.Set().Where(shoppingBasket => shoppingBasket.CurrentOrder == resourceTracked) // .Select(shoppingBasket => shoppingBasket.Id)
+
+ var nameFactory = new LambdaParameterNameFactory();
+ var scopeFactory = new LambdaScopeFactory(nameFactory);
+
+ foreach (INavigation navigation in navigationsToLoad)
+ {
+ if (!navigation.IsOnDependent && navigation.Inverse != null)
+ {
+ // TODO: Handle the case where there is no inverse.
+ continue;
+ }
+
+ IQueryable source = _dbContext.Set(navigation.DeclaringEntityType.ClrType);
+
+ using LambdaScope scope = scopeFactory.CreateScope(source.ElementType);
+
+ Expression expression;
+
+ if (navigation.IsCollection)
+ {
+ /*
+ {Navigation: WorkItem.Subscribers (ISet) Collection ToDependent UserAccount}
+
+ var subscribers = dbContext.WorkItems
+ .Where(workItem => workItem == existingWorkItem)
+ .Include(workItem => workItem.Subscribers)
+ .Select(workItem => workItem.Subscribers);
+ */
+
+ Expression left = scope.Accessor;
+ Expression right = Expression.Constant(resourceTracked, typeof(TResource));
+
+ Expression whereBody = Expression.Equal(left, right);
+ LambdaExpression wherePredicate = Expression.Lambda(whereBody, scope.Parameter);
+ Expression whereExpression = WhereExtensionMethodCall(source.Expression, scope, wherePredicate);
+
+ // TODO: Use typed overload
+ Expression includeExpression = IncludeExtensionMethodCall(whereExpression, scope, navigation.Name);
+
+ MemberExpression selectorBody = Expression.MakeMemberAccess(scope.Accessor, navigation.PropertyInfo);
+ LambdaExpression selectorLambda = Expression.Lambda(selectorBody, scope.Parameter);
+
+ expression = SelectExtensionMethodCall(includeExpression, source.ElementType, navigation.PropertyInfo.PropertyType, selectorLambda);
+ }
+ else
+ {
+ MemberExpression left = Expression.MakeMemberAccess(scope.Parameter, navigation.PropertyInfo);
+ ConstantExpression right = Expression.Constant(resourceTracked, typeof(TResource));
+
+ Expression body = Expression.Equal(left, right);
+ LambdaExpression selectorLambda = Expression.Lambda(body, scope.Parameter);
+ expression = WhereExtensionMethodCall(source.Expression, scope, selectorLambda);
+ }
+
+ IQueryable queryable = source.Provider.CreateQuery(expression);
+
+ // Executes the query and loads the returned entities in the change tracker.
+ // We can likely optimize this by only fetching IDs and creating placeholder resources for them.
+ object[] results = queryable.Cast