25#include "stencil_solver.hpp"
51template<
typename stencil_t,
typename stencil_container_t>
54 m_stencils(stencils), m_stride(stride)
63template<
typename stencil_t,
typename stencil_container_t>
74template<
typename stencil_t,
typename stencil_container_t>
77 return this->m_stencils->size();
86template<
typename stencil_t,
typename stencil_container_t>
89 return (*(this->m_stencils))[rowIndex];
98template<
typename stencil_t,
typename stencil_container_t>
101 return (*(this->m_stencils))[rowRawIndex];
111template<
typename stencil_t,
typename stencil_container_t>
114 return (*(this->m_stencils))[blockIndex * this->m_stride + componentIdx];
124template<
typename stencil_t,
typename stencil_container_t>
127 return (*(this->m_stencils))[blockRawIndex * this->m_stride + componentIdx];
135template<
typename stencil_t>
146template<
typename stencil_t>
149 return this->m_stencils->getKernel()->
size() * this->m_stride;
158template<
typename stencil_t>
161 return this->m_stencils->
at(rowIndex);
170template<
typename stencil_t>
173 return this->m_stencils->
rawAt(rowRawIndex);
183template<
typename stencil_t>
186 return this->m_stencils->
at(blockIndex, componentIdx);
196template<
typename stencil_t>
199 return this->m_stencils->
rawAt(blockRawIndex, componentIdx);
210#if BITPIT_ENABLE_MPI==1
218template<
typename stencil_t,
typename solver_kernel_t>
219template<
typename stencil_container_t,
typename... AssemblerKernelArgs>
221 AssemblerKernelArgs&&... assemblerKernelArgs)
235template<
typename stencil_t,
typename solver_kernel_t>
236template<
typename stencil_container_t,
typename... AssemblerKernelArgs>
238 const stencil_container_t *stencils,
239 AssemblerKernelArgs&&... assemblerKernelArgs)
242 std::forward<AssemblerKernelArgs>(assemblerKernelArgs)...)
253template<
typename stencil_t,
typename solver_kernel_t>
254template<
typename stencil_container_t,
typename... AssemblerKernelArgs>
256 AssemblerKernelArgs&&... assemblerKernelArgs)
258 std::forward<AssemblerKernelArgs>(assemblerKernelArgs)...)
263#if BITPIT_ENABLE_MPI==1
271template<
typename stencil_t,
typename solver_kernel_t>
272template<
typename... AssemblerKernelArgs>
274 AssemblerKernelArgs&&... assemblerKernelArgs)
288template<
typename stencil_t,
typename solver_kernel_t>
289template<
typename... AssemblerKernelArgs>
292 AssemblerKernelArgs&&... assemblerKernelArgs)
302template<typename stencil_t, typename solver_kernel_t>
303template<typename... AssemblerKernelArgs>
305 AssemblerKernelArgs&&... assemblerKernelArgs)
315#if BITPIT_ENABLE_MPI==1
322template<
typename stencil_t,
typename solver_kernel_t>
323template<
typename... AssemblerKernelArgs>
337template<
typename stencil_t,
typename solver_kernel_t>
338template<
typename... AssemblerKernelArgs>
340 AssemblerKernelArgs&&... assemblerKernelArgs)
341 : solver_kernel_type::Assembler(std::forward<AssemblerKernelArgs>(assemblerKernelArgs)...),
342 m_partitioned(partitioned), m_communicator(communicator)
352template<
typename stencil_t,
typename solver_kernel_t>
353template<
typename... AssemblerKernelArgs>
355 : solver_kernel_type::Assembler(std::forward<AssemblerKernelArgs>(assemblerKernelArgs)...)
360#if BITPIT_ENABLE_MPI==1
366template<
typename stencil_t,
typename solver_kernel_t>
369 return m_partitioned;
377template<
typename stencil_t,
typename solver_kernel_t>
380 return m_communicator;
389template<
typename stencil_t,
typename solver_kernel_t>
392 assembly_options_type options;
394 options.sorted =
false;
402template<
typename stencil_t,
typename solver_kernel_t>
403template<typename W, typename V, typename std::enable_if<std::is_fundamental<W>::value>::type *>
412template<
typename stencil_t,
typename solver_kernel_t>
413template<
typename W,
typename V, std::
size_t D,
typename std::enable_if<std::is_same<std::array<V, D>, W>::value>::type *>
416 setBlockSize(
sizeof(
typename StencilVector::weight_type) /
sizeof(
typename StencilVector::weight_type::value_type));
431template<
typename stencil_t,
typename solver_kernel_t>
432template<
typename W,
typename V,
typename std::enable_if<std::is_same<std::vector<V>, W>::value>::type *>
443 for (
long i = 0; i < getRowCount(); ++i) {
445 std::size_t stencilSize = stencil.
size();
446 if (stencilSize == 0) {
450 const StencilBlock::weight_type *weightData = stencil.
weightData();
451 double blockSize = std::sqrt(weightData[0].size());
452 if (blockSize != std::sqrt(weightData[0].size())) {
453 throw std::runtime_error(
"Weights size should be a square.");
455 setBlockSize(
static_cast<int>(blockSize));
459 if (m_blockSize == -1) {
460 throw std::runtime_error(
"All weights should have a size greater than zero.");
467 for (
long i = 0; i < getRowCount(); ++i) {
468 const StencilBlock &stencil = getRowStencil(i);
469 const StencilBlock::weight_type *weightData = stencil.weightData();
470 std::size_t stencilSize = stencil.size();
472 for (std::size_t k = 0; k < stencilSize; ++k) {
473 if (weightData[k].size() != m_blockSize)) {
474 throw std::runtime_error(
"All stencils weights should have the same size.");
478 if (stencil.getConstant().size() != m_blockSize) {
479 throw std::runtime_error(
"The stencils constant should have the same size of the stencil weights.");
490template<
typename stencil_t,
typename solver_kernel_t>
493 m_blockSize = blockSize;
501template<
typename stencil_t,
typename solver_kernel_t>
504 m_stencils = std::move(stencils);
510template<
typename stencil_t,
typename solver_kernel_t>
513 setMatrixSizes(m_stencils->size(), m_stencils->size());
522template<
typename stencil_t,
typename solver_kernel_t>
529#if BITPIT_ENABLE_MPI==1
531 m_nGlobalRows = nRows;
532 m_nGlobalCols = nCols;
534 MPI_Allreduce(MPI_IN_PLACE, &m_nGlobalRows, 1, MPI_LONG, MPI_SUM, m_communicator);
535 MPI_Allreduce(MPI_IN_PLACE, &m_nGlobalCols, 1, MPI_LONG, MPI_SUM, m_communicator);
539 m_globalRowOffset = 0;
540 m_globalColOffset = 0;
543 MPI_Comm_size(m_communicator, &nProcessors);
545 std::vector<long> nRankRows(nProcessors);
546 MPI_Allgather(&m_nRows, 1, MPI_LONG, nRankRows.data(), 1, MPI_LONG, m_communicator);
548 std::vector<long> nRankCols(nProcessors);
549 MPI_Allgather(&m_nCols, 1, MPI_LONG, nRankCols.data(), 1, MPI_LONG, m_communicator);
552 MPI_Comm_rank(m_communicator, &rank);
553 for (
int i = 0; i < rank; ++i) {
554 m_globalRowOffset += nRankRows[i];
555 m_globalColOffset += nRankCols[i];
564template<
typename stencil_t,
typename solver_kernel_t>
568 for (
long n = 0; n < getRowCount(); ++n) {
569 maxRowNZ = std::max(getRowNZCount(n), maxRowNZ);
572 setMaximumRowNZ(maxRowNZ);
580template<
typename stencil_t,
typename solver_kernel_t>
583 m_maxRowNZ = maxRowNZ;
591template<
typename stencil_t,
typename solver_kernel_t>
606template<
typename stencil_t,
typename solver_kernel_t>
621template<
typename stencil_t,
typename solver_kernel_t>
635template<
typename stencil_t,
typename solver_kernel_t>
638 long nRowElements = getBlockSize() * getRowCount();
651template<
typename stencil_t,
typename solver_kernel_t>
654 long nColElements = getBlockSize() * getColCount();
659#if BITPIT_ENABLE_MPI==1
669template<
typename stencil_t,
typename solver_kernel_t>
672 return m_nGlobalRows;
684template<
typename stencil_t,
typename solver_kernel_t>
687 return m_nGlobalCols;
698template<
typename stencil_t,
typename solver_kernel_t>
701 long nElements = getBlockSize() * getRowGlobalCount();
714template<
typename stencil_t,
typename solver_kernel_t>
717 long nElements = getBlockSize() * getColGlobalCount();
727template<
typename stencil_t,
typename solver_kernel_t>
730 return m_globalRowOffset;
738template<
typename stencil_t,
typename solver_kernel_t>
741 return m_globalColOffset;
752template<
typename stencil_t,
typename solver_kernel_t>
755 long offset = getBlockSize() * getRowGlobalOffset();
768template<
typename stencil_t,
typename solver_kernel_t>
771 long offset = getBlockSize() * getColGlobalOffset();
783template<
typename stencil_t,
typename solver_kernel_t>
786 const stencil_t &stencil = getRowStencil(rowIndex);
787 std::size_t stencilSize = stencil.size();
797template<
typename stencil_t,
typename solver_kernel_t>
814template<
typename stencil_t,
typename solver_kernel_t>
818 const stencil_t &stencil = getRowStencil(rowIndex);
821 getPattern(stencil, pattern);
830template<
typename stencil_t,
typename solver_kernel_t>
833 std::size_t stencilSize = stencil.size();
835 const long *patternData = stencil.patternData();
836 pattern->
set(patternData, stencilSize);
852template<
typename stencil_t,
typename solver_kernel_t>
856 const stencil_t &stencil = getRowStencil(rowIndex);
859 getValues(stencil, values);
870template<
typename stencil_t,
typename solver_kernel_t>
871template<typename U, typename std::enable_if<std::is_fundamental<U>::value>::type *>
874 values->
set(stencil.weightData(), m_blockSize * stencil.size());
885template<
typename stencil_t,
typename solver_kernel_t>
886template<typename U, typename std::enable_if<!std::is_fundamental<U>::value>::type *>
889 std::size_t stencilSize = stencil.size();
890 const stencil_weight_type *stencilWeightData = stencil.weightData();
892 int nBlockElements = m_blockSize * m_blockSize;
893 std::size_t nRowValues = m_blockSize * stencilSize;
894 std::size_t nValues = nBlockElements * stencilSize;
899 for (std::size_t k = 0; k < stencilSize; ++k) {
900 const double *weightData = stencilWeightData[k].data();
901 for (
int i = 0; i < m_blockSize; ++i) {
905 std::copy_n(weightData + weightOffset, m_blockSize, expandedValuesStorage + valuesOffset);
919template<
typename stencil_t,
typename solver_kernel_t>
923 const stencil_t &stencil = getRowStencil(rowIndex);
926 getPattern(stencil, pattern);
929 getValues(stencil, values);
940template<
typename stencil_t,
typename solver_kernel_t>
944 const stencil_t &stencil = getRowStencil(rowIndex);
947 getConstant(stencil, constant);
958template<
typename stencil_t,
typename solver_kernel_t>
959template<typename U, typename std::enable_if<std::is_fundamental<U>::value>::type *>
962 const stencil_weight_type &stencilConstant = stencil.getConstant();
966 std::copy_n(&stencilConstant, m_blockSize, constantStorage);
977template<
typename stencil_t,
typename solver_kernel_t>
978template<typename U, typename std::enable_if<!std::is_fundamental<U>::value>::type *>
981 const stencil_weight_type &stencilConstant = stencil.getConstant();
985 for (
int i = 0; i < m_blockSize; ++i) {
988 constantStorage[i] = 0;
989 for (
int j = 0; j < m_blockSize; ++j) {
990 constantStorage[i] += stencil.getWeightManager().
at(stencilConstant, offset_i + j);
1001template<
typename stencil_t,
typename solver_kernel_t>
1004 return m_stencils->at(rowIndex);
1018template<
typename stencil_t,
typename solver_kernel_t>
1021 solver_kernel_t::clear();
1023 std::vector<double>().swap(m_constants);
1026#if BITPIT_ENABLE_MPI==1
1034template<
typename stencil_t,
typename solver_kernel_t>
1035template<
typename stencil_container_t>
1038 assembly(MPI_COMM_SELF,
false, stencils);
1048template<
typename stencil_t,
typename solver_kernel_t>
1049template<
typename stencil_container_t>
1057template<
typename stencil_t,
typename solver_kernel_t>
1058template<
typename stencil_container_t>
1063#if BITPIT_ENABLE_MPI==1
1070 solver_kernel_t::template assembly<DiscretizationStencilSolver<stencil_t, solver_kernel_t>>(assembler,
NaturalSystemMatrixOrdering());
1081template<
typename stencil_t,
typename solver_kernel_t>
1084 solver_kernel_t::template assembly<DiscretizationStencilSolver<stencil_t, solver_kernel_t>>(assembler,
NaturalSystemMatrixOrdering());
1092template<
typename stencil_t,
typename solver_kernel_t>
1096 solver_kernel_t::matrixAssembly(assembler);
1099 assembleConstants(assembler);
1118template<
typename stencil_t,
typename solver_kernel_t>
1122 solver_kernel_t::matrixUpdate(nRows, rows, assembler);
1125 updateConstants(nRows, rows, assembler);
1136template<
typename stencil_t,
typename solver_kernel_t>
1137template<
typename stencil_container_t>
1140 update(this->getRowCount(),
nullptr, stencils);
1152template<
typename stencil_t,
typename solver_kernel_t>
1153template<
typename stencil_container_t>
1156 update(rows.size(), rows.data(), stencils);
1171template<
typename stencil_t,
typename solver_kernel_t>
1172template<
typename stencil_container_t>
1175#if BITPIT_ENABLE_MPI==1
1182 solver_kernel_t::template update<DiscretizationStencilSolver<stencil_t, solver_kernel_t>>(nRows, rows, assembler);
1196template<
typename stencil_t,
typename solver_kernel_t>
1199 solver_kernel_t::template update<DiscretizationStencilSolver<stencil_t, solver_kernel_t>>(nRows, rows, assembler);
1207template<
typename stencil_t,
typename solver_kernel_t>
1214 m_constants.resize(nRows * blockSize);
1217 updateConstants(nRows,
nullptr, assembler);
1229template<
typename stencil_t,
typename solver_kernel_t>
1235 for (std::size_t n = 0; n < nRows; ++n) {
1244 std::copy_n(rowConstant.
data(), blockSize, m_constants.data() + row * blockSize);
1251template<
typename stencil_t,
typename solver_kernel_t>
1255 if (!this->isAssembled()) {
1256 throw std::runtime_error(
"Unable to solve the system. The stencil solver is not yet assembled.");
1260 long nUnknowns = this->getBlockSize() * this->getRowCount();
1261 double *raw_rhs = this->getRHSRawPtr();
1262 for (
long i = 0; i < nUnknowns; ++i) {
1263 raw_rhs[i] -= m_constants[i];
1265 this->restoreRHSRawPtr(raw_rhs);
1268 solver_kernel_t::solve();
Metafunction for generating a discretization stencil.
The DiscretizationStencilProxyBaseStorage class defines a proxy for stencil storage.
DiscretizationStencilProxyBaseStorage(const stencil_container_t *stencils, int stride)
const stencil_t & rawAt(std::size_t rowRawIndex) const override
const stencil_t & at(long rowIndex) const override
std::size_t size() const override
DiscretizationStencilProxyStorage(const stencil_container_t *stencils)
The DiscretizationStencilSolverAssembler class defines an assembler for building the stencil solver.
long getRowNZCount(long rowIndex) const override
virtual void getRowConstant(long rowIndex, bitpit::ConstProxyVector< double > *constant) const
void getRowValues(long rowIndex, ConstProxyVector< double > *values) const override
void getRowPattern(long rowIndex, ConstProxyVector< long > *pattern) const override
long getRowGlobalCount() const override
void setStencils(std::unique_ptr< DiscretizationStencilStorageInterface< stencil_t > > &&stencils)
long getColElementCount() const override
long getColGlobalElementOffset() const override
int getBlockSize() const override
long getMaxRowNZCount() const override
DiscretizationStencilSolverAssembler(const stencil_container_t *stencils, AssemblerKernelArgs &&... assemblerKernelArgs)
long getRowGlobalElementCount() const override
long getRowElementCount() const override
void getValues(const stencil_t &stencil, ConstProxyVector< double > *values) const
long getColCount() const override
void getRowData(long rowIndex, ConstProxyVector< long > *pattern, ConstProxyVector< double > *values) const override
long getRowGlobalOffset() const override
long getRowGlobalElementOffset() const override
long getColGlobalOffset() const override
long getRowCount() const override
const MPI_Comm & getCommunicator() const override
long getColGlobalElementCount() const override
void getPattern(const stencil_t &stencil, ConstProxyVector< long > *pattern) const
assembly_options_type getOptions() const override
virtual const stencil_t & getRowStencil(long rowIndex) const
bool isPartitioned() const override
void getConstant(const stencil_t &stencil, bitpit::ConstProxyVector< double > *constant) const
long getColGlobalCount() const override
void matrixAssembly(const Assembler &assembler)
void assembleConstants(const Assembler &assembler)
void matrixUpdate(long nRows, const long *rows, const Assembler &assembler)
void assembly(const stencil_container_t &stencils)
void update(const stencil_container_t &stencils)
void updateConstants(std::size_t nRows, const long *rows, const Assembler &assembler)
The DiscretizationStencilStorageInterface class defines the interface for stencil storage.
The NaturalSystemMatrixOrdering class defines allows to use a matrix natural ordering.
Metafunction for generating a pierced storage.
Metafunction for generating a list of elements that can be either stored in an external vectror or,...
container_type::pointer storage_pointer
__PXV_POINTER__ data() noexcept
__PXV_STORAGE_POINTER__ storedData() noexcept
__PXV_REFERENCE__ at(std::size_t n)
void set(__PXV_POINTER__ data, std::size_t size)
int linearIndexRowMajor(int row, int col, int nRows, int nCols)