add runtime API to control multithreading
This commit is contained in:
parent
842b54fe80
commit
5b192930b6
@ -241,8 +241,8 @@ struct ei_gemm_functor
|
|||||||
|
|
||||||
Index sharedBlockBSize() const
|
Index sharedBlockBSize() const
|
||||||
{
|
{
|
||||||
int maxKc, maxMc;
|
Index maxKc, maxMc, maxNc;
|
||||||
getBlockingSizes<Scalar>(maxKc,maxMc);
|
getBlockingSizes<Scalar>(maxKc, maxMc, maxNc);
|
||||||
return std::min<Index>(maxKc,m_rhs.rows()) * m_rhs.cols();
|
return std::min<Index>(maxKc,m_rhs.rows()) * m_rhs.cols();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -25,6 +25,50 @@
|
|||||||
#ifndef EIGEN_PARALLELIZER_H
|
#ifndef EIGEN_PARALLELIZER_H
|
||||||
#define EIGEN_PARALLELIZER_H
|
#define EIGEN_PARALLELIZER_H
|
||||||
|
|
||||||
|
/** \internal */
|
||||||
|
inline void ei_manage_multi_threading(Action action, int* v)
|
||||||
|
{
|
||||||
|
static int m_maxThreads = -1;
|
||||||
|
|
||||||
|
if(action==SetAction)
|
||||||
|
{
|
||||||
|
ei_internal_assert(v!=0);
|
||||||
|
m_maxThreads = *v;
|
||||||
|
}
|
||||||
|
else if(action==GetAction)
|
||||||
|
{
|
||||||
|
ei_internal_assert(v!=0);
|
||||||
|
#ifdef EIGEN_HAS_OPENMP
|
||||||
|
if(m_maxThreads>0)
|
||||||
|
*v = m_maxThreads;
|
||||||
|
else
|
||||||
|
*v = omp_get_max_threads();
|
||||||
|
#else
|
||||||
|
*v = 1;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
ei_internal_assert(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** \returns the max number of threads reserved for Eigen
|
||||||
|
* \sa setNbThreads */
|
||||||
|
inline int nbThreads()
|
||||||
|
{
|
||||||
|
int ret;
|
||||||
|
ei_manage_multi_threading(GetAction, &ret);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Sets the max number of threads reserved for Eigen
|
||||||
|
* \sa nbThreads */
|
||||||
|
inline void setNbThreads(int v)
|
||||||
|
{
|
||||||
|
ei_manage_multi_threading(SetAction, &v);
|
||||||
|
}
|
||||||
|
|
||||||
template<typename BlockBScalar, typename Index> struct GemmParallelInfo
|
template<typename BlockBScalar, typename Index> struct GemmParallelInfo
|
||||||
{
|
{
|
||||||
GemmParallelInfo() : sync(-1), users(0), rhs_start(0), rhs_length(0), blockB(0) {}
|
GemmParallelInfo() : sync(-1), users(0), rhs_start(0), rhs_length(0), blockB(0) {}
|
||||||
@ -57,10 +101,10 @@ void ei_parallelize_gemm(const Functor& func, Index rows, Index cols)
|
|||||||
|
|
||||||
// 2- compute the maximal number of threads from the size of the product:
|
// 2- compute the maximal number of threads from the size of the product:
|
||||||
// FIXME this has to be fine tuned
|
// FIXME this has to be fine tuned
|
||||||
Index max_threads = std::max(1,rows / 32);
|
Index max_threads = std::max<Index>(1,rows / 32);
|
||||||
|
|
||||||
// 3 - compute the number of threads we are going to use
|
// 3 - compute the number of threads we are going to use
|
||||||
Index threads = std::min<Index>(omp_get_max_threads(), max_threads);
|
Index threads = std::min<Index>(nbThreads(), max_threads);
|
||||||
|
|
||||||
if(threads==1)
|
if(threads==1)
|
||||||
return func(0,rows, 0,cols);
|
return func(0,rows, 0,cols);
|
||||||
|
|||||||
@ -112,7 +112,8 @@ int main(int argc, char ** argv)
|
|||||||
if(procs>1)
|
if(procs>1)
|
||||||
{
|
{
|
||||||
BenchTimer tmono;
|
BenchTimer tmono;
|
||||||
omp_set_num_threads(1);
|
//omp_set_num_threads(1);
|
||||||
|
Eigen::setNbThreads(1);
|
||||||
BENCH(tmono, tries, rep, gemm(a,b,c));
|
BENCH(tmono, tries, rep, gemm(a,b,c));
|
||||||
std::cout << "eigen mono cpu " << tmono.best(CPU_TIMER)/rep << "s \t" << (double(m)*n*p*rep*2/tmono.best(CPU_TIMER))*1e-9 << " GFLOPS \t(" << tmono.total(CPU_TIMER) << "s)\n";
|
std::cout << "eigen mono cpu " << tmono.best(CPU_TIMER)/rep << "s \t" << (double(m)*n*p*rep*2/tmono.best(CPU_TIMER))*1e-9 << " GFLOPS \t(" << tmono.total(CPU_TIMER) << "s)\n";
|
||||||
std::cout << "eigen mono real " << tmono.best(REAL_TIMER)/rep << "s \t" << (double(m)*n*p*rep*2/tmono.best(REAL_TIMER))*1e-9 << " GFLOPS \t(" << tmono.total(REAL_TIMER) << "s)\n";
|
std::cout << "eigen mono real " << tmono.best(REAL_TIMER)/rep << "s \t" << (double(m)*n*p*rep*2/tmono.best(REAL_TIMER))*1e-9 << " GFLOPS \t(" << tmono.total(REAL_TIMER) << "s)\n";
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user