How to get element by Index in Spark RDD (Java)

This should be possible by first indexing the RDD. The transformation zipWithIndex provides a stable indexing, numbering each element in its original order.

Given: rdd = (a,b,c)

val withIndex = rdd.zipWithIndex // ((a,0),(b,1),(c,2))

To lookup an element by index, this form is not useful. First we need to use the index as key:

val indexKey = withIndex.map{case (k,v) => (v,k)}  //((0,a),(1,b),(2,c))

Now, it's possible to use the lookup action in PairRDD to find an element by key:

val b = indexKey.lookup(1) // Array(b)

If you're expecting to use lookup often on the same RDD, I'd recommend to cache the indexKey RDD to improve performance.

How to do this using the Java API is an exercise left for the reader.


I tried this class to fetch an item by index. First, when you construct new IndexedFetcher(rdd, itemClass), it counts the number of elements in each partition of the RDD. Then, when you call indexedFetcher.get(n), it runs a job on only the partition that contains that index.

Note that I needed to compile this using Java 1.7 instead of 1.8; as of Spark 1.1.0, the bundled org.objectweb.asm within com.esotericsoftware.reflectasm cannot read Java 1.8 classes yet (throws IllegalStateException when you try to runJob a Java 1.8 function).

import java.io.Serializable;

import org.apache.spark.SparkContext;
import org.apache.spark.TaskContext;
import org.apache.spark.rdd.RDD;

import scala.reflect.ClassTag;

public static class IndexedFetcher<E> implements Serializable {
    private static final long serialVersionUID = 1L;
    public final RDD<E> rdd;
    public Integer[] elementsPerPartitions;
    private Class<?> clazz;
    public IndexedFetcher(RDD<E> rdd, Class<?> clazz){
        this.rdd = rdd;
        this.clazz = clazz;
        SparkContext context = this.rdd.context();
        ClassTag<Integer> intClassTag = scala.reflect.ClassTag$.MODULE$.<Integer>apply(Integer.class);
        elementsPerPartitions = (Integer[]) context.<E, Integer>runJob(rdd, IndexedFetcher.<E>countFunction(), intClassTag);
    }
    public static class IteratorCountFunction<E> extends scala.runtime.AbstractFunction2<TaskContext, scala.collection.Iterator<E>, Integer> implements Serializable {
        private static final long serialVersionUID = 1L;
        @Override public Integer apply(TaskContext taskContext, scala.collection.Iterator<E> iterator) {
            int count = 0;
            while (iterator.hasNext()) {
                count++;
                iterator.next();
            }
            return count;
        }
    }
    static <E> scala.Function2<TaskContext, scala.collection.Iterator<E>, Integer> countFunction() {
        scala.Function2<TaskContext, scala.collection.Iterator<E>, Integer> function = new IteratorCountFunction<E>();
        return function;
    }
    public E get(long index) {
        long remaining = index;
        long totalCount = 0;
        for (int partition = 0; partition < elementsPerPartitions.length; partition++) {
            if (remaining < elementsPerPartitions[partition]) {
                return getWithinPartition(partition, remaining);
            }
            remaining -= elementsPerPartitions[partition];
            totalCount += elementsPerPartitions[partition];
        }
        throw new IllegalArgumentException(String.format("Get %d within RDD that has only %d elements", index, totalCount));
    }
    public static class FetchWithinPartitionFunction<E> extends scala.runtime.AbstractFunction2<TaskContext, scala.collection.Iterator<E>, E> implements Serializable {
        private static final long serialVersionUID = 1L;
        private final long indexWithinPartition;
        public FetchWithinPartitionFunction(long indexWithinPartition) {
            this.indexWithinPartition = indexWithinPartition;
        }
        @Override public E apply(TaskContext taskContext, scala.collection.Iterator<E> iterator) {
            int count = 0;
            while (iterator.hasNext()) {
                E element = iterator.next();
                if (count == indexWithinPartition)
                    return element;
                count++;
            }
            throw new IllegalArgumentException(String.format("Fetch %d within partition that has only %d elements", indexWithinPartition, count));
        }
    }
    public E getWithinPartition(int partition, long indexWithinPartition) {
        System.out.format("getWithinPartition(%d, %d)%n", partition, indexWithinPartition);
        SparkContext context = rdd.context();
        scala.Function2<TaskContext, scala.collection.Iterator<E>, E> function = new FetchWithinPartitionFunction<E>(indexWithinPartition);
        scala.collection.Seq<Object> partitions = new scala.collection.mutable.WrappedArray.ofInt(new int[] {partition});
        ClassTag<E> classTag = scala.reflect.ClassTag$.MODULE$.<E>apply(this.clazz);
        E[] result = (E[]) context.<E, E>runJob(rdd, function, partitions, true, classTag);
        return result[0];
    }
}