pele
Python energy landscape explorer
/home/js850/projects/pele/source/pele/matrix.h
00001 #ifndef _PELE_MATRIX_H_
00002 #define _PELE_MATRIX_H_
00003 
00004 namespace pele{
00005 
00012 template<class dtype>
00013 class MatrixAdapter : public pele::Array<dtype> {
00014 public:
00021     size_t _dim2;
00022 
00026     MatrixAdapter(size_t dim1, size_t dim2, dtype val=0)
00027         : pele::Array<dtype>(dim1 * dim2, val),
00028           _dim2(dim2)
00029     {}
00030 
00036     MatrixAdapter(pele::Array<double> v, size_t dim2)
00037         : pele::Array<dtype>(v),
00038           _dim2(dim2)
00039     {
00040         if (v.size() % dim2 != 0) {
00041             throw std::invalid_argument("v.size() is not divisible by dim2");
00042         }
00043     }
00044 
00048     MatrixAdapter(double * data, size_t dim1, size_t dim2)
00049         : pele::Array<dtype>(data, dim1*dim2),
00050           _dim2(dim2)
00051     {}
00052 
00056     inline dtype const & operator()(size_t i, size_t j) const
00057     {
00058         return this->operator[](i * _dim2 + j);
00059     }
00060     inline dtype & operator()(size_t i, size_t j)
00061     {
00062         return this->operator[](i * _dim2 + j);
00063     }
00064 
00068     inline std::pair<size_t, size_t> shape() const
00069     {
00070         return std::pair<size_t, size_t>(this->size() / _dim2, _dim2);
00071     }
00072 };
00073 
00079 template<class dtype>
00080 MatrixAdapter<dtype> hacky_mat_mul(MatrixAdapter<dtype> const & A, MatrixAdapter<dtype> const & B)
00081 {
00082     assert(A.shape().second == B.shape().first);
00083     size_t const L = A.shape().second;
00084     size_t const N = A.shape().first;
00085     size_t const M = B.shape().second;
00086 
00087     MatrixAdapter<dtype> C(N, M, 0);
00088     for (size_t i = 0; i<N; ++i){
00089         for (size_t j = 0; j<M; ++j){
00090             double val = 0;
00091             for (size_t k = 0; k<L; ++k){
00092                 val += A(i,k) * B(k,j);
00093             }
00094             C(i,j) = val;
00095         }
00096     }
00097     return C;
00098 }
00099 
00101 // * multiply a matrix times an vector
00102 // */
00103 //template<class dtype>
00104 //pele::Array<dtype> hacky_mat_mul(MatrixAdapter<dtype> const & A, pele::Array<dtype> const & v)
00105 //{
00106 //    assert(A.shape().second == v.size());
00107 //    size_t const L = A.shape().second;
00108 //    size_t const n = A.shape().first;
00109 //
00110 //    pele::Array<dtype> C(n, 0);
00111 //    for (size_t i = 0; i<n; ++i){
00112 //        dtype val = 0;
00113 //        for (size_t k = 0; k<L; ++k){
00114 //            val += A(i,k) * v[k];
00115 //        }
00116 //        C(i) = val;
00117 //    }
00118 //    return C;
00119 //}
00120 
00121 // for matrix printing
00122 template<class dtype>
00123 std::ostream &operator<<(std::ostream &out, const pele::MatrixAdapter<dtype> &a) {
00124     out << "[ ";
00125     size_t const N = a.shape().first;
00126     size_t const M = a.shape().second;
00127     for(size_t n=0; n<N;++n) {
00128         for(size_t m=0; m<M;++m) {
00129             if(m>0) out << ", ";
00130             out << a(n,m);
00131         }
00132         if (n < N-1) out << ",\n  ";
00133     }
00134     out << " ]";
00135     return out;
00136 }
00137 
00138 
00139 }
00140 #endif
 All Classes Namespaces Functions Variables Typedefs