How to get COUNT DISTINCT in translated SQL with EF Core

I wanted to share an idea I had for solving my issues about count distinct.

Ultimately another way of doing count distinct in a group by function, is by having nested group by functions (assuming you can aggregate your data through).

Here is an example of what I used, it seems to work.

Apologes for the criptic acronims, I am using this to keep my JSON as small as can be.

var myData = _context.ActivityItems
                        .GroupBy(a => new { ndt = EF.Property<DateTime>(a, "dt").Date, ntn = a.tn })
                        .Select(g => new
                        {
                            g.Key.ndt,
                            g.Key.ntn,
                            dpv = g.Sum(o => o.pv),
                            dlv = g.Sum(o => o.lv),
                            cnt = g.Count(),
                        })
                        .GroupBy(a => new { ntn = a.ntn })
                        .Select(g => new
                        {
                            g.Key.ntn,
                            sd = g.Min(o => o.ndt),
                            ld = g.Max(o => o.ndt),
                            pSum = g.Sum(o => o.dpv),
                            pMin = g.Min(o => o.dpv),
                            pMax = g.Max(o => o.dpv),
                            pAvg = g.Average(o => o.dpv),
                            lSum = g.Sum(o => o.dlv),
                            lMin = g.Min(o => o.dlv),
                            lMax = g.Max(o => o.dlv),
                            lAvg = g.Average(o => o.dlv),
                            n10s = g.Sum(o => o.cnt),
                            ndays = g.Count()
                        });

Update (EF Core 5.x):

Starting with version 5.0, expression Select(expr).Distinct().Count() is now recognized by EF Core and translated to the corresponding SQL COUNT(DISTINCT expr)), hence the original LINQ query can be used w/o modification.


Original (EF Core 2.x), the solution DOES NOT work with EF Core 3.x due to query pipeline rewrite:

EF (6 and Core) historically does not support this standard SQL construct. Most likely because of the lack of standard LINQ method and technical difficulties of mapping Select(expr).Distinct().Count() to it.

The good thing is that EF Core is extendable by replacing many of its internal services with custom derived implementations to override the required behaviors. Not easy, requires a lot of plumbing code, but doable.

So the idea is to add and use simple custom CountDistinct methods like this

public static int CountDistinct<T, TKey>(this IQueryable<T> source, Expression<Func<T, TKey>> keySelector)
    => source.Select(keySelector).Distinct().Count();

public static int CountDistinct<T, TKey>(this IEnumerable<T> source, Func<T, TKey> keySelector)
    => source.Select(keySelector).Distinct().Count();

and let EF Core somehow translate them to SQL. In fact EF Core provides a simple way of defining (and even custom translating) database scalar functions, but unfortunately this cannot be used for aggregate functions which have separate processing pipeline. So we need to dig deeply into EF Core infrastructure.

The full code for that for EF Core 2.x pipeline is provided at the end. Not sure if it's worth efforts because EF Core 3.0 will use complete rewritten query process pipeline. But it was interesting and also I'm pretty sure it can be updated for the new (hopefully simpler) pipeline.

Anyway, all you need is to copy/paste the code into a new code file in the project, add the following to the context OnConfiguring override

optionsBuilder.UseCustomExtensions();

which will plug the functionality into EF Core infrastructure, and then query like this

var result = db.MyTable
    .GroupBy(x => x.PersonID, x => new { VisitStartDate = x.VisitStart.Date })
    .Select(g => new
    {
        Count = g.CountDistinct(x => x.VisitStartDate)
    }).ToList();

will luckily be translated to the desired

SELECT COUNT(DISTINCT(CONVERT(date, [x].[VisitStart]))) AS [Count]
FROM [MyTable] AS [x]
GROUP BY [x].[PersonID]

Note the preselecting the expression needed for aggregate method. This is current EF Core limitation/requirement for all aggregate methods, not just ours.

Finally, the full code that does the magic:

using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query;
using Microsoft.EntityFrameworkCore.Query.Expressions;
using Microsoft.EntityFrameworkCore.Query.ExpressionVisitors;
using Microsoft.EntityFrameworkCore.Query.ExpressionVisitors.Internal;
using Microsoft.EntityFrameworkCore.Query.Internal;
using Remotion.Linq;
using Remotion.Linq.Clauses;
using Remotion.Linq.Clauses.ResultOperators;
using Remotion.Linq.Clauses.StreamedData;
using Remotion.Linq.Parsing.Structure.IntermediateModel;

namespace Microsoft.EntityFrameworkCore
{
    public static partial class CustomExtensions
    {
        public static int CountDistinct<T, TKey>(this IQueryable<T> source, Expression<Func<T, TKey>> keySelector)
            => source.Select(keySelector).Distinct().Count();

        public static int CountDistinct<T, TKey>(this IEnumerable<T> source, Func<T, TKey> keySelector)
            => source.Select(keySelector).Distinct().Count();

        public static DbContextOptionsBuilder UseCustomExtensions(this DbContextOptionsBuilder optionsBuilder)
            => optionsBuilder
                .ReplaceService<INodeTypeProviderFactory, CustomNodeTypeProviderFactory>()
                .ReplaceService<IRelationalResultOperatorHandler, CustomRelationalResultOperatorHandler>();
    }
}

namespace Remotion.Linq.Parsing.Structure.IntermediateModel
{
    public sealed class CountDistinctExpressionNode : ResultOperatorExpressionNodeBase
    {
        public CountDistinctExpressionNode(MethodCallExpressionParseInfo parseInfo, LambdaExpression optionalSelector)
            : base(parseInfo, null, optionalSelector) { }
        public static IEnumerable<MethodInfo> GetSupportedMethods()
            => typeof(CustomExtensions).GetTypeInfo().GetDeclaredMethods("CountDistinct");
        public override Expression Resolve(ParameterExpression inputParameter, Expression expressionToBeResolved, ClauseGenerationContext clauseGenerationContext)
            => throw CreateResolveNotSupportedException();
        protected override ResultOperatorBase CreateResultOperator(ClauseGenerationContext clauseGenerationContext)
            => new CountDistinctResultOperator();
    }
}

namespace Remotion.Linq.Clauses.ResultOperators
{
    public sealed class CountDistinctResultOperator : ValueFromSequenceResultOperatorBase
    {
        public override ResultOperatorBase Clone(CloneContext cloneContext) => new CountDistinctResultOperator();
        public override StreamedValue ExecuteInMemory<T>(StreamedSequence input) => throw new NotSupportedException();
        public override IStreamedDataInfo GetOutputDataInfo(IStreamedDataInfo inputInfo) => new StreamedScalarValueInfo(typeof(int));
        public override string ToString() => "CountDistinct()";
        public override void TransformExpressions(Func<Expression, Expression> transformation) { }
    }
}

namespace Microsoft.EntityFrameworkCore.Query.Internal
{
    public class CustomNodeTypeProviderFactory : DefaultMethodInfoBasedNodeTypeRegistryFactory
    {
        public CustomNodeTypeProviderFactory()
            => RegisterMethods(CountDistinctExpressionNode.GetSupportedMethods(), typeof(CountDistinctExpressionNode));
    }

    public class CustomRelationalResultOperatorHandler : RelationalResultOperatorHandler
    {
        private static readonly ISet<Type> AggregateResultOperators = (ISet<Type>)
            typeof(RequiresMaterializationExpressionVisitor).GetField("_aggregateResultOperators", BindingFlags.NonPublic | BindingFlags.Static)
            .GetValue(null);

        static CustomRelationalResultOperatorHandler()
            => AggregateResultOperators.Add(typeof(CountDistinctResultOperator));

        public CustomRelationalResultOperatorHandler(IModel model, ISqlTranslatingExpressionVisitorFactory sqlTranslatingExpressionVisitorFactory, ISelectExpressionFactory selectExpressionFactory, IResultOperatorHandler resultOperatorHandler)
            : base(model, sqlTranslatingExpressionVisitorFactory, selectExpressionFactory, resultOperatorHandler)
        { }

        public override Expression HandleResultOperator(EntityQueryModelVisitor entityQueryModelVisitor, ResultOperatorBase resultOperator, QueryModel queryModel)
            => resultOperator is CountDistinctResultOperator ?
                HandleCountDistinct(entityQueryModelVisitor, resultOperator, queryModel) :
                base.HandleResultOperator(entityQueryModelVisitor, resultOperator, queryModel);

        private Expression HandleCountDistinct(EntityQueryModelVisitor entityQueryModelVisitor, ResultOperatorBase resultOperator, QueryModel queryModel)
        {
            var queryModelVisitor = (RelationalQueryModelVisitor)entityQueryModelVisitor;
            var selectExpression = queryModelVisitor.TryGetQuery(queryModel.MainFromClause);
            var inputType = queryModel.SelectClause.Selector.Type;
            if (CanEvalOnServer(queryModelVisitor)
                && selectExpression != null
                && selectExpression.Projection.Count == 1)
            {
                PrepareSelectExpressionForAggregate(selectExpression, queryModel);
                var expression = selectExpression.Projection[0];
                var subExpression = new SqlFunctionExpression(
                    "DISTINCT", inputType, new[] { expression.UnwrapAliasExpression() });
                selectExpression.SetProjectionExpression(new SqlFunctionExpression(
                    "COUNT", typeof(int), new[] { subExpression }));
                return new ResultTransformingExpressionVisitor<int>(
                    queryModelVisitor.QueryCompilationContext, false)
                    .Visit(queryModelVisitor.Expression);
            }
            else
            {
                queryModelVisitor.RequiresClientResultOperator = true;
                var typeArgs = new[] { inputType };
                var distinctCall = Expression.Call(
                    typeof(Enumerable), "Distinct", typeArgs,
                    queryModelVisitor.Expression);
                return Expression.Call(
                    typeof(Enumerable), "Count", typeArgs,
                    distinctCall);
            }
        }

        private static bool CanEvalOnServer(RelationalQueryModelVisitor queryModelVisitor) =>
            !queryModelVisitor.RequiresClientEval && !queryModelVisitor.RequiresClientSelectMany &&
            !queryModelVisitor.RequiresClientJoin && !queryModelVisitor.RequiresClientFilter &&
            !queryModelVisitor.RequiresClientOrderBy && !queryModelVisitor.RequiresClientResultOperator &&
            !queryModelVisitor.RequiresStreamingGroupResultOperator;
    }
}

Tags:

C#

Ef Core 2.1