Most efficient way to test equality of lambda expressions

UPDATE: Due to interest to my solution, I have updated the code so it supports arrays, new operators and other stuff and compares the ASTs in more elegant way.

Here is an improved version of Marc's code and now it's available as a nuget package:

public static class LambdaCompare
{
    public static bool Eq<TSource, TValue>(
        Expression<Func<TSource, TValue>> x,
        Expression<Func<TSource, TValue>> y)
    {
        return ExpressionsEqual(x, y, null, null);
    }

    public static bool Eq<TSource1, TSource2, TValue>(
        Expression<Func<TSource1, TSource2, TValue>> x,
        Expression<Func<TSource1, TSource2, TValue>> y)
    {
        return ExpressionsEqual(x, y, null, null);
    }

    public static Expression<Func<Expression<Func<TSource, TValue>>, bool>> Eq<TSource, TValue>(Expression<Func<TSource, TValue>> y)
    {
        return x => ExpressionsEqual(x, y, null, null);
    }

    private static bool ExpressionsEqual(Expression x, Expression y, LambdaExpression rootX, LambdaExpression rootY)
    {
        if (ReferenceEquals(x, y)) return true;
        if (x == null || y == null) return false;

        var valueX = TryCalculateConstant(x);
        var valueY = TryCalculateConstant(y);

        if (valueX.IsDefined && valueY.IsDefined)
            return ValuesEqual(valueX.Value, valueY.Value);

        if (x.NodeType != y.NodeType
            || x.Type != y.Type)
        {
            if (IsAnonymousType(x.Type) && IsAnonymousType(y.Type))
                throw new NotImplementedException("Comparison of Anonymous Types is not supported");
            return false;
        }

        if (x is LambdaExpression)
        {
            var lx = (LambdaExpression)x;
            var ly = (LambdaExpression)y;
            var paramsX = lx.Parameters;
            var paramsY = ly.Parameters;
            return CollectionsEqual(paramsX, paramsY, lx, ly) && ExpressionsEqual(lx.Body, ly.Body, lx, ly);
        }
        if (x is MemberExpression)
        {
            var mex = (MemberExpression)x;
            var mey = (MemberExpression)y;
            return Equals(mex.Member, mey.Member) && ExpressionsEqual(mex.Expression, mey.Expression, rootX, rootY);
        }
        if (x is BinaryExpression)
        {
            var bx = (BinaryExpression)x;
            var by = (BinaryExpression)y;
            return bx.Method == @by.Method && ExpressionsEqual(bx.Left, @by.Left, rootX, rootY) &&
                   ExpressionsEqual(bx.Right, @by.Right, rootX, rootY);
        }
        if (x is UnaryExpression)
        {
            var ux = (UnaryExpression)x;
            var uy = (UnaryExpression)y;
            return ux.Method == uy.Method && ExpressionsEqual(ux.Operand, uy.Operand, rootX, rootY);
        }
        if (x is ParameterExpression)
        {
            var px = (ParameterExpression)x;
            var py = (ParameterExpression)y;
            return rootX.Parameters.IndexOf(px) == rootY.Parameters.IndexOf(py);
        }
        if (x is MethodCallExpression)
        {
            var cx = (MethodCallExpression)x;
            var cy = (MethodCallExpression)y;
            return cx.Method == cy.Method
                   && ExpressionsEqual(cx.Object, cy.Object, rootX, rootY)
                   && CollectionsEqual(cx.Arguments, cy.Arguments, rootX, rootY);
        }
        if (x is MemberInitExpression)
        {
            var mix = (MemberInitExpression)x;
            var miy = (MemberInitExpression)y;
            return ExpressionsEqual(mix.NewExpression, miy.NewExpression, rootX, rootY)
                   && MemberInitsEqual(mix.Bindings, miy.Bindings, rootX, rootY);
        }
        if (x is NewArrayExpression)
        {
            var nx = (NewArrayExpression)x;
            var ny = (NewArrayExpression)y;
            return CollectionsEqual(nx.Expressions, ny.Expressions, rootX, rootY);
        }
        if (x is NewExpression)
        {
            var nx = (NewExpression)x;
            var ny = (NewExpression)y;
            return
                Equals(nx.Constructor, ny.Constructor)
                && CollectionsEqual(nx.Arguments, ny.Arguments, rootX, rootY)
                && (nx.Members == null && ny.Members == null
                    || nx.Members != null && ny.Members != null && CollectionsEqual(nx.Members, ny.Members));
        }
        if (x is ConditionalExpression)
        {
            var cx = (ConditionalExpression)x;
            var cy = (ConditionalExpression)y;
            return
                ExpressionsEqual(cx.Test, cy.Test, rootX, rootY)
                && ExpressionsEqual(cx.IfFalse, cy.IfFalse, rootX, rootY)
                && ExpressionsEqual(cx.IfTrue, cy.IfTrue, rootX, rootY);
        }

        throw new NotImplementedException(x.ToString());
    }

    private static Boolean IsAnonymousType(Type type)
    {
        Boolean hasCompilerGeneratedAttribute = type.GetCustomAttributes(typeof(CompilerGeneratedAttribute), false).Any();
        Boolean nameContainsAnonymousType = type.FullName.Contains("AnonymousType");
        Boolean isAnonymousType = hasCompilerGeneratedAttribute && nameContainsAnonymousType;

        return isAnonymousType;
    }

    private static bool MemberInitsEqual(ICollection<MemberBinding> bx, ICollection<MemberBinding> by, LambdaExpression rootX, LambdaExpression rootY)
    {
        if (bx.Count != by.Count)
        {
            return false;
        }

        if (bx.Concat(by).Any(b => b.BindingType != MemberBindingType.Assignment))
            throw new NotImplementedException("Only MemberBindingType.Assignment is supported");

        return
            bx.Cast<MemberAssignment>().OrderBy(b => b.Member.Name).Select((b, i) => new { Expr = b.Expression, b.Member, Index = i })
            .Join(
                  by.Cast<MemberAssignment>().OrderBy(b => b.Member.Name).Select((b, i) => new { Expr = b.Expression, b.Member, Index = i }),
                  o => o.Index, o => o.Index, (xe, ye) => new { XExpr = xe.Expr, XMember = xe.Member, YExpr = ye.Expr, YMember = ye.Member })
                   .All(o => Equals(o.XMember, o.YMember) && ExpressionsEqual(o.XExpr, o.YExpr, rootX, rootY));
    }

    private static bool ValuesEqual(object x, object y)
    {
        if (ReferenceEquals(x, y))
            return true;
        if (x is ICollection && y is ICollection)
            return CollectionsEqual((ICollection)x, (ICollection)y);

        return Equals(x, y);
    }

    private static ConstantValue TryCalculateConstant(Expression e)
    {
        if (e is ConstantExpression)
            return new ConstantValue(true, ((ConstantExpression)e).Value);
        if (e is MemberExpression)
        {
            var me = (MemberExpression)e;
            var parentValue = TryCalculateConstant(me.Expression);
            if (parentValue.IsDefined)
            {
                var result =
                    me.Member is FieldInfo
                        ? ((FieldInfo)me.Member).GetValue(parentValue.Value)
                        : ((PropertyInfo)me.Member).GetValue(parentValue.Value);
                return new ConstantValue(true, result);
            }
        }
        if (e is NewArrayExpression)
        {
            var ae = ((NewArrayExpression)e);
            var result = ae.Expressions.Select(TryCalculateConstant);
            if (result.All(i => i.IsDefined))
                return new ConstantValue(true, result.Select(i => i.Value).ToArray());
        }
        if (e is ConditionalExpression)
        {
            var ce = (ConditionalExpression)e;
            var evaluatedTest = TryCalculateConstant(ce.Test);
            if (evaluatedTest.IsDefined)
            {
                return TryCalculateConstant(Equals(evaluatedTest.Value, true) ? ce.IfTrue : ce.IfFalse);
            }
        }

        return default(ConstantValue);
    }

    private static bool CollectionsEqual(IEnumerable<Expression> x, IEnumerable<Expression> y, LambdaExpression rootX, LambdaExpression rootY)
    {
        return x.Count() == y.Count()
               && x.Select((e, i) => new { Expr = e, Index = i })
                   .Join(y.Select((e, i) => new { Expr = e, Index = i }),
                         o => o.Index, o => o.Index, (xe, ye) => new { X = xe.Expr, Y = ye.Expr })
                   .All(o => ExpressionsEqual(o.X, o.Y, rootX, rootY));
    }

    private static bool CollectionsEqual(ICollection x, ICollection y)
    {
        return x.Count == y.Count
               && x.Cast<object>().Select((e, i) => new { Expr = e, Index = i })
                   .Join(y.Cast<object>().Select((e, i) => new { Expr = e, Index = i }),
                         o => o.Index, o => o.Index, (xe, ye) => new { X = xe.Expr, Y = ye.Expr })
                   .All(o => Equals(o.X, o.Y));
    }

    private struct ConstantValue
    {
        public ConstantValue(bool isDefined, object value)
            : this()
        {
            IsDefined = isDefined;
            Value = value;
        }

        public bool IsDefined { get; private set; }

        public object Value { get; private set; }
    }
}

Note that it does not compare full AST. Instead, it collapses constant expressions and compares their values rather than their AST. It is useful for mocks validation when the lambda has a reference to local variable. In his case the variable is compared by its value.

Unit tests:

[TestClass]
public class Tests
{
    [TestMethod]
    public void BasicConst()
    {
        var f1 = GetBasicExpr1();
        var f2 = GetBasicExpr2();
        Assert.IsTrue(LambdaCompare.Eq(f1, f2));
    }

    [TestMethod]
    public void PropAndMethodCall()
    {
        var f1 = GetPropAndMethodExpr1();
        var f2 = GetPropAndMethodExpr2();
        Assert.IsTrue(LambdaCompare.Eq(f1, f2));
    }

    [TestMethod]
    public void MemberInitWithConditional()
    {
        var f1 = GetMemberInitExpr1();
        var f2 = GetMemberInitExpr2();
        Assert.IsTrue(LambdaCompare.Eq(f1, f2));
    }

    [TestMethod]
    public void AnonymousType()
    {
        var f1 = GetAnonymousExpr1();
        var f2 = GetAnonymousExpr2();
        Assert.Inconclusive("Anonymous Types are not supported");
    }

    private static Expression<Func<int, string, string>> GetBasicExpr2()
    {
        var const2 = "some const value";
        var const3 = "{0}{1}{2}{3}";
        return (i, s) =>
            string.Format(const3, (i + 25).ToString(CultureInfo.InvariantCulture), i + s, const2.ToUpper(), 25);
    }

    private static Expression<Func<int, string, string>> GetBasicExpr1()
    {
        var const1 = 25;
        return (first, second) =>
            string.Format("{0}{1}{2}{3}", (first + const1).ToString(CultureInfo.InvariantCulture), first + second,
                "some const value".ToUpper(), const1);
    }

    private static Expression<Func<Uri, bool>> GetPropAndMethodExpr2()
    {
        return u => Uri.IsWellFormedUriString(u.ToString(), UriKind.Absolute);
    }

    private static Expression<Func<Uri, bool>> GetPropAndMethodExpr1()
    {
        return arg1 => Uri.IsWellFormedUriString(arg1.ToString(), UriKind.Absolute);
    }

    private static Expression<Func<Uri, UriBuilder>> GetMemberInitExpr2()
    {
        var isSecure = true;
        return u => new UriBuilder(u) { Host = string.IsNullOrEmpty(u.Host) ? "abc" : "def" , Port = isSecure ? 443 : 80 };
    }

    private static Expression<Func<Uri, UriBuilder>> GetMemberInitExpr1()
    {
        var port = 443;
        return x => new UriBuilder(x) { Port = port, Host = string.IsNullOrEmpty(x.Host) ? "abc" : "def" };
    }

    private static Expression<Func<Uri, object>> GetAnonymousExpr2()
    {
        return u => new { u.Host , Port = 443, Addr = u.AbsolutePath };
    }

    private static Expression<Func<Uri, object>> GetAnonymousExpr1()
    {
        return x => new { Port = 443, x.Host, Addr = x.AbsolutePath };
    }
}

Hmmm... I guess you'd have to parse the tree, checking the node-type and member of each. I'll knock up an example...

using System;
using System.Linq.Expressions;
class Test {
    public string Foo { get; set; }
    public string Bar { get; set; }
    static void Main()
    {
        bool test1 = FuncTest<Test>.FuncEqual(x => x.Bar, y => y.Bar),
            test2 = FuncTest<Test>.FuncEqual(x => x.Foo, y => y.Bar);
    }

}
// this only exists to make it easier to call, i.e. so that I can use FuncTest<T> with
// generic-type-inference; if you use the doubly-generic method, you need to specify
// both arguments, which is a pain...
static class FuncTest<TSource>
{
    public static bool FuncEqual<TValue>(
        Expression<Func<TSource, TValue>> x,
        Expression<Func<TSource, TValue>> y)
    {
        return FuncTest.FuncEqual<TSource, TValue>(x, y);
    }
}
static class FuncTest {
    public static bool FuncEqual<TSource, TValue>(
        Expression<Func<TSource,TValue>> x,
        Expression<Func<TSource,TValue>> y)
    {
        return ExpressionEqual(x, y);
    }
    private static bool ExpressionEqual(Expression x, Expression y)
    {
        // deal with the simple cases first...
        if (ReferenceEquals(x, y)) return true;
        if (x == null || y == null) return false;
        if (   x.NodeType != y.NodeType
            || x.Type != y.Type ) return false;

        switch (x.NodeType)
        {
            case ExpressionType.Lambda:
                return ExpressionEqual(((LambdaExpression)x).Body, ((LambdaExpression)y).Body);
            case ExpressionType.MemberAccess:
                MemberExpression mex = (MemberExpression)x, mey = (MemberExpression)y;
                return mex.Member == mey.Member; // should really test down-stream expression
            default:
                throw new NotImplementedException(x.NodeType.ToString());
        }
    }
}

A canonical solution would be great. In the meantime, I created an IEqualityComparer<Expression> version. This is rather a verbose implementation, so I created a gist for it.

It is intended to be a comprehensive abstract syntax tree comparer. To that end, it compares every expression type including expressions that aren't yet supported by C# like Try and Switch and Block. The only types it does not compare are Goto, Label, Loop and DebugInfo due to my limited knowledge of them.

You can specify whether and how names of parameters and lambdas should be compared, as well as how to handle ConstantExpression.

It tracks parameters positionally by context. Lambdas inside lambdas and catch block variable parameters are supported.

Tags:

C#

Lambda