Thursday, March 7, 2013

Efficient Memoization using Partial Function Application

Memoization is any technique of storing previously computed values for fast retrieval, rather than recomputing the same value multiple times:


bool f(int); // defined elsewhere
f(2); // compute f(2)
f(2); // retrieve previously stored value for f(2) instead of recomputing it


This is useful when calling f is expensive.

In C++, the most obvious way to implement memoization is perhaps using a std::map:

(I'm assuming correctness is important here, hence, storing only hashes and return values won't do.)


bool f(int); // defined elsewhere
bool memo_f(std::map<int,bool>& memo_tbl, int i)
{
    auto at = memo_tbl.find(i);
    if (at == memo_tbl.end()) {
        // Not previously computed, compute and store
        at = memo_tbl.insert(std::make_pair(i,f(i))).first;
    }
    return at->second;
}


This works fine for any function from a domain X (int above) to another domain Y (bool above).
But what if we want to compute a function that takes multiple arguments?
Say we have:

bool f(int,int); // defined elsewhere

We could then use a std::map<std::pair<int,int>,bool>, where the pair represents the two arguments (more generally, we could use a std::tuple, or even a std::vector for variable size arguments).
But these approaches all have one problem that becomes apparent if we replace our int's with heavier objects:


bool f(const Matrix&, const Graph&);

bool memo_f(std::map<std::pair<Matrix,Graph>,bool>& memo_tbl, const Matrix& m, const Graph& g)
{
    std::pair<Matrix,Graph> p(m,g); // <-- ouch! Copies the entire matrix and graph!
    auto at = memo_tbl.find(p);
    if (at == memo_tbl.end()) {
        // Not previously computed, compute and store
        at = memo_tbl.insert(std::make_pair(i,f(i))).first;
    }
    return at->second;
}


I'm assuming that all objects can be ordered. The argument works equally well using std::unordered_map with a hash function for all objects. How to order or hash objects is not the concern here, but rather, avoiding unnecessary copying of large objects.

Using points to solve the problem isn't a good idea:


bool f(const Matrix&, const Graph&);
typedef std::pair<const Matrix*,const Graph*> inputs;
// pair_ptr_cmp defined elsewhere, dereferences the pointers inside the pair and compares them

bool memo_f(std::map<inputs,bool,pair_ptr_cmp>& memo_tbl, const Matrix& m, const Graph& g)
{
    inputs p(&m,&g);
    auto at = memo_tbl.find(p);
    if (at == memo_tbl.end()) {
        // Not previously computed, compute and store
        bool ans = f(*m,*g);
        // Note safe to store soft pointers to external objects, so we need to copy
        p = inputs(new Matrix(*m),new Graph(*g));
        at = memo_tbl.insert(std::make_pair(p,ans)).first;
    }
    return at->second;
}

This works, except we need to define the awkward pair_ptr_cmp, and, the entire block starting right after our declaration of memo_tbl (not shown above) until it won't be used anymore will have to be wrapped in a try - catch (...) block:


std::map<inputs,bool,pair_ptr_cmp> memo_tbl;
try {
    // Want to call f with Matrix m and graph g
    bool b = memo_f(memo_tbl, m,g);
    // ...
} catch (...) {
    for (auto& p : memo_tbl) {
        delete p.first;
        delete p.second;
    }
    throw;
}

Needless to say, this is hard to read, error-prone, and just bad C++.

A good solution I've found is to use the concept of partial function application. The idea is simple: we may curry a 2 dimensional function (X,Y) -> f(X,Y) as X -> (Y -> f(X,Y)). That is, instead of taking two arguments from domain (of type) X and Y, it takes one argument from X and returns a function. This new function, takes the second argument, Y, and returns the desired f(X,Y).
Some code will make the point clear, and also demonstrate its elegance:


typedef std::string arg1_type;
typedef std::vector<int> arg2_type;
typedef bool ret_type;

// Function to call
bool function(const arg1_type& a1, const arg2_type& a2);

// Memoization of function
typedef std::map<arg1_type,std::map<arg2_type,ret_type>> memo_tbl;
ret_type memo_call(memo_tbl& f, const arg1_type& a1, const arg2_type& a2);

ret_type memo_call(memo_tbl& f, const arg1_type& a1, const arg2_type& a2)
{
 ret_type ans;
 auto at = f.find(a1);
 if (at != f.end()) {
  auto at2 = at->second.find(a2);
  if (at2 != at->second.end()) {
   ans = at2->second;
  } else {
   // map for a1 exists but value of a2 not stored
   ans = function(a1,a2);
   at->second.insert(std::make_pair(a2,ans));
  }
 } else {
  // map for a2 (and a2) does not exist
  ans = function(a1,a2);
  std::map<arg2_type,ret_type> m;
  m.insert(std::make_pair(a2,ans));
  f.insert(std::make_pair(a1,std::move(m)));
 }
 return ans;
}

We can now test it:


static int counter = 0; // function keeps track of number of calls
bool function(const arg1_type& a1, const arg2_type& a2)
{
 ++counter;
 return a1.size() == a2.size();
}

int main()
{
 memo_tbl memo_tbl;
 std::cerr << memo_call(memo_tbl,arg1_type("Hello World!"),arg2_type(10)) << "\n";
 std::cerr << memo_call(memo_tbl,arg1_type("Hello World!"),arg2_type(12)) << "\n";
 std::cerr << memo_call(memo_tbl,arg1_type("Hello World!"),arg2_type(10)) << "\n";
 std::cerr << memo_call(memo_tbl,arg1_type("Bye World!"),arg2_type(10)) << "\n";

 std::cerr << "Counter: " << counter << "\n";
}

Output is:
0
1
0
1
Counter: 3

The method generalizes to n-arity functions in the obvious way: just nest std::maps inside each other.


1 comment:

  1. Why not to turn memo_call into a template function.
    I'm not yet good at c++11 but I bet you should be able to exploit variadic templates here.
    Also I think it's better to have a class with a functor and avoid passing the map to the memo_call each time.

    ReplyDelete