Saturday, 12 March 2016

Struggling with the State Monad

After spending hours trying to find articles explaining the damn thing without resorting to "let's take this contrived example..." or "let's take a real life example, generating pseudo-random numbers" (seriously?!), I'm still left with frustration. Eventually I'd reached out to Kaloz, who pointed me to one of his earlier explorations of the topic. His example is here, but I'm still in the dark why is it any better than a simple foldLeft. In the code below I refactored slightly Kaloz's code to keep the generic part apart from the specific solutions. A second foldLeft-using function is provided to match the signature of the function using the State monad, although I don't see any additional value in that.

package org.bluecollar.scalaz
import scalaz.Scalaz._
import scalaz._
object StateMonadExamples extends App {
//Domain and test values
sealed trait Input
case object Coin extends Input
case object Turn extends Input
case class Machine(locked: Boolean, candies: Int, coins: Int)
private def applyInput(i: Input): (Machine) ⇒ Machine =
(m: Machine) => (i, m) match {
case (_, Machine(_, 0, _)) => m
case (Coin, Machine(false, _, _)) => m
case (Turn, Machine(true, _, _)) => m
case (Coin, Machine(true, candy, coin)) => Machine(false, candy, coin + 1)
case (Turn, Machine(false, candy, coin)) => Machine(true, candy - 1, coin)
}
val inputs = List(Coin, Turn, Coin, Turn, Coin, Turn, Coin, Turn)
val machine = Machine(true, 5, 10)
//Implementation with STATE MONAD starts =============
val state = scalaz.StateT.stateMonad[Machine]
def rules(i: Input): State[Machine, (Int, Int)] = for {
_ <- modify(applyInput(i))
m <- get
} yield (m.coins, m.candies)
def simulateMachine(inputs: List[Input]): State[Machine, (Int, Int)] = for {
_ <- state.sequence(inputs.map(rules))
m <- get[Machine]
} yield (m.coins, m.candies)
def simulationWithStateMonad(inputs: List[Input], machine: Machine) = simulateMachine(inputs)(machine)
//Implementation with STATE MONAD ends =============
//Implementation with foldLeft. A 2-(rather short)liner
def simulationWithFoldLeft(inputs: List[Input], machine: Machine) = inputs.foldLeft(machine) { (m, input) ⇒
applyInput(input)(m)
}
//Implementation with foldLeft, where the result is the same format
def simulationWithFoldLeft2(inputs: List[Input], machine: Machine) = {
def convert(m: Machine) = (m, (m.coins, m.candies))
inputs.foldLeft(convert(machine)) { (m, input) ⇒
convert(applyInput(input)(m._1))
}
}
//Test
println(simulationWithStateMonad(inputs, machine))
//(Machine(true,1,14),(14,1))
println(simulationWithFoldLeft(inputs, machine))
//Machine(true,1,14))
println(simulationWithFoldLeft2(inputs, machine))
//(Machine(true,1,14),(14,1))
}
Kaloz? What am I missing?