pele
Python energy landscape explorer
/home/js850/projects/pele/source/pele/python_potential_wrapper.h
00001 #ifndef _PELE_POTENTIAL_FUNCTION_H
00002 #define _PELE_POTENTIAL_FUNCTION_H
00003 #include "array.h"
00004 #include "base_potential.h"
00005 #include <Python.h>
00006 #include <numpy/arrayobject.h>
00007 #include <iostream>
00008 #include <stdexcept>
00009 
00010 //#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
00011 
00012 namespace pele {
00013 
00021 /*
00022 class PotentialFunction : public BasePotential
00023 {
00024     public:
00025         typedef double EnergyCallback(Array<double>, void *);
00026         typedef double EnergyGradientCallback(Array<double>, Array<double>, void *);
00027 
00028         PotentialFunction(EnergyCallback *get_energy, EnergyGradientCallback *get_energy_gradient, void *userdata)
00029             :    _get_energy(get_energy), _get_energy_gradient(get_energy_gradient), _userdata(userdata) {}
00030 
00031         virtual double get_energy(Array<double> x) { return (*_get_energy)(x, _userdata); } ;
00032         virtual double get_energy_gradient(Array<double> x, Array<double> grad) {  return (*_get_energy_gradient)(x, grad, _userdata); }
00033 
00034     private:
00035         EnergyCallback *_get_energy;
00036         EnergyGradientCallback *_get_energy_gradient;
00037         void *_userdata;
00038 };
00039 */
00040 
00041 
00047 class PythonPotential : public BasePotential
00048 {
00049     PyObject * _potential;
00050 
00051 public:
00052     PythonPotential(PyObject * potential)
00053         :    _potential(potential)
00054     {
00055         Py_XINCREF(_potential);
00056 
00057         // import the the numpy array API.  This is commented because
00058         // it is now done in the cython module file.  It is possible I
00059         // need to define some preprocessor variables like
00060         // NO_IMPORT_ARRAY and PY_ARRAY_UNIQUE_SYMBOL, but the
00061         // documentation is a bit confusing.
00062         // http://docs.scipy.org/doc/numpy/reference/c-api.array.html
00063         //import_array();
00064     }
00065 
00066     virtual ~PythonPotential() 
00067     { 
00068         Py_XDECREF(_potential); 
00069     }
00070 
00074     virtual double get_energy(Array<double> x) 
00075     { 
00076         // create a numpy array from x
00077         // copy the data from x because becase the python object might
00078         // live longer than the data in x.data
00079         npy_intp N = (npy_intp) x.size();
00080         PyObject * numpyx = PyArray_SimpleNew(1, &N, NPY_DOUBLE);
00081         if (!numpyx){
00082             std::cerr << "created numpy object is NULL\n";
00083             throw std::runtime_error("created numpy object is NULL\n");
00084         }
00085         double * xdata = (double*) PyArray_DATA(numpyx);
00086         for (size_t i = 0; i < x.size(); ++i){
00087             xdata[i] = x[i];
00088         }
00089 
00090 
00091         // call the function getEnergy
00092         PyObject * name = PyString_FromString("getEnergy");
00093         PyObject * returnval = PyObject_CallMethodObjArgs(_potential, name, numpyx, NULL);
00094         Py_XDECREF(name); 
00095         Py_XDECREF(numpyx); 
00096         if (!returnval){
00097             //parse error
00098             throw std::runtime_error("getEnergy returned NULL");
00099         }
00100         //std::cout << "    done calling get energy\n";
00101 
00102         // parse the returned tuple
00103         double energy = PyFloat_AsDouble(returnval);
00104         Py_XDECREF(returnval);
00105         //TODO: error checking
00106         if (PyErr_Occurred()){
00107             PyErr_Clear();
00108             PyErr_SetString(PyExc_TypeError, "return value of getEnergy could not be converted to float");
00109             throw std::runtime_error("return value of getEnergy could not be converted to float");
00110         }
00111 
00112         return energy;
00113     }
00114 
00118     virtual double get_energy_gradient(Array<double> x, Array<double> grad)
00119     {
00120         if (x.size() != grad.size()) {
00121             throw std::invalid_argument("grad.size() be the same as x.size()");
00122         }
00123 
00124         // create a numpy array from x
00125         // copy the data from x because becase the python object might
00126         // live longer than the data in x.data
00127         npy_intp N = (npy_intp) x.size();
00128         PyObject * numpyx = PyArray_SimpleNew(1, &N, NPY_DOUBLE);
00129         if (!numpyx){
00130             std::cerr << "created numpy object is NULL\n";
00131             throw std::runtime_error("created numpy object is NULL\n");
00132         }
00133         double * numpyx_data = (double*) PyArray_DATA(numpyx);
00134         for (size_t i = 0; i < x.size(); ++i){
00135             numpyx_data[i] = x[i];
00136         }
00137         
00138         // call the function getEnergy
00139         PyObject * name = PyString_FromString("getEnergyGradient");
00140         PyObject * returnval = PyObject_CallMethodObjArgs(_potential, name, numpyx, NULL);
00141         Py_XDECREF(numpyx); 
00142         Py_XDECREF(name); 
00143         if (!returnval){
00144             //parse error
00145             throw std::runtime_error("getEnergyGradient return is NULL");
00146         }
00147 
00148         // parse the returned tuple into a doulbe and a numpy array
00149         double energy;
00150         PyObject * npgrad_returned; //the reference count for this does not need to be decreased
00151         if (!PyArg_ParseTuple(returnval, "dO", &energy, &npgrad_returned)){
00152             Py_XDECREF(returnval);
00153             throw std::runtime_error("failed to parse the tuple");
00154         }
00155 
00156         // convert the returned gradient into an array which I know I
00157         // can safely use as a double array.
00158         // note: NPY_CARRAY is for numpy version 1.6, for later version use NPY_ARRAY_CARRAY
00159         PyObject * npgrad_safe = PyArray_FromAny(npgrad_returned,
00160                 PyArray_DescrFromType(NPY_DOUBLE), 1, 1, NPY_CARRAY,
00161                 NULL);
00162         if (!npgrad_safe){
00163             Py_XDECREF(returnval);
00164             throw std::runtime_error("gradient returned by getEnergyGradient could not be converted to numpy double array");
00165         }
00166         // check the size of the gradient array
00167         if (static_cast<size_t>(PyArray_Size(npgrad_safe)) != grad.size()){
00168             PyErr_SetString(PyExc_IndexError, "gradient returned by getEnergyGradient has wrong size.");
00169             Py_XDECREF(returnval);
00170             Py_XDECREF(npgrad_safe);
00171             throw std::runtime_error("return value of getEnergy could not be converted to float");
00172         }
00173 
00174         //copy the gradient into grad
00175         double * gdata = (double*) PyArray_DATA(npgrad_safe);
00176         for (size_t i = 0; i < grad.size(); ++i){
00177             grad[i] = gdata[i];
00178         }
00179         //std::cout << "    done copying grad\n";
00180 
00181         // decrease referenece counts on Python objects
00182         Py_XDECREF(returnval);
00183         Py_XDECREF(npgrad_safe);
00184 
00185         return energy;
00186     }
00187 };
00188 }
00189 #endif
 All Classes Namespaces Functions Variables Typedefs