angusjf

@functools.cache-style Memoization in Haskell

If you've ever attempted Advent of Code in Haskell, you have probably had to memoize a function to complete a challenge. Looking though submissions, you'll find yourself seething with jealousy at how Python users can simply add the @functools.cache decorator to their functions to magically memoize them.

A similar level of magic (though admittedly not quite as ergonomic) can be achived in Haskell, without any external libraries, just the builtin Data.Map Control.Monad.State.

Note: This is not the most efficient memoization possible, just 'good enough' to solve AOC problems.

The first step is to ensure your function is unary, meaning it takes only one argument. If it takes more than one, for example:

f :: Int -> Int -> Int

Pass them in as a tuple:

f :: (Int, Int) -> Int

The argument to the function will be stored in the map as the key.

Next, define these two functions:

(I've use the type variables x & y to represent the input and output of the function f :: x -> y respectively)

import qualified Data.Map as M
import Control.Monad.State

memoized :: Ord x => (x -> State (M.Map x y) y) -> x -> State (M.Map x y) y
memoized f x = do
    cache <- get                     -- get the 'cache' from the state monad
    case M.lookup x cache of         -- check if the key `x` exists in the cache
        Just hit -> return hit       -- if it exists, return it
        Nothing -> do
            res <- f x               -- else, call f(x)
            modify (M.insert x res)  -- set cache[x] = f(x)
            return res               -- return f(x)

runMemoized :: (x -> State (M.Map x y) y) -> x -> y
runMemoized f x = evalState (f x) M.empty

Now, you must replace all recursive calls to the function you wish to memoize.

For example, given a function f:

f :: Int -> Int
f 0 = 1
f 1 = 1
f x = f (x - 1) + f (x - 2)

Convert all the base cases to return the value, and replace recusive calls to f with memoized f.

-- (M.Map Int Int) is the type of our cache
f :: Int -> State (M.Map Int Int) Int
f 0 = return 1
f 1 = return 1
f x =
  do
    a <- memoized f (x - 1)
    b <- memoized f (x - 2)
    return (a + b)

Now, to run the function call:

ghci> runMemoized f 100
1298777728820984005

Happy memoizing!