Count balanced binary strings matching any of a set of masks

ruby, pretty fast, but it depends upon the input

Now speed-up by a factor of 2~2.5 by switching from strings to integers.

Usage:

cat <input> | ruby this.script.rb

Eg.

mad_gaksha@madlab ~/tmp $ ruby c50138.rb < c50138.inp2
number of matches: 298208861472
took 0.05726237 s

The number of matches for a single mask a readily calculated by the binomial coefficient. So for example 122020 needs 3 2s filled, 1 0 and 2 1. Thus there are nCr(3,2)=nCr(3,1)=3!/(2!*1!)=3 different binary strings matching this mask.

An intersection between n masks m_1, m_2, ... m_n is a mask q, such that a binary string s matches q only iff it matches all masks m_i.

If we take two masks m_1 and m_2, its intersection is easily computed. Just set m_1[i]=m_2[i] if m_1[i]==2. The intersection between 122020 and 111222 is 111020:

122020 (matched by 3 strings, 111000 110010 101010)
111222 (matched by 1 string, 111000)
111020 (matched by 1 string, 111000)

The two individual masks are matched by 3+1=4 strings, the interesection mask is matched by one string, thus there are 3+1-1=3 unique strings matching one or both masks.

I'll call N(m_1,m_2,...) the number of strings matched all m_i. Applying the same logic as above, we can compute the number of unique strings matched by at least one mask, given by the inclusion exclusion principle and see below as well, like this:

N(m_1) + N(m_2) + ... + N(m_n) - N(m_1,m_2) - ... - N(m_n-1,m_n) + N(m_1,m_2,m_3) + N(m_1,m_2,m_4) + ... N(m_n-2,m_n-1,m_n) - N(m_1,m_2,m_3,m_4) -+ ...

There are many, many, many combinations of taking, say 30 masks out of 200.

So this solution makes the assumption that not many high-order intersections of the input masks exist, ie. most n-tuples of n>2 masks will not have any common matches.

Use the code here, the code at ideone may be out-dated.

  • Test case 1: number of matches: 6
  • Test case 2: number of matches: 184756
  • Test case 3: number of matches: 298208861472
  • Test case 4: number of matches: 5

I added a function remove_duplicates that can be used to pre-process the input and delete masks m_i such that all strings that match it also match another mask m_j.,For the current input, this actually takes longer as there are no such masks (or not many), so the function isn't applied to the data yet in the code below.

Code:

# factorial table
FAC = [1]
def gen_fac(n)
  n.times do |i|
    FAC << FAC[i]*(i+1)
  end
end

# generates a mask such that it is matched by each string that matches m and n
def diff_mask(m,n)
  (0..m.size-1).map do |i|
    c1 = m[i]
    c2 = n[i]
    c1^c2==1 ? break : c1&c2
  end
end

# counts the number of possible balanced strings matching the mask
def count_mask(m)
  n = m.size/2
  c0 = n-m.count(0)
  c1 = n-m.count(1)
  if c0<0 || c1<0
    0
  else
    FAC[c0+c1]/(FAC[c0]*FAC[c1])
  end
end

# removes masks contained in another
def remove_duplicates(m)
  m.each do |x|
    s = x.join
    m.delete_if do |y|
      r = /\A#{s.gsub(?3,?.)}\Z/
      (!x.equal?(y) && y =~ r) ? true : false
    end
  end
end

#intersection masks of cn masks from m.size masks
def mask_diff_combinations(m,n=1,s=m.size,diff1=[3]*m[0].size,j=-1,&b)
  (j+1..s-1).each do |i|
    diff2 = diff_mask(diff1,m[i])
    if diff2
      mask_diff_combinations(m,n+1,s,diff2,i,&b) if n<s
      yield diff2,n
    end
  end
end

# counts the number of balanced strings matched by at least one mask
def count_n_masks(m)
  sum = 0
  mask_diff_combinations(m) do |mask,i|
    sum += i%2==1 ? count_mask(mask) : -count_mask(mask)
  end
  sum
end

time = Time.now

# parse input
d = STDIN.each_line.map do |line|
  line.chomp.strip.gsub('2','3')
end
d.delete_if(&:empty?)
d.shift
d.map!{|x|x.chars.map(&:to_i)}

# generate factorial table
gen_fac([d.size,d[0].size].max+1)

# count masks
puts "number of matches: #{count_n_masks(d)}"
puts "took #{Time.now-time} s"

This is called the inclusion exclusion principle, but before somebody had pointed me to it I had my own proof, so here it goes. Doing something yourself feels great though.

Let us consider the case of 2 masks, call then 0 and 1, first. We take every balanced binary string and classify it according to which mask(s) it matches. c0 is the number of those that match only mask 0, c1 the nunber of those that match only 1, c01 those that match mask 0 and 1.

Let s0 be the number sum of the number of matches for each mask (they may overlap). Let s1 be the sum of the number of matches for each pair (2-combination) of masks. Let s_i be the sum of the number of matches for each (i+1) combination of masks. The number of matches of n-masks is the number of binary strings matching all masks.

If there are n masks, the desired output is the sum of all c's, ie. c = c0+...+cn+c01+c02+...+c(n-2)(n-1)+c012+...+c(n-3)(n-2)(n-1)+...+c0123...(n-2)(n-1). What the program computes is the alternating sum of all s's, ie. s = s_0-s_1+s_2-+...+-s_(n-1). We wish to proof that s==c.

n=1 is obvious. Consider n=2. Counting all matches of mask 0 gives c0+c01 (the number of strings matching only 0 + those matching both 0 and 1), counting all matches of 1 gives c1+c02. We can illustrate this as follows:

0: c0 c01
1: c1 c10

By definition, s0 = c0 + c1 + c12. s1 is the sum number of matches of each 2-combination of [0,1], ie. all uniqye c_ijs. Keep in mind that c01=c10.

s0 = c0 + c1 + 2 c01
s1 = c01
s = s0 - s1 = c0 + c1 + c01 = c

Thus s=c for n=2.

Now consider n=3.

0  : c0 + c01 + c02 + c012
1  : c1 + c01 + c12 + c012
2  : c2 + c12 + c02 + c012
01 : c01 + c012
02 : c02 + c012
12 : c12 + c012
012: c012

s0 = c0 + c1 + c2 + 2 (c01+c02+c03) + 3 c012
s1 = c01 + c02 + c12 + 3 c012
s2 = c012

s0 = c__0 + 2 c__1 + 3 c__2
s1 =          c__1 + 3 c__2
s2 =                   c__2

s = s0 - s1 + s2 = ... = c0 + c1 + c2 + c01 + c02 + c03 + c012 = c__0 + c__1 + c__2 = c

Thus s=c for n=3. c__i represents the of all cs with (i+1) indices, eg c__1 = c01 for n=2 and c__1 = c01 + c02 + c12 for n==3.

For n=4, a pattern starts to emerge:

0:   c0 + c01 + c02 + c03 + c012 + c013 + c023 + c0123
1:   c1 + c01 + c12 + c13 + c102 + c103 + c123 + c0123
2:   c2 + c02 + c12 + c23 + c201 + c203 + c213 + c0123
3:   c3 + c03 + c13 + c23 + c301 + c302 + c312 + c0123

01:  c01 + c012 + c013 + c0123
02:  c02 + c012 + c023 + c0123
03:  c03 + c013 + c023 + c0123
12:  c11 + c012 + c123 + c0123
13:  c13 + c013 + c123 + c0123
23:  c23 + c023 + c123 + c0123

012:  c012 + c0123
013:  c013 + c0123
023:  c023 + c0123
123:  c123 + c0123

0123: c0123

s0 = c__0 + 2 c__1 + 3 c__2 + 4 c__3
s1 =          c__1 + 3 c__2 + 6 c__3
s2 =                   c__2 + 4 c__3
s3 =                            c__3

s = s0 - s1 + s2 - s3 = c__0 + c__1 + c__2 + c__3 = c

Thus s==c for n=4.

In general, we get binomial coefficients like this (↓ is i, → is j):

   0  1  2  3  4  5  6  .  .  .

0  1  2  3  4  5  6  7  .  .  .
1     1  3  6  10 15 21 .  .  .
2        1  4  10 20 35 .  .  .
3           1  5  15 35 .  .  .
4              1  6  21 .  .  .
5                 1  7  .  .  .
6                    1  .  .  . 
.                       .
.                          .
.                             .

To see this, consider that for some i and j, there are:

  • x = ncr(n,i+1): combinations C for the intersection of (i+1) mask out of n
  • y = ncr(n-i-1,j-i): for each combination C above, there are y different combinations for the intersection of (j+2) masks out of those containing C
  • z = ncr(n,j+1): different combinations for the intersection of (j+1) masks out of n

As that may sound confusing, here's the defintion applied to an example. For i=1, j=2, n=4, it looks like this (cf. above):

01:  c01 + c012 + c013 + c0123
02:  c02 + c012 + c023 + c0123
03:  c03 + c013 + c023 + c0123
12:  c11 + c012 + c123 + c0123
13:  c13 + c013 + c123 + c0123
23:  c23 + c023 + c123 + c0123

So here x=6 (01, 02, 03, 12, 13, 23), y=2 (two c's with three indices for each combination), z=4 (c012, c013, c023, c123).

In total, there are x*y coefficients c with (j+1) indices, and there are z different ones, so each occurs x*y/z times, which we call the coefficient k_ij. By simple algebra, we get k_ij = ncr(n,i+1) ncr(n-i-1,j-i) / ncr(n,j+1) = ncr(j+1,i+1).

So the index is given by k_ij = nCr(j+1,i+1) If you recall all the defintions, all we need to show is that the alternating sum of each column gives 1.

The alternating sum s0 - s1 + s2 - s3 +- ... +- s(n-1) can thus be expressed as:

s_j = c__j * ∑[(-1)^(i+j) k_ij] for i=0..n-1
     = c__j * ∑[(-1)^(i+j) nCr(j+1,i+1)] for i=0..n-1
     = c__j * ∑[(-1)^(i+j) nCr(j+1,i)]{i=0..n} - (-1)^0 nCr(j+1,0)
     = (-1)^j c__j

s   = ∑[(-1)^j  s_j] for j = 0..n-1
    = ∑[(-1)^j (-1)^j c__j)] for j=0..n-1
    = ∑[c__j] for j=0..n-1
    = c

Thus s=c for all n=1,2,3,...


C

If you're not on Linux, or otherwise having trouble compiling, you should probably remove the timing code (clock_gettime).

#include <stdio.h>
#include <stdlib.h>
#include <time.h>

long int binomial(int n, int m) {
  if(m > n/2) {
    m = n - m;
  }
  int i;
  long int result = 1;
  for(i = 0; i < m; i++) {
    result *= n - i;
    result /= i + 1;
  }
  return result;
}

typedef struct isct {
  char *mask;
  int p_len;
  int *p;
} isct;

long int mask_intersect(char *mask1, char *mask2, char *mask_dest, int n) {

  int zero_count = 0;
  int any_count = 0;
  int i;
  for(i = 0; i < n; i++) {
    if(mask1[i] == '2') {
      mask_dest[i] = mask2[i];
    } else if (mask2[i] == '2') {
      mask_dest[i] = mask1[i];
    } else if (mask1[i] == mask2[i]) {
      mask_dest[i] = mask1[i];
    } else {
      return 0;
    }
    if(mask_dest[i] == '2') {
      any_count++;
    } else if (mask_dest[i] == '0') {
      zero_count++;
    }
  }
  if(zero_count > n/2 || any_count + zero_count < n/2) {
    return 0;
  }
  return binomial(any_count, n/2 - zero_count);
}

int main() {
  
  struct timespec start, end;
  clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &start);
  
  int n;
  scanf("%d", &n);
  int nn = 2 * n;

  int m = 0;
  int m_max = 1024;

  char **masks = malloc(m_max * sizeof(char *));

  while(1) {
    masks[m] = malloc(nn + 1);
    if (scanf("%s", masks[m]) == EOF) {
      break;
    }
    m++;
    if (m == m_max) {
      m_max *= 2;
      masks = realloc(masks, m_max * sizeof(char *));
    }
  }

  int i = 1;
  int i_max = 1024 * 128;

  isct *iscts = malloc(i_max * sizeof(isct));

  iscts[0].mask = malloc(nn);
  iscts[0].p = malloc(m * sizeof(int));

  int j;
  for(j = 0; j < nn; j++) {
    iscts[0].mask[j] = '2';
  }
  for(j = 0; j < m; j++) {
    iscts[0].p[j] = j;
  }
  iscts[0].p_len = m;

  int i_start = 0;
  int i_end = 1;
  int sign = 1;

  long int total = 0;
  
  int mask_bin_len = 1024 * 1024;
  char* mask_bin = malloc(mask_bin_len);
  int mask_bin_count = 0;
  
  int p_bin_len = 1024 * 128;
  int* p_bin = malloc(p_bin_len * sizeof(int));
  int p_bin_count = 0;
  
  
  while (i_end > i_start) {
    for (j = i_start; j < i_end; j++) {
      if (i + iscts[j].p_len > i_max) {
        i_max *= 2;
        iscts = realloc(iscts, i_max * sizeof(isct));
      }
      isct *isct_orig = iscts + j;
      int x;
      int x_len = 0;
      int i0 = i;
      for (x = 0; x < isct_orig->p_len; x++) {
        if(mask_bin_count + nn > mask_bin_len) {
          mask_bin_len *= 2;
          mask_bin = malloc(mask_bin_len);
          mask_bin_count = 0;
        }
        iscts[i].mask = mask_bin + mask_bin_count;
        mask_bin_count += nn;
        long int count =
            mask_intersect(isct_orig->mask, masks[isct_orig->p[x]], iscts[i].mask, nn);
        if (count > 0) {
          isct_orig->p[x_len] = isct_orig->p[x];
          i++;
          x_len++;
          total += sign * count;
        }
      }
      for (x = 0; x < x_len; x++) {
        int p_len = x_len - x - 1;
        iscts[i0 + x].p_len = p_len;
        if(p_bin_count + p_len > p_bin_len) {
          p_bin_len *= 2;
          p_bin = malloc(p_bin_len * sizeof(int));
          p_bin_count = 0;
        }
        iscts[i0 + x].p = p_bin + p_bin_count;
        p_bin_count += p_len;
        int y;
        for (y = 0; y < p_len; y++) {
          iscts[i0 + x].p[y] = isct_orig->p[x + y + 1];
        }
      }
    }

    sign *= -1;
    i_start = i_end;
    i_end = i;

  }
  
  printf("%lld\n", total);
  
  clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &end);
  
  int seconds = end.tv_sec - start.tv_sec;
  long nanoseconds = end.tv_nsec - start.tv_nsec;
  if(nanoseconds < 0) {
    nanoseconds += 1000000000;
    seconds--;
  }
  
  printf("%d.%09lds\n", seconds, nanoseconds);
  return 0;
}

Example cases:

robert@unity:~/c/se-mask$ gcc -O3 se-mask.c -lrt -o se-mask
robert@unity:~/c/se-mask$ head testcase-long
30
210211202222222211222112102111220022202222210122222212220210
010222222120210221012002220212102220002222221122222220022212
111022212212022222222220111120022120122121022212211202022010
022121221020201212200211120100202222212222122222102220020212
112200102110212002122122011102201021222222120200211222002220
121102222220221210220212202012110201021201200010222200221002
022220200201222002020110122212211202112011102220212120221111
012220222200211200020022121202212222022012201201210222200212
210211221022122020011220202222010222011101220121102101200122
robert@unity:~/c/se-mask$ ./se-mask < testcase-long
298208861472
0.001615834s
robert@unity:~/c/se-mask$ head testcase-hard
8
0222222222222222
1222222222222222
2022222222222222
2122222222222222
2202222222222222
2212222222222222
2220222222222222
2221222222222222
2222022222222222
robert@unity:~/c/se-mask$ ./se-mask < testcase-hard
12870
3.041261458s
robert@unity:~/c/se-mask$ 

(Times are for an i7-4770K CPU at 4.1 GHz.) Be careful, testcase-hard uses around 3-4 GB of memory.

This is pretty much an implementation of inclusion-exclusion method blutorange came up with, but done so that it will handle intersections of any depth. The code as written is spending a lot of time on memory allocation, and will get even faster once I optimize the memory management.

I shaved off around 25% on testcase-hard, but the performance on the original (testcase-long) is pretty much unchanged, since not much memory allocation is going on there. I'm going to tune a bit more before I call it: I think I might be able to get a 25%-50% improvement on testcase-long too.

Mathematica

Once I noticed this was a #SAT problem, I realized I could use Mathematica's built-in SatisfiabilityCount:

AbsoluteTiming[
 (* download test case *)
 input = Map[FromDigits, 
   Characters[
    Rest[StringSplit[
      Import["http://pastebin.com/raw.php?i=2Dg7gbfV", 
       "Text"]]]], {2}]; n = Length[First[input]];
 (* create boolean function *)
 bool = BooleanCountingFunction[{n/2}, n] @@ Array[x, n] && 
   Or @@ Table[
     And @@ MapIndexed[# == 2 || Xor[# == 1, x[First[#2]]] &, i], {i, 
      input}];
 (* count instances *)
 SatisfiabilityCount[bool, Array[x, n]]
]

Output:

{1.296944, 298208861472}

That's 298,208,861,472 masks in 1.3 seconds (i7-3517U @ 1.9 GHz), including the time spent downloading the test case from pastebin.