pele
Python energy landscape explorer
|
00001 #ifndef _PELE_POTENTIAL_FUNCTION_H 00002 #define _PELE_POTENTIAL_FUNCTION_H 00003 #include "array.h" 00004 #include "base_potential.h" 00005 #include <Python.h> 00006 #include <numpy/arrayobject.h> 00007 #include <iostream> 00008 #include <stdexcept> 00009 00010 //#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION 00011 00012 namespace pele { 00013 00021 /* 00022 class PotentialFunction : public BasePotential 00023 { 00024 public: 00025 typedef double EnergyCallback(Array<double>, void *); 00026 typedef double EnergyGradientCallback(Array<double>, Array<double>, void *); 00027 00028 PotentialFunction(EnergyCallback *get_energy, EnergyGradientCallback *get_energy_gradient, void *userdata) 00029 : _get_energy(get_energy), _get_energy_gradient(get_energy_gradient), _userdata(userdata) {} 00030 00031 virtual double get_energy(Array<double> x) { return (*_get_energy)(x, _userdata); } ; 00032 virtual double get_energy_gradient(Array<double> x, Array<double> grad) { return (*_get_energy_gradient)(x, grad, _userdata); } 00033 00034 private: 00035 EnergyCallback *_get_energy; 00036 EnergyGradientCallback *_get_energy_gradient; 00037 void *_userdata; 00038 }; 00039 */ 00040 00041 00047 class PythonPotential : public BasePotential 00048 { 00049 PyObject * _potential; 00050 00051 public: 00052 PythonPotential(PyObject * potential) 00053 : _potential(potential) 00054 { 00055 Py_XINCREF(_potential); 00056 00057 // import the the numpy array API. This is commented because 00058 // it is now done in the cython module file. It is possible I 00059 // need to define some preprocessor variables like 00060 // NO_IMPORT_ARRAY and PY_ARRAY_UNIQUE_SYMBOL, but the 00061 // documentation is a bit confusing. 00062 // http://docs.scipy.org/doc/numpy/reference/c-api.array.html 00063 //import_array(); 00064 } 00065 00066 virtual ~PythonPotential() 00067 { 00068 Py_XDECREF(_potential); 00069 } 00070 00074 virtual double get_energy(Array<double> x) 00075 { 00076 // create a numpy array from x 00077 // copy the data from x because becase the python object might 00078 // live longer than the data in x.data 00079 npy_intp N = (npy_intp) x.size(); 00080 PyObject * numpyx = PyArray_SimpleNew(1, &N, NPY_DOUBLE); 00081 if (!numpyx){ 00082 std::cerr << "created numpy object is NULL\n"; 00083 throw std::runtime_error("created numpy object is NULL\n"); 00084 } 00085 double * xdata = (double*) PyArray_DATA(numpyx); 00086 for (size_t i = 0; i < x.size(); ++i){ 00087 xdata[i] = x[i]; 00088 } 00089 00090 00091 // call the function getEnergy 00092 PyObject * name = PyString_FromString("getEnergy"); 00093 PyObject * returnval = PyObject_CallMethodObjArgs(_potential, name, numpyx, NULL); 00094 Py_XDECREF(name); 00095 Py_XDECREF(numpyx); 00096 if (!returnval){ 00097 //parse error 00098 throw std::runtime_error("getEnergy returned NULL"); 00099 } 00100 //std::cout << " done calling get energy\n"; 00101 00102 // parse the returned tuple 00103 double energy = PyFloat_AsDouble(returnval); 00104 Py_XDECREF(returnval); 00105 //TODO: error checking 00106 if (PyErr_Occurred()){ 00107 PyErr_Clear(); 00108 PyErr_SetString(PyExc_TypeError, "return value of getEnergy could not be converted to float"); 00109 throw std::runtime_error("return value of getEnergy could not be converted to float"); 00110 } 00111 00112 return energy; 00113 } 00114 00118 virtual double get_energy_gradient(Array<double> x, Array<double> grad) 00119 { 00120 if (x.size() != grad.size()) { 00121 throw std::invalid_argument("grad.size() be the same as x.size()"); 00122 } 00123 00124 // create a numpy array from x 00125 // copy the data from x because becase the python object might 00126 // live longer than the data in x.data 00127 npy_intp N = (npy_intp) x.size(); 00128 PyObject * numpyx = PyArray_SimpleNew(1, &N, NPY_DOUBLE); 00129 if (!numpyx){ 00130 std::cerr << "created numpy object is NULL\n"; 00131 throw std::runtime_error("created numpy object is NULL\n"); 00132 } 00133 double * numpyx_data = (double*) PyArray_DATA(numpyx); 00134 for (size_t i = 0; i < x.size(); ++i){ 00135 numpyx_data[i] = x[i]; 00136 } 00137 00138 // call the function getEnergy 00139 PyObject * name = PyString_FromString("getEnergyGradient"); 00140 PyObject * returnval = PyObject_CallMethodObjArgs(_potential, name, numpyx, NULL); 00141 Py_XDECREF(numpyx); 00142 Py_XDECREF(name); 00143 if (!returnval){ 00144 //parse error 00145 throw std::runtime_error("getEnergyGradient return is NULL"); 00146 } 00147 00148 // parse the returned tuple into a doulbe and a numpy array 00149 double energy; 00150 PyObject * npgrad_returned; //the reference count for this does not need to be decreased 00151 if (!PyArg_ParseTuple(returnval, "dO", &energy, &npgrad_returned)){ 00152 Py_XDECREF(returnval); 00153 throw std::runtime_error("failed to parse the tuple"); 00154 } 00155 00156 // convert the returned gradient into an array which I know I 00157 // can safely use as a double array. 00158 // note: NPY_CARRAY is for numpy version 1.6, for later version use NPY_ARRAY_CARRAY 00159 PyObject * npgrad_safe = PyArray_FromAny(npgrad_returned, 00160 PyArray_DescrFromType(NPY_DOUBLE), 1, 1, NPY_CARRAY, 00161 NULL); 00162 if (!npgrad_safe){ 00163 Py_XDECREF(returnval); 00164 throw std::runtime_error("gradient returned by getEnergyGradient could not be converted to numpy double array"); 00165 } 00166 // check the size of the gradient array 00167 if (static_cast<size_t>(PyArray_Size(npgrad_safe)) != grad.size()){ 00168 PyErr_SetString(PyExc_IndexError, "gradient returned by getEnergyGradient has wrong size."); 00169 Py_XDECREF(returnval); 00170 Py_XDECREF(npgrad_safe); 00171 throw std::runtime_error("return value of getEnergy could not be converted to float"); 00172 } 00173 00174 //copy the gradient into grad 00175 double * gdata = (double*) PyArray_DATA(npgrad_safe); 00176 for (size_t i = 0; i < grad.size(); ++i){ 00177 grad[i] = gdata[i]; 00178 } 00179 //std::cout << " done copying grad\n"; 00180 00181 // decrease referenece counts on Python objects 00182 Py_XDECREF(returnval); 00183 Py_XDECREF(npgrad_safe); 00184 00185 return energy; 00186 } 00187 }; 00188 } 00189 #endif