pele
Python energy landscape explorer
|
00001 #ifndef _PELE_MATRIX_H_ 00002 #define _PELE_MATRIX_H_ 00003 00004 namespace pele{ 00005 00012 template<class dtype> 00013 class MatrixAdapter : public pele::Array<dtype> { 00014 public: 00021 size_t _dim2; 00022 00026 MatrixAdapter(size_t dim1, size_t dim2, dtype val=0) 00027 : pele::Array<dtype>(dim1 * dim2, val), 00028 _dim2(dim2) 00029 {} 00030 00036 MatrixAdapter(pele::Array<double> v, size_t dim2) 00037 : pele::Array<dtype>(v), 00038 _dim2(dim2) 00039 { 00040 if (v.size() % dim2 != 0) { 00041 throw std::invalid_argument("v.size() is not divisible by dim2"); 00042 } 00043 } 00044 00048 MatrixAdapter(double * data, size_t dim1, size_t dim2) 00049 : pele::Array<dtype>(data, dim1*dim2), 00050 _dim2(dim2) 00051 {} 00052 00056 inline dtype const & operator()(size_t i, size_t j) const 00057 { 00058 return this->operator[](i * _dim2 + j); 00059 } 00060 inline dtype & operator()(size_t i, size_t j) 00061 { 00062 return this->operator[](i * _dim2 + j); 00063 } 00064 00068 inline std::pair<size_t, size_t> shape() const 00069 { 00070 return std::pair<size_t, size_t>(this->size() / _dim2, _dim2); 00071 } 00072 }; 00073 00079 template<class dtype> 00080 MatrixAdapter<dtype> hacky_mat_mul(MatrixAdapter<dtype> const & A, MatrixAdapter<dtype> const & B) 00081 { 00082 assert(A.shape().second == B.shape().first); 00083 size_t const L = A.shape().second; 00084 size_t const N = A.shape().first; 00085 size_t const M = B.shape().second; 00086 00087 MatrixAdapter<dtype> C(N, M, 0); 00088 for (size_t i = 0; i<N; ++i){ 00089 for (size_t j = 0; j<M; ++j){ 00090 double val = 0; 00091 for (size_t k = 0; k<L; ++k){ 00092 val += A(i,k) * B(k,j); 00093 } 00094 C(i,j) = val; 00095 } 00096 } 00097 return C; 00098 } 00099 00101 // * multiply a matrix times an vector 00102 // */ 00103 //template<class dtype> 00104 //pele::Array<dtype> hacky_mat_mul(MatrixAdapter<dtype> const & A, pele::Array<dtype> const & v) 00105 //{ 00106 // assert(A.shape().second == v.size()); 00107 // size_t const L = A.shape().second; 00108 // size_t const n = A.shape().first; 00109 // 00110 // pele::Array<dtype> C(n, 0); 00111 // for (size_t i = 0; i<n; ++i){ 00112 // dtype val = 0; 00113 // for (size_t k = 0; k<L; ++k){ 00114 // val += A(i,k) * v[k]; 00115 // } 00116 // C(i) = val; 00117 // } 00118 // return C; 00119 //} 00120 00121 // for matrix printing 00122 template<class dtype> 00123 std::ostream &operator<<(std::ostream &out, const pele::MatrixAdapter<dtype> &a) { 00124 out << "[ "; 00125 size_t const N = a.shape().first; 00126 size_t const M = a.shape().second; 00127 for(size_t n=0; n<N;++n) { 00128 for(size_t m=0; m<M;++m) { 00129 if(m>0) out << ", "; 00130 out << a(n,m); 00131 } 00132 if (n < N-1) out << ",\n "; 00133 } 00134 out << " ]"; 00135 return out; 00136 } 00137 00138 00139 } 00140 #endif