Complicated summary function -- is it possible to solve with R data.table package?

Great question!! The example data is especially well constructed and well explained.

First I'll show this answer, then I'll explain it step by step.

> ids = 1:3   # or from the data: unique(ds$ID)
> pos = 1:6   # or from the data: unique(ds$Pos)
> setkey(ds,ID,Pos)

> ds[CJ(ids,pos), roll=-Inf, nomatch=0][, .N, by=Pos]
   Pos N
1:   1 3
2:   2 3
3:   3 3
4:   4 3
5:   5 2
6:   6 1
> 

That should also be very efficient on your large data.

Step by step

First I tried a Cross Join (CJ); i.e., for each train for each position.

> ds[CJ(ids,pos)]
    ID Pos  Obs
 1:  1   1 1.50
 2:  1   2   NA
 3:  1   3 2.50
 4:  1   4   NA
 5:  1   5 0.00
 6:  1   6 1.25
 7:  2   1   NA
 8:  2   2 1.45
 9:  2   3 1.50
10:  2   4   NA
11:  2   5 2.50
12:  2   6   NA
13:  3   1   NA
14:  3   2 0.00
15:  3   3 1.25
16:  3   4 1.45
17:  3   5   NA
18:  3   6   NA

I see 6 rows per train. I see 3 trains. I've got 18 rows as I expected. I see NA where that train wasn't observed. Good. Check. The cross join seems to be working. Let's now build the query up.

You wrote if a train is observed at position n it must have passed previous positions. Immediately I'm thinking roll. Let's try it.

ds[CJ(ids,pos), roll=TRUE]
    ID Pos  Obs
 1:  1   1 1.50
 2:  1   2 1.50
 3:  1   3 2.50
 4:  1   4 2.50
 5:  1   5 0.00
 6:  1   6 1.25
 7:  2   1   NA
 8:  2   2 1.45
 9:  2   3 1.50
10:  2   4 1.50
11:  2   5 2.50
12:  2   6 2.50
13:  3   1   NA
14:  3   2 0.00
15:  3   3 1.25
16:  3   4 1.45
17:  3   5 1.45
18:  3   6 1.45

Hm. That rolled the observations forwards for each train. It left some NA at position 1 for trains 2 and 3, but you said if a train is observed at position 2 it must have passed position 1. It also rolled the last observation for trains 2 and 3 forward to position 6, but you said trains might explode. So, we want to roll backwards! That's roll=-Inf. It's a complicated -Inf because you can also control how far to roll backwards, but we don't need that for this question; we just want to roll backwards indefinitely. Let's try roll=-Inf and see what happens.

> ds[CJ(ids,pos), roll=-Inf]
    ID Pos  Obs
 1:  1   1 1.50
 2:  1   2 2.50
 3:  1   3 2.50
 4:  1   4 0.00
 5:  1   5 0.00
 6:  1   6 1.25
 7:  2   1 1.45
 8:  2   2 1.45
 9:  2   3 1.50
10:  2   4 2.50
11:  2   5 2.50
12:  2   6   NA
13:  3   1 0.00
14:  3   2 0.00
15:  3   3 1.25
16:  3   4 1.45
17:  3   5   NA
18:  3   6   NA

That's better. Almost there. All we need to do now is count. But, those pesky NA are there after trains 2 and 3 exploded. Let's remove them.

> ds[CJ(ids,pos), roll=-Inf, nomatch=0]
    ID Pos  Obs
 1:  1   1 1.50
 2:  1   2 2.50
 3:  1   3 2.50
 4:  1   4 0.00
 5:  1   5 0.00
 6:  1   6 1.25
 7:  2   1 1.45
 8:  2   2 1.45
 9:  2   3 1.50
10:  2   4 2.50
11:  2   5 2.50
12:  3   1 0.00
13:  3   2 0.00
14:  3   3 1.25
15:  3   4 1.45

Btw, data.table likes as much as possible to be inside one single DT[...] as that's how it optimizes the query. Internally, it doesn't create the NA and then remove them; it never creates the NA in the first place. This concept is important for efficiency.

Finally, all we have to do is count. We can just tack this on the end as a compound query.

> ds[CJ(ids,pos), roll=-Inf, nomatch=0][, .N, by=Pos]
   Pos N
1:   1 3
2:   2 3
3:   3 3
4:   4 3
5:   5 2
6:   6 1

data.table sounds like an excellent solution. From the way the data are ordered one could find the maximum of each train with

maxPos = ds$Pos[!duplicated(ds$ID, fromLast=TRUE)]

Then tabulate the trains that reach that position

nAtMax = tabulate(maxPos)

and calculate the cumulative sum of trains at each position, counting from the end

rev(cumsum(rev(nAtMax)))
## [1] 3 3 3 3 2 1

I think this will be quite fast for large data, though not entirely memory efficient.


You can try as below. I have purposefully split it into many step solution for better understanding. You can probably combine all of them into one step as well by just chaining [].

The logic here is that first we find final position for each ID. Then we aggregate data to find count of IDs for each Final Position. Since all IDs for Final Position 6 should also be counted for Final position 5, we use cumsum to add all higher ID counts to their respective lower IDs.

ds2 <- ds[, list(FinalPos=max(Pos)), by=ID]

ds2 
##    ID FinalPos
## 1:  1        6
## 2:  2        5
## 3:  3        4

ds3 <- ds2[ , list(Count = length(ID)), by = FinalPos][order(FinalPos, decreasing=TRUE), list(FinalPos, Count = cumsum(Count))]

ds3
##    FinalPos Count
## 1:        4     3
## 2:        5     2
## 3:        6     1

setkey(ds3, FinalPos)

ds3[J(c(1:6)), roll = 'nearest']

##    FinalPos Count
## 1:        1     3
## 2:        2     3
## 3:        3     3
## 4:        4     3
## 5:        5     2
## 6:        6     1

Tags:

R

Data.Table