Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 7 additions & 32 deletions inst/include/CoDA/glr.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using Eigen::VectorXd;
using Eigen::Map;
using Eigen::MatrixBase;
using Eigen::SparseMatrix;
using Eigen::EigenBase;

// Convention here is that samples are columns and parts are rows
// This is notably different than how Driver was originally implemented. Sorry
Expand All @@ -23,31 +24,19 @@ namespace coda {
//'
//' Calculates Y = GLR(X) or Y=GLRINV(X) defined by
//' @name glr
//'

template <typename TX, typename TV>
Eigen::MatrixXd glr(Eigen::MatrixBase<TX>& X, Eigen::MatrixBase<TV>& V){ // was MatrixBase and SparseMatrixBase
int P = V.rows();
int D = V.cols();
if (X.rows() != D) throw std::invalid_argument("X.rows() != V.cols()");

MatrixXd Y = X.array().log().matrix();
return V*Y;
}

//' @rdname glr
template <typename TX, typename TV>
Eigen::MatrixXd glr(Eigen::MatrixBase<TX>& X, Eigen::SparseMatrixBase<TV>& V){ // was MatrixBase and SparseMatrixBase
Eigen::MatrixXd glr(Eigen::MatrixBase<TX>& X, const TV& V){
int P = V.rows();
int D = V.cols();
if (X.rows() != D) throw std::invalid_argument("X.rows() != V.cols()");

MatrixXd Y = X.array().log().matrix();
return V*Y;
return V*(X.array().log().matrix());
}


//' @rdname glr
template <typename TX, typename TV>
Eigen::MatrixXd glrInv(Eigen::MatrixBase<TX>& X, Eigen::MatrixBase<TV>& V){
Eigen::MatrixXd glrInv(Eigen::MatrixBase<TX>& X, const TV& V){
int P = V.rows();
int D = V.cols();
if (X.rows() != P) throw std::invalid_argument("X.rows() != V.rows()");
Expand All @@ -57,21 +46,7 @@ namespace coda {
O = O.array().exp().matrix();
return clo(O);
}


//' @rdname glr
template <typename TX, typename TV>
Eigen::MatrixXd glrInv(Eigen::MatrixBase<TX>& X, Eigen::SparseMatrixBase<TV>& V){
int P = V.rows();
int D = V.cols();
if (X.rows() != P) throw std::invalid_argument("X.rows() != V.cols()");

MatrixXd O;
O.noalias() = V.transpose()*X;
O = O.array().exp().matrix();
return clo(O);
}


} /* End coda Namespace */


Expand Down