Categories
Tech

Scala Saturday – Pattern Matching, Part 4: Extractors

This week we take a look under the hood, as it were, of pattern matching in Scala. An extractor is any object that defines an unapply() method that Scala’s match expressions can use to evaluate whether the input value is a match or not.

One reason case classes are so convenient for pattern matching is because they define an unapply() method along with the other handy tools Scala gives you when you define a case class. You don’t have to define a case class to get an unapply() method, though. You can define one yourself and enjoy the benefits.

Boolean Extractors

Extractors allow you to give a readable name to a case that effectively communicates the nature of the match, but hide some of the clutter that can threaten the readability of your code.

A classic example is determining whether an integer is even or odd. Now you could do that this way:

n % 2 match {
  case 0 => n -> "even"
  case _ => n -> "odd"
}

Now n % 2 is a simple expression, and its use common enough in computer science that most of us recognize right away, whenever we see it, “Oh, right, even or odd.” But there is just the slightest context switch between thinking mathematically and thinking conceptually, i.e., determining intent. And that context switch slows us down. A more complex expression can slow us down even more.

Compare the above (admittedly simple) match expression to what you can do with a couple of extractors. First, define the pair of extractors this way:

object Even {
  def unapply(n: Int) = n % 2 == 0
}
object Odd {
  def unapply(n: Int) = n % 2 == 1
}

This illustrates one way you can create an extractor: Define unapply() so that it returns a Boolean. A true value indicates a match, false a failure.

Now after defining the extractors, you can use them in a match expression like this:

def oddOrEven(n: Int) = {
  n match {
    case Even() => n -> "even"
    case Odd() => n -> "odd"
  }
}

(1 to 10) map oddOrEven foreach println

// (1,odd)
// (2,even)
// (3,odd)
// (4,even)
// (5,odd)
// (6,even)
// (7,odd)
// (8,even)
// (9,odd)
// (10,even)

Look at the difference in readability. There is no context switch. It reads much smoother than the original match expression. You can see that oddOrEven takes a value, n, and returns a result based on whether n is even or odd. You don’t have to leave the realm of the conceptual to think mathematically and then turn right back around to think conceptually again.

A point of emphasis: When building a match expression, you need to cover all the bases and define a return condition for every case. That is easy for the even/odd test: There are only two cases.

But what if there are several cases? Take the wavelengths of colors in the spectrum of visible light. Each color corresponds to a range of wavelengths:

Color Wavelength Ranges
Color Wavelength
Red 620–750 nm
Orange 590–620 nm
Yellow 570–590 nm
Green 495–570 nm
Blue 450–495 nm
Violet 380–450 nm

No problem, right? Just define a quick little set of extractors:

object Red {
  def unapply(λ: Int) = 620 <= λ && λ < 750
}
object Orange {
  def unapply(λ: Int) = 590 <= λ && λ < 620
}
object Yellow {
  def unapply(λ: Int) = 570 <= λ && λ < 590
}
object Green {
  def unapply(λ: Int) = 495 <= λ && λ < 570
}
object Blue {
  def unapply(λ: Int) = 450 <= λ && λ < 495
}
object Violet {
  def unapply(λ: Int) = 380 <= λ && λ < 450
}

Then put those extractors to use in a function:

def colorOfLight(λ: Int) = {
  λ match {
    case Red() => s"$λ nm" -> "red"
    case Orange() => s"$λ nm" -> "orange"
    case Yellow() => s"$λ nm" -> "yellow"
    case Green() => s"$λ nm" -> "green"
    case Blue() => s"$λ nm" -> "blue"
    case Violet() => s"$λ nm" -> "violet"
  }
}

Now this line should run like a charm, right?

List(800,700,600,580,500,475,400,350)
  .map(colorOfLight)
  .foreach(println)

// Exception in thread "main" scala.MatchError: 
//   800 (of class java.lang.Integer)
// ...

Whoa, what happened? There is radiation outside the visible spectrum; λ could be greater than 750 nm (as it is in this case of 800 nm) or less than 380 nm. You therefore need a catch-all case to cover the values that are outside the explicit cases:

def colorOfLight(λ: Int) = {
  λ match {
    case Red() => s"$λ nm" -> "red"
    case Orange() => s"$λ nm" -> "orange"
    case Yellow() => s"$λ nm" -> "yellow"
    case Green() => s"$λ nm" -> "green"
    case Blue() => s"$λ nm" -> "blue"
    case Violet() => s"$λ nm" -> "violet"
    case _ => s"$λ nm" -> "invisible"
  }
}

Now your little three-liner really does run like a charm:

List(800,700,600,580,500,475,400,350)
  .map(colorOfLight)
  .foreach(println)

// (800 nm,invisible)
// (700 nm,red)
// (600 nm,orange)
// (580 nm,yellow)
// (500 nm,green)
// (475 nm,blue)
// (400 nm,violet)
// (350 nm,invisible)

Option Extractors

Another way to indicate a match with an extractor is to return an Option. If the input meets the criteria sought, you return the case name in a Some. If not, you return a None. Furthermore, you can capture the matching value or a collection of values based on calculations the extractor performs on the input value.

If you just need to capture one value, return an Option[A]. You could modify the even/odd case above so that you capture the input variable in another variable name:

object Even {
  def unapply(n: Int) = if (n % 2 == 0) Option(n) else None
}
object Odd {
  def unapply(n: Int) = if (n % 2 == 1) Option(n) else None
}

def oddOrEven(n: Int) = {
  n match {
    case Even(m) => m -> "even"
    case Odd(m) => m -> "odd"
  }
}

(1 to 10) map oddOrEven foreach println

// (1,odd)
// (2,even)
// (3,odd)
// (4,even)
// (5,odd)
// (6,even)
// (7,odd)
// (8,even)
// (9,odd)
// (10,even)

To capture multiple values, return an Option that contains a tuple of the number of values you want to capture.

Perhaps you are a clerk in a department store. You recommend that customers who are at least 6′ (72 inches) tall and have at least a 40-inch waist go to the Big & Tall section. Others you greet according to their proportions.

First, define a Measurements type with height and waist properties:

class Measurements(val height: Int, val waist: Int)

Then define three extractors:

  1. one to detect whether a customer meets the “big” criterion,
  2. a second to detect whether he meets the “tall” criterion, and
  3. a third to detect whether he meets both criteria.
object Big {
  def unapply(m: Measurements) =
    if (m.waist >= 40) Some(m.waist) else None
}
object Tall {
  def unapply(m: Measurements) =
    if (m.height >= 72) Some(m.height) else None
}
object BigAndTall {
  def unapply(m: Measurements) =
    (Big.unapply(m), Tall.unapply(m)) match {
      case (Some(w), Some(h)) => Some(w,h)
      case _ => None
    }
}

Now that you have the extractors defined, you can use them in a match expression that extracts the waist and height on a successful match:

def sizeUp(m: Measurements) = {
  m match {
    case BigAndTall(w,h) =>
      s"$w-inch waist and $h inches tall: " +
        "Let me show you to our big & tall section"
    case Tall(h) => s"$h inches tall: How's the weather up there?"
    case Big(w) => s"$w-inch waist: Big fella, ain'tcha?"
    case _ => "How may I help you?"
  }
}

Now you run some customers through the sizeUp function:

  val me = new Measurements(76, 36)
  val shrimp = new Measurements(58, 28)
  val hoss = new Measurements(80, 46)
  val tubby = new Measurements(63, 42)

  List(me, shrimp, hoss, tubby)
    .map(sizeUp)
    .foreach(println)
// 76 inches tall: How's the weather up there?
// How may I help you?
// 46-inch waist and 80 inches tall: 
//   Let me show you to our big & tall section
// 42-inch waist: Big fella, ain'tcha?

Sequence Extractors

Finally, you can extract a variable number of values and match only on elements that meet your conditions while ignoring the rest. Sequence extractors define an unapplySeq() method rather than unapply(). The unapplySeq() method must return an Option[Seq[A]].

You could build an extractor that gets the prime factors of an integer:

object Factors {
  def unapplySeq(n: Int): Option[Seq[Int]] = {
    @tailrec
    def go(factors: List[Int], candidates: Seq[Int]): List[Int] = {
      if (candidates.isEmpty) {
        factors
      } else {
        val head = candidates.head
        val tail = candidates.tail
        if (n % head == 0) {
          go(head :: factors, tail)
        } else {
          go(factors, tail.filter(_ % head != 0))
        }
      }
    }

    val factors = n :: go(List(1), (2 to (n/2)).toSeq)
    if (factors.isEmpty) None else Some(factors.reverse)
  }
}

I won’t explain the (rather brute force) factorization method above. Suffice it to say that it returns the prime factors in order. For example, for 15, it returns {1,3,5,15}. Now let’s say that we want to match on numbers that are divisible by three, but not two. That means we’re looking for factor sets that follow this pattern: {1,3,…}. Here is how you use Factors to match that pattern:

def divBy3Not2(n: Int) = n match {
  case Factors(1,3,_*) =>
    s"$n: Divisible by three, but not two"
  case _ => n.toString
}

The _* wildcard tells Scala that you don’t really care what follows. If the first two elements match, then it’s a match. Now you can put divBy3Not2 to use:

List(2,6,9,10,15,54)
  .map(divBy3Not2)
  .map(println)
// 2
// 6
// 9: Divisible by three, but not two
// 10
// 15: Divisible by three, but not two
// 54

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.