Random Python dictionary key, weighted by values

Do you always know the total number of values in the dictionary? If so, this might be easy to do with the following algorithm, which can be used whenever you want to make a probabilistic selection of some items from an ordered list:

  1. Iterate over your list of keys.
  2. Generate a uniformly distributed random value between 0 and 1 (aka "roll the dice").
  3. Assuming that this key has N_VALS values associated with it and there are TOTAL_VALS total values in the entire dictionary, accept this key with a probability N_VALS / N_REMAINING, where N_REMAINING is the number of items left in the list.

This algorithm has the advantage of not having to generate any new lists, which is important if your dictionary is large. Your program is only paying for the loop over K keys to calculate the total, a another loop over the keys which will on average end halfway through, and whatever it costs to generate a random number between 0 and 1. Generating such a random number is a very common application in programming, so most languages have a fast implementation of such a function. In Python the random number generator a C implementation of the Mersenne Twister algorithm, which should be very fast. Additionally, the documentation claims that this implementation is thread-safe.

Here's the code. I'm sure that you can clean it up if you'd like to use more Pythonic features:

#!/usr/bin/python

import random

def select_weighted( d ):
   # calculate total
   total = 0
   for key in d:
      total = total + len(d[key])
   accept_prob = float( 1.0 / total )

   # pick a weighted value from d
   n_seen = 0
   for key in d:
      current_key = key
      for val in d[key]:
         dice_roll = random.random()
         accept_prob = float( 1.0 / ( total - n_seen ) )
         n_seen = n_seen + 1
         if dice_roll <= accept_prob:
            return current_key

dict = {
   'a': [1, 3, 2],
   'b': [6],
   'c': [0, 0]
}

counts = {}
for key in dict:
   counts[key] = 0

for s in range(1,100000):
   k = select_weighted(dict)
   counts[k] = counts[k] + 1

print counts

After running this 100 times, I get select keys this number of times:

{'a': 49801, 'c': 33548, 'b': 16650}

Those are fairly close to your expected values of:

{'a': 0.5, 'c': 0.33333333333333331, 'b': 0.16666666666666666}

Edit: Miles pointed out a serious error in my original implementation, which has since been corrected. Sorry about that!


This would work:

random.choice([k for k in d for x in d[k]])

Without constructing a new, possibly big list with repeated values:

def select_weighted(d):
   offset = random.randint(0, sum(d.itervalues())-1)
   for k, v in d.iteritems():
      if offset < v:
         return k
      offset -= v