Conditional state monad expressions

Doh! I don't know why this didn't occur to me sooner. Sometimes just explaining your problem in simpler terms forces you to look at it afresh, I guess...

One possibility is to handle sequences of transitions, so that the next task is only undertaken if the current task succeeds.

// Run a sequence of transitions, until one fails.
def untilFailure[M](ts: List[Transition[M]]): Transition[M] = State {s =>
  ts match {

    // If we have an empty list, that's an error. (Cannot report a success value.)
    case Nil => (s, Failure(new RuntimeException("Empty transition sequence")))

    // If there's only one transition left, perform it and return the result.
    case t :: Nil => t.run(s).value

    // Otherwise, we have more than one transition remaining.
    //
    // Run the next transition. If it fails, report the failure, otherwise repeat
    // for the tail.
    case t :: tt => {
      val r = t.run(s).value
      if(r._2.isFailure) r
      else untilFailure(tt).run(r._1).value
    }
  }
}

We can then implement counterManip as a sequence.

val counterManip: Transition[Unit] = for {
  r <- untilFailure(List(decrement, increment, increment, increment))
} yield r

which gives the correct results:

scala> counterManip.run(Counter(0)).value
res0: (Counter, scala.util.Try[Unit]) = (Counter(0),Failure(java.lang.ArithmeticException: Attempt to make count negative failed))

scala> counterManip.run(Counter(1)).value
res1: (Counter, scala.util.Try[Unit]) = (Counter(3),Success(()))

scala> counterManip.run(Counter(Int.MaxValue - 2)).value
res2: (Counter, scala.util.Try[Unit]) = (Counter(2147483647),Success(()))

scala> counterManip.run(Counter(Int.MaxValue - 1)).value
res3: (Counter, scala.util.Try[Unit]) = (Counter(2147483647),Failure(java.lang.ArithmeticException: Attempt to overflow counter failed))

scala> counterManip.run(Counter(Int.MaxValue)).value
res4: (Counter, scala.util.Try[Unit]) = (Counter(2147483647),Failure(java.lang.ArithmeticException: Attempt to overflow counter failed))

The downside is that all of the transitions need to have a return value in common (unless you're OK with Any result).


From what I understand, your computation has two states, which you can define as an ADT

sealed trait CompState[A]
case class Ok[A](value: A) extends CompState[A]
case class Err[A](lastValue: A, cause: Exception) extends CompState[A]

The next step you can take is to define an update method for CompState, to encapsulate your logic of what should happen when chaining the computations.

def update(f: A => A): CompState[A] = this match {
  case Ok(a) => 
    try Ok(f(a))
    catch { case e: Exception => Err(a, e) }
  case Err(a, e) => Err(a, e)
}

From there, redefine

type Transition[M] = State[CompState[Counter], M]

// Operation to increment a counter.
// note: using `State.modify` instead of `.apply`
val increment: Transition[Unit] = State.modify { cs =>
  // use the new `update` method to take advantage of your chaining semantics
  cs update{ c =>
    // If the count is at its maximum, incrementing it must fail.
    if(c.count == Int.MaxValue) {
      throw new ArithmeticException("Attempt to overflow counter failed")
    }

    // Otherwise, increment the count and indicate success.
    else c.copy(count = c.count + 1)
  }
}

// Operation to decrement a counter.
val decrement: Transition[Unit] = State.modify { cs =>
  cs update { c =>
    // If the count is zero, decrementing it must fail.
    if(c.count == 0) {
      throw new ArithmeticException("Attempt to make count negative failed")
    }

    // Otherwise, decrement the count and indicate success.
    else c.copy(count = c.count - 1)
  }
}

Note that in the updated increment/decrement transitions above, I used State.modify, which changes the state, but does not generate a result. It looks like the "idiomatic" way to obtain the current state at the end of your transitions is to use State.get, i.e.

val counterManip: State[CompState[Counter], CompState[Counter]] = for {
    _ <- decrement
    _ <- increment
    _ <- increment
    _ <- increment
    r <- State.get
} yield r

And you can run this and discard the final state using the runA helper, i.e.

counterManip.runA(Ok(Counter(0))).value
// Err(Counter(0),java.lang.ArithmeticException: Attempt to make count negative failed)