pele
Python energy landscape explorer
/home/js850/projects/pele/source/pele/ngt.hpp
00001 #ifndef _NGT_HPP_
00002 #define _NGT_HPP_
00003 /*
00004  *
00005  * This implements the New Graph Transformation method (NGT) described in
00006  *
00007  * David Wales, J. Chem. Phys., 2009 http://dx.doi.org/10.1063/1.3133782
00008  *
00009  * This procedure computes transition rates and committor probabilities for
00010  * transition network (kinetic monte carlo).
00011  */
00012 
00013 
00014 #include <cstdlib>
00015 #include <iostream>
00016 #include <list>
00017 #include <queue>
00018 #include <assert.h>
00019 #include <stdexcept>
00020 #include <memory>
00021 
00022 #include "graph.hpp"
00023 
00024 using std::cout;
00025 
00026 namespace pele
00027 {
00028 
00029 bool compare_degree(node_ptr u, node_ptr v){
00030     return u->in_out_degree() < v->in_out_degree();
00031 }
00032 
00033 class NGT {
00034 public:
00035     typedef std::map<std::pair<node_id, node_id>, double> rate_map_t;
00036 
00037     std::shared_ptr<Graph> _graph;
00038     std::set<node_ptr> _A; // the source nodes
00039     std::set<node_ptr> _B; // the sink nodes
00040     std::list<node_ptr> intermediates; //this will an up to date list of nodes sorted by the node degree
00041     bool debug;
00042 
00047     std::map<node_id, double> initial_tau;
00051     std::map<node_id, double> final_omPxx;
00055     std::map<node_id, double> final_tau;
00056     std::map<node_id, double> final_committors;
00057     std::map<node_id, double> weights; // normally these are equilibrium occupation probabilities
00058 
00059 
00060     
00061     ~NGT()
00062     {
00063     }
00064 
00065     /*
00066      * construct the NGT from an existing graph.
00067      *
00068      * The graph will be used directly, without copying.  Any modifications
00069      * will be reflected in the passed graph
00070      */
00071     template<class Acontainer, class Bcontainer>
00072     NGT(std::shared_ptr<Graph> graph, Acontainer const &A, Bcontainer const &B) :
00073         _graph(graph),
00074         debug(false)
00075     {
00076         for (auto u : A){
00077             _A.insert(_graph->get_node(u));
00078         }
00079         for (auto u : B){
00080             _B.insert(_graph->get_node(u));
00081         }
00082 
00083         // make intermediates
00084         for (auto const & mapval : _graph->node_map_){
00085             node_ptr u = mapval.second;
00086             if (_A.find(u) == _A.end() and _B.find(u) == _B.end()){
00087                 intermediates.push_back(u);
00088             }
00089         }
00090 
00091 //        std::cout << "number of nodes " << _graph->number_of_nodes() << "\n";
00092 //        std::cout << "A.size() " << _A.size() << "\n";
00093 //        std::cout << "B.size() " << _B.size() << "\n";
00094 //        std::cout << "intermediates.size() " << intermediates.size() << "\n";
00095         assert(intermediates.size() + _A.size() + _B.size() == _graph->number_of_nodes());
00096 
00097     }
00098 
00099     void set_debug() { debug=true; }
00100     std::map<node_id, double> const & get_committors() { return final_committors; }
00101 
00102     /*
00103      * construct the NGT from a map of rate constants.
00104      */
00105     template<class Acontainer, class Bcontainer>
00106     NGT(rate_map_t &rate_constants, Acontainer const &A, Bcontainer const &B) :
00107         _graph(new Graph()),
00108         debug(false)
00109     {
00110         std::set<node_ptr> nodes;
00111 
00112         // add nodes to the graph and sum the rate constants for all out edges for each node.
00113         std::map<node_ptr, double> sum_out_rates;
00114         for (auto const & mapvals : rate_constants){
00115             node_ptr u = _graph->add_node(mapvals.first.first);
00116             node_ptr v = _graph->add_node(mapvals.first.second);
00117             double k = mapvals.second;
00118             nodes.insert(u);
00119             nodes.insert(v);
00120 
00121             try {
00122                 sum_out_rates.at(u) += k;
00123             } catch (std::out_of_range & e) {
00124                 sum_out_rates[u] = k;
00125             }
00126         }
00127 
00128         // set tau_x for each node
00129         // add edge Pxx for each node and initialize P to 0.
00130         for (auto x : nodes){
00131             double tau_x = 1. / sum_out_rates[x];
00132             set_tau(x, tau_x);
00133             initial_tau[x->id()] = tau_x;
00134             edge_ptr xx = _graph->_add_edge(x, x);
00135             set_P(xx, 0.);
00136         }
00137 
00138         // set Puv for each edge
00139         for (auto const & mapval : rate_constants){
00140             node_ptr u = _graph->get_node(mapval.first.first);
00141             node_ptr v = _graph->get_node(mapval.first.second);
00142             double k = mapval.second;
00143 
00144             edge_ptr uv = _graph->_add_edge(u, v);
00145             double tau_u = get_tau(u);
00146             double Puv = k * tau_u;
00147             set_P(uv, Puv);
00148 
00149             try {
00150                 sum_out_rates.at(u) += k;
00151             } catch (std::out_of_range & e) {
00152                 sum_out_rates[u] = k;
00153             }
00154         }
00155 
00156 
00157         // make the set of A and B
00158         for (auto a : A){
00159             _A.insert(_graph->get_node(a));
00160         }
00161         for (auto b : B){
00162             _B.insert(_graph->get_node(b));
00163         }
00164 
00165         // make a list of intermediates
00166         for (auto a : _A){
00167             nodes.erase(a);
00168         }
00169         for (auto b : _B){
00170             nodes.erase(b);
00171         }
00172         intermediates.assign(nodes.begin(), nodes.end());
00173 
00174 
00175 //        std::cout << _graph->number_of_nodes() << "\n";
00176 //        std::cout << _A.size() << "\n";
00177 //        std::cout << _B.size() << "\n";
00178 //        std::cout << intermediates.size() << "\n";
00179 //        std::cout << nodes.size() << "\n";
00180         assert(intermediates.size() + _A.size() + _B.size() == _graph->number_of_nodes());
00181     }
00182 
00183     void set_node_occupation_probabilities(std::map<node_id, double> &Peq){
00184         weights.insert(Peq.begin(), Peq.end());
00185     }
00186 
00187     /*
00188      * Sort the list of intermediates.
00189      *
00190      * This is done because it is faster to remove nodes with fewer connections first.
00191      */
00192     void sort_intermediates(){
00193         node_ptr x = *intermediates.begin();
00194         if (debug){
00195             std::cout << "smallest node degree " << x->in_out_degree() << "\n";
00196         }
00197         if (x->in_out_degree() > 4) {
00198             intermediates.sort(compare_degree);
00199         }
00200     }
00201     
00202     /*
00203      * accessors for graph properties P and tau attached to the edges and nodes.
00204      */
00205     inline double get_tau(node_ptr u){ return u->tau; }
00206     inline double get_P(edge_ptr edge){ return edge->P; }
00207     inline void set_tau(node_ptr u, double tau){ u->tau = tau; }
00208     inline void set_P(edge_ptr edge, double P){ edge->P = P; }
00209 
00210     /*
00211      * This returns P for the edge u->u.  This is slow because the edge must first be found.
00212      */
00213     double get_node_P(node_ptr u){ return get_P(u->get_successor_edge(u)); }
00214 
00215     /*
00216      * This returns 1.-P for the edge u->u.
00217      *
00218      * If P is close to one compute 1.-P directly by summing P over all the out edges of u.
00219      * This is extremely important for numerical precision.  It is this ability to deal precisely
00220      * with both P and 1.-P that makes this method more stable then linear algebra methods.
00221      */
00222     double get_node_one_minus_P(node_ptr u){
00223         edge_ptr uu = u->get_successor_edge(u);
00224         double Puu = get_P(uu);
00225         if (Puu < 0.99){
00226             return 1. - Puu;
00227         } else {
00228             // sum the contributions from all other edges
00229             double omPuu = 0.;
00230             for (auto eiter = u->out_edge_begin(); eiter != u->out_edge_end(); ++eiter){
00231                 node_ptr v = (*eiter)->head();
00232                 if (v != u){
00233                     omPuu += (*eiter)->P;
00234                 }
00235             }
00236             return omPuu;
00237         }
00238     }
00239 
00240 
00241     /*
00242      * node x is being deleted, so update tau for node u
00243      *
00244      * tau_u -> tau_u + Pux * tau_x / (1-Pxx)
00245      */
00246     void update_node(edge_ptr ux, double omPxx, double tau_x){
00247         node_ptr u = ux->tail();
00248         double Pux = get_P(ux);
00249         double tau_u = get_tau(u);
00250         double new_tau_u = tau_u + Pux * tau_x / omPxx;
00251         if (debug){
00252             std::cout << "updating node " << u->id() << " tau " << tau_u << " -> " << new_tau_u << "\n";
00253         }
00254         set_tau(u, new_tau_u);
00255     }
00256 
00257     /*
00258      * add an edge to the graph and set P to 0
00259      */
00260     edge_ptr add_edge(node_ptr u, node_ptr v){
00261        edge_ptr edge = _graph->_add_edge(u, v);
00262        set_P(edge, 0.);
00263        return edge;
00264     }
00265 
00266     /*
00267      * Node x is being deleted, so update P for the edge u -> v
00268      *
00269      * Puv -> Puv + Pux * Pxv / (1-Pxx)
00270      */
00271     void update_edge(node_ptr u, node_ptr v, edge_ptr ux, edge_ptr xv, double omPxx){
00272         edge_ptr uv = u->get_successor_edge(v);  // this is slow
00273         if (uv == NULL){
00274             uv = add_edge(u, v);
00275         }
00276 
00277         double Pux = get_P(ux);
00278         double Pxv = get_P(xv);
00279         double Puv = get_P(uv);
00280 
00281         double newPuv = Puv + Pux * Pxv / omPxx;
00282         if (debug) {
00283             std::cout << "updating edge " << u->id() << " -> " << v->id() << " Puv " << Puv << " -> " << newPuv
00284                     << " 1-Pxx " << omPxx
00285                     << " Pux " << Pux
00286                     << " Pxv " << Pxv
00287                     << "\n";
00288         }
00289         set_P(uv, newPuv);
00290     }
00291 
00292     /*
00293      * remove node x from the graph and update its neighbors
00294      */
00295     void remove_node(node_ptr x){
00296         if (debug){
00297             std::cout << "removing node " << x->id() << "\n";
00298         }
00299         double taux = get_tau(x);
00300 //        double Pxx = get_node_P(x);
00301         double omPxx = get_node_one_minus_P(x);
00302 
00303         // update the node data for all the neighbors
00304         for (auto eiter = x->in_edge_begin(); eiter != x->in_edge_end(); eiter++){
00305             edge_ptr edge = *eiter;
00306             if (edge->tail() != edge->head()){
00307                 update_node(edge, omPxx, taux);
00308             }
00309         }
00310 
00311         std::set<node_ptr> neibs = x->in_out_neighbors();
00312         neibs.erase(x);
00313 
00314         //
00315         for (auto uxiter = x->in_edge_begin(); uxiter != x->in_edge_end(); ++uxiter){
00316             edge_ptr ux = *uxiter;
00317             node_ptr u = ux->tail();
00318             if (u == x) continue;
00319             for (auto xviter = x->out_edge_begin(); xviter != x->out_edge_end(); ++xviter){
00320                 edge_ptr xv = *xviter;
00321                 node_ptr v = xv->head();
00322                 if (v == x) continue;
00323 //                if (u == v){
00324 //                    continue;
00325 //                }
00326                 update_edge(u, v, ux, xv, omPxx);
00327             }
00328         }
00329 
00330         // remove the node from the graph
00331         _graph->_remove_node(x);
00332 
00333     }
00334 
00335     /*
00336      * remove all intermediates from the graph
00337      */
00338     void remove_intermediates(){
00339         while (intermediates.size() > 0){
00340             sort_intermediates();
00341 
00342             node_ptr x = intermediates.front();
00343             intermediates.pop_front();
00344 
00345             remove_node(x);
00346         }
00347     }
00348 
00349     /*
00350      * phase one of the rate calculation is to remove all intermediate nodes
00351      */
00352     void phase_one(){
00353         remove_intermediates();
00354     }
00355 
00356     /*
00357      * Compute final_tau and final_omPxx for each node in to_remove
00358      *
00359      * For each node x in to_remove, this involves removing all other nodes in to_remove, and
00360      * getting the results from this reduced graph.
00361      */
00362     void reduce_all_in_group(std::set<node_ptr> &to_remove, std::set<node_ptr> & to_keep){
00363         std::list<node_id> Aids, Bids;
00364         // copy the ids of the nodes in to_remove into Aids
00365         for (auto u : to_remove){
00366             Aids.push_back(u->id());
00367         }
00368         // copy the ids of the nodes in to_keep into Bids
00369         for (auto u : to_keep){
00370             Bids.push_back(u->id());
00371         }
00372 
00373         // note: should we sort the minima in to_remove?
00374 
00375         if (Aids.size() > 1){
00376             // make a copy of _graph called working_graph
00377             auto working_graph = std::make_shared<Graph> (*_graph);
00378             std::list<node_id> empty_list;
00379             // make an ngt object for working_graph
00380             NGT working_ngt(working_graph, std::list<node_id>(), Bids);
00381             while (Aids.size() > 1){
00382                 /*
00383                  * Create a new graph and a new NGT object new_ngt.  Pass x as A and Bids as B.  new_ngt will
00384                  * remove all `intermediates`, i.e. everything in Aids except x.  Then save the final
00385                  * value of 1-Pxx and tau_x.
00386                  */
00387                 // choose an element x and remove it from the list
00388                 node_id x = Aids.back();
00389                 Aids.pop_back();
00390                 std::list<node_id> newAids;
00391                 newAids.push_back(x);
00392 
00393                 // make a new graph from the old graph
00394                 auto new_graph = std::make_shared<Graph>(*working_graph);
00395 
00396                 // remove all nodes from new_graph except x
00397                 NGT new_ngt(new_graph, newAids, Bids);
00398                 new_ngt.remove_intermediates();
00399                 node_ptr xptr = new_graph->get_node(x);
00400                 final_omPxx[x] = new_ngt.get_node_one_minus_P(xptr);
00401                 final_tau[x] = new_ngt.get_tau(xptr);
00402 
00403                 // delete node x from the old_graph
00404                 working_ngt.remove_node(working_graph->get_node(x));
00405             }
00406             // there is one node left. we can just read off the results
00407             assert(Aids.size() == 1);
00408             node_id x = Aids.back();
00409             Aids.pop_back();
00410             node_ptr xptr = working_graph->get_node(x);
00411             final_omPxx[x] = working_ngt.get_node_one_minus_P(xptr);
00412             final_tau[x] = working_ngt.get_tau(xptr);
00413 
00414         } else if (Aids.size() == 1) {
00415             // if there is only one node in A then we can just read off the results.
00416             node_id x = Aids.back();
00417             Aids.pop_back();
00418             node_ptr xptr = _graph->get_node(x);
00419             final_omPxx[x] = get_node_one_minus_P(xptr);
00420             final_tau[x] = get_tau(xptr);
00421         }
00422         assert(Aids.size() == 0);
00423     }
00424 
00425     /*
00426      * Phase two, compute final_tau and final_omPxx for each x separately in _A and in _B
00427      */
00428     void phase_two(){
00429         reduce_all_in_group(_A, _B);
00430         reduce_all_in_group(_B, _A);
00431     }
00432 
00433     /*
00434      * do phase one and phase two of the rate calculation
00435      */
00436     void compute_rates(){
00437         phase_one();
00438         phase_two();
00439     }
00440 
00441     /*
00442      * compute the final rate A->B or B->A from final_tau and final_omPxx
00443      */
00444     double _get_rate_final(std::set<node_ptr> &A){
00445         double rate_sum = 0.;
00446         double norm = 0.;
00447         for (auto a : A){
00448             double omPxx = final_omPxx.at(a->id());
00449             double tau_a = final_tau.at(a->id());
00450             double weight = 1.;
00451             if (weights.size() > 0){
00452                 weight = weights.at(a->id());
00453             }
00454             rate_sum += weight * omPxx / tau_a;
00455             norm += weight;
00456         }
00457         return rate_sum / norm;
00458     }
00459 
00460     /*
00461      * Return the rate A->B
00462      */
00463     double get_rate_AB(){
00464         return _get_rate_final(_A);
00465     }
00466 
00467     /*
00468      * Return the rate B->A
00469      */
00470     double get_rate_BA(){
00471         return _get_rate_final(_B);
00472     }
00473 
00474     double _get_rate_SS(std::set<node_ptr> & A, std::set<node_ptr> & B){
00475         double kAB = 0.;
00476         double norm = 0.;
00477         for (auto a : A){
00478             // compute PaB the probability that this node goes directly to B
00479             double PaB = 0.;
00480             for (auto eiter = a->out_edge_begin(); eiter != a->out_edge_end(); ++eiter){
00481                 edge_ptr ab = *eiter;
00482                 node_ptr b = ab->head();
00483                 if (B.find(b) != B.end()){
00484                     PaB += get_P(ab);
00485                 }
00486             }
00487             double weight = 1.;
00488             if (weights.size() > 0){
00489                 weight = weights.at(a->id());
00490             }
00491             kAB += weight * PaB / initial_tau.at(a->id());
00492             norm += weight;
00493         }
00494         return kAB / norm;
00495     }
00496 
00497     /*
00498      * Return the steady state rate A->B
00499      *
00500      * this must be called after calling phase_one
00501      */
00502     double get_rate_AB_SS(){
00503         return _get_rate_SS(_A, _B);
00504     }
00505 
00506     /*
00507      * Return the steady state rate B->A
00508      *
00509      * this must be called after calling phase_one
00510      */
00511     double get_rate_BA_SS(){
00512         return _get_rate_SS(_B, _A);
00513     }
00514 
00515     /*
00516      * sum the probabilities of the out edges of x that end in B normalized by 1-Pxx
00517      */
00518     double get_PxB(node_ptr x, std::set<node_id> & B){
00519         double PxB = 0.;
00520         double Pxx = 0.;
00521         double omPxx = 0.;
00522         for (auto eiter = x->out_edge_begin(); eiter != x->out_edge_end(); ++eiter){
00523             edge_ptr xb = *eiter;
00524             node_ptr b = xb->head();
00525             double Pxb = get_P(xb);
00526             if (b == x){
00527                 Pxx = Pxb;
00528             } else {
00529                 omPxx += Pxb;
00530             }
00531             if (B.find(b->id()) != B.end()){
00532                 PxB += Pxb;
00533             }
00534         }
00535         if (Pxx < 0.9){
00536             omPxx = 1. - Pxx;
00537         }
00538         return PxB / omPxx;
00539     }
00540 
00541     /*
00542      * compute the committors for all intermediates
00543      *
00544      * \param to_remove a list of nodes that will be removed.  Committor values will
00545      *     be computed for these nodes
00546      * \param to_keep a list of nodes that should not be deleted.
00547      * \param committor_targets a list of nodes that should not be deleted.  These
00548      *     nodes will be the targets in the committor calucation.
00549      *
00550      * All nodes should be in one of the three passed groups of nodes.  Duplicates
00551      * between to_keep and committor_targets are OK.
00552      */
00553     void _remove_nodes_and_compute_committors(std::list<node_ptr> &to_remove,
00554             std::set<node_ptr> &to_keep, std::set<node_ptr> const &committor_targets)
00555     {
00556         // make a copy of to_remove.  Store the id's
00557         std::list<node_id> to_remove_cp;
00558         for (auto u : to_remove){
00559             to_remove_cp.push_back(u->id());
00560         }
00561 
00562         // copy the nodes from to_keep and committor_target into a new set Bids;
00563         // make a copy of committor_target
00564         std::set<node_id> Bids;
00565         std::set<node_id> targets;
00566         // create a set of Bids
00567         for (auto u : to_keep){
00568             Bids.insert(u->id());
00569         }
00570         for (auto u : committor_targets){
00571             Bids.insert(u->id());
00572             targets.insert(u->id());
00573         }
00574 
00575         // ensure there are no unaccounted for nodes
00576         assert(to_remove_cp.size() + Bids.size() == _graph->number_of_nodes());
00577 
00578         // note: should we sort the nodes in to_remove?
00579 
00580         while (to_remove_cp.size() > 0){
00581             /*
00582              * Create a new graph and a new NGT object new_ngt.  Pass x as A and Bids as B.  new_ngt will
00583              * remove all `intermediates`, i.e. everything in to_remove except x.  Then save the final
00584              * value of 1-Pxx and tau_x.
00585              */
00586             // choose an element x and remove it from the list
00587             node_id x = to_remove_cp.back();
00588             to_remove_cp.pop_back();
00589             std::list<node_id> Aids;
00590             Aids.push_back(x);
00591 
00592             // make a copy of _graph
00593             auto new_graph = std::make_shared<Graph>(*_graph);
00594 
00595             // remove all to_remove nodes from new_graph except x
00596             NGT new_ngt(new_graph, Aids, Bids);
00597             new_ngt.remove_intermediates();
00598             node_ptr xptr = new_graph->get_node(x);
00599             final_omPxx[x] = new_ngt.get_node_one_minus_P(xptr);
00600             final_tau[x] = new_ngt.get_tau(xptr);
00601             if (! targets.empty()){
00602                 final_committors[x] = new_ngt.get_PxB(xptr, targets);
00603             }
00604 
00605             // delete node x from _graph
00606             this->remove_node(_graph->get_node(x));
00607         }
00608     }
00609 
00610     /*
00611      * Compute the rate from A->B and committor probabilities for all intermediates
00612      *
00613      * This is much slower than compute_rates.
00614      * If you don't want committors use that function instead
00615      */
00616     void compute_rates_and_committors(){
00617         _remove_nodes_and_compute_committors(intermediates, _A, _B);
00618         intermediates.clear();
00619 
00620         phase_two();
00621 
00622         // set the committor for nodes in A to 0
00623         for (auto a : _A){
00624             final_committors[a->id()] = 0.;
00625         }
00626         // set the committor for nodes in B to 1
00627         for (auto b : _B){
00628             final_committors[b->id()] = 1.;
00629         }
00630     }
00631 
00632 
00633 };
00634 
00635 }
00636 #endif
 All Classes Namespaces Functions Variables Typedefs