Nerde Nolzda

Persistent Segment Trees in Scala

I’ve always thought that persistent segment trees, also known as “functional segment trees” in Chinese, are more elegant than traditional ones, since I can essentially treat them as immutable objects. So recently, I decided to code an implementation in Scala, a functional language.

Without further ado, here’s the code:

object PersistentSegTree {
  def init(l: Int, r: Int): TreeNode = {
    val mid = (l + r) >>> 1
    if (l != r) Node(0, init(l, mid), init(mid + 1, r))
    else Leaf(0)
  }

  sealed trait TreeNode {
    val sum: Int

    def modify(pos: Int, v: Int, l: Int, r: Int): TreeNode

    def queryKthSmallest(old: TreeNode, k: Int, l: Int, r: Int): Int
  }

  case class Node(sum: Int, lc: TreeNode, rc: TreeNode) extends TreeNode {
    def modify(pos: Int, v: Int, l: Int, r: Int): TreeNode = {
      val mid = (l + r) >>> 1
      if (pos <= mid) Node(sum + v, lc.modify(pos, v, l, mid), rc)
      else Node(sum + v, lc, rc.modify(pos, v, mid + 1, r))
    }

    def queryKthSmallest(old: TreeNode, k: Int, l: Int, r: Int): Int = {
      val nOld = old.asInstanceOf[Node]
      val (lSum, mid) = (lc.sum - nOld.lc.sum, (l + r) / 2)
      if (k <= lSum) lc.queryKthSmallest(nOld.lc, k, l, mid)
      else rc.queryKthSmallest(nOld.rc, k - lSum, mid + 1, r)
    }
  }

  case class Leaf(sum: Int) extends TreeNode {
    def modify(pos: Int, v: Int, l: Int, r: Int): TreeNode = Leaf(sum + v)

    def queryKthSmallest(old: TreeNode, k: Int, l: Int, r: Int): Int = l
  }
}

object Main extends App {
  var Array(n, m) = io.StdIn.readLine().split(" ").map(_.toInt)
  val arr = io.StdIn.readLine().split(" ").map(_.toInt)
  // map the array to an id ranged 0 ~ n
  val dict = arr.distinct.sorted.zipWithIndex.toMap
  // build the tree
  val roots = (0 until n).foldLeft(List(PersistentSegTree.init(1, n))) { (lis, cur) =>
    lis.head.modify(dict(arr(cur)) + 1, 1, 1, n) :: lis
  }.reverse.toSeq
  while ({ m -= 1; m >= 0 }) {
    val Array(i, j, k) = io.StdIn.readLine().split(" ").map(_.toInt)
    println(roots(j).queryKthSmallest(roots(i - 1), k, 1, n))
  }
}

As long as you know how persistent segment trees work, the code should be pretty self-explanatory.

Once I finished the code, I submitted it to MKTHNUM on SPOJ for testing, and got a TLE. It turned out that the tree building process, after optimization, took about 0.6 ~ 0.7 seconds on my laptop (CPU: Intel i5-2520M), which, unfortunately, is a bit too long.

Recalling that the JVM might need some warmup due to JIT optimizations, I added a loop that ran the building process ~20 times before the benchmark, and sure enough, the time went down to 0.1 ~ 0.2 seconds.

Of course, it’s impossible to apply the method on online judges. As such, I can sort of understand why most OJs give Java programs a longer time limit.

Related Posts

0 comments

Post a comment

Send an email to comment@nerde.pw.