Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 101 additions & 35 deletions src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include <functional>
#include <limits>
#include <type_traits>

#include "storm/adapters/RationalNumberForward.h"
#include "storm/solver/IterativeMinMaxLinearEquationSolver.h"
#include "storm/solver/LinearEquationSolver.h"
#include "storm/solver/OptimizationDirection.h"
Expand Down Expand Up @@ -116,24 +118,44 @@ bool IterativeMinMaxLinearEquationSolver<ValueType, SolutionType>::internalSolve

template<typename ValueType, typename SolutionType>
void IterativeMinMaxLinearEquationSolver<ValueType, SolutionType>::setUpViOperator() const {
if (!viOperator) {
viOperator = std::make_shared<helper::ValueIterationOperator<ValueType, std::is_same_v<ValueType, storm::Interval>, SolutionType>>();
viOperator->setMatrixBackwards(*this->A);
if (!viOperatorTriv && !viOperatorNontriv) {
if (this->A->hasTrivialRowGrouping()) {
// The trivial row grouping minmax operator makes sense over intervals.
viOperatorTriv = std::make_shared<helper::ValueIterationOperator<ValueType, true, SolutionType>>();
viOperatorTriv->setMatrixBackwards(*this->A);
if constexpr (!std::is_same_v<ValueType, storm::Interval>) {
// It might be that someone is using a minmaxlinearequationsolver with an advanced VI algorithm
// but is just passing a DTMC over doubles. In this case we need to populate this VI operator.
// It behaves exactly the same as the trivial row grouping operator, but it is currently hardcoded
// to be used by, e.g., optimistic value iteration.
viOperatorNontriv = std::make_shared<helper::ValueIterationOperator<ValueType, false, SolutionType>>();
viOperatorNontriv->setMatrixBackwards(*this->A);
}
} else {
// The nontrivial row grouping minmax operator makes sense for MDPs.
viOperatorNontriv = std::make_shared<helper::ValueIterationOperator<ValueType, false, SolutionType>>();
viOperatorNontriv->setMatrixBackwards(*this->A);
}
}
if (this->choiceFixedForRowGroup) {
// Ignore those rows that are not selected
assert(this->initialScheduler);
auto callback = [&](uint64_t groupIndex, uint64_t localRowIndex) {
return this->choiceFixedForRowGroup->get(groupIndex) && this->initialScheduler->at(groupIndex) != localRowIndex;
};
viOperator->setIgnoredRows(true, callback);
if (viOperatorTriv) {
viOperatorTriv->setIgnoredRows(true, callback);
}
if (viOperatorNontriv) {
viOperatorNontriv->setIgnoredRows(true, callback);
}
}
}

template<typename ValueType, typename SolutionType>
void IterativeMinMaxLinearEquationSolver<ValueType, SolutionType>::extractScheduler(std::vector<SolutionType>& x, std::vector<ValueType> const& b,
OptimizationDirection const& dir, bool updateX, bool robust) const {
if constexpr (std::is_same_v<ValueType, storm::Interval>) {
if (std::is_same_v<ValueType, storm::Interval> && this->A->hasTrivialRowGrouping()) {
// Create robust scheduler index if it doesn't exist yet
if (!this->robustSchedulerIndex) {
this->robustSchedulerIndex = std::vector<uint64_t>(x.size(), 0);
Expand Down Expand Up @@ -161,12 +183,23 @@ void IterativeMinMaxLinearEquationSolver<ValueType, SolutionType>::extractSchedu
}

// Set the correct choices.
STORM_LOG_WARN_COND(viOperator, "Expected VI operator to be initialized for scheduler extraction. Initializing now, but this is inefficient.");
if (!viOperator) {
STORM_LOG_WARN_COND(!viOperatorTriv && !viOperatorNontriv,
"Expected VI operator to be initialized for scheduler extraction. Initializing now, but this is inefficient.");
if (!viOperatorTriv && !viOperatorNontriv) {
setUpViOperator();
}
storm::solver::helper::SchedulerTrackingHelper<ValueType, SolutionType, std::is_same_v<ValueType, storm::Interval>> schedHelper(viOperator);
schedHelper.computeScheduler(x, b, dir, *this->schedulerChoices, robust, updateX ? &x : nullptr, this->robustSchedulerIndex);
if (viOperatorTriv) {
if constexpr (std::is_same<ValueType, storm::Interval>() && std::is_same<SolutionType, double>()) {
storm::solver::helper::SchedulerTrackingHelper<ValueType, SolutionType, true> schedHelper(viOperatorTriv);
schedHelper.computeScheduler(x, b, dir, *this->schedulerChoices, robust, updateX ? &x : nullptr, this->robustSchedulerIndex);
} else {
STORM_LOG_ERROR("SchedulerTrackingHelper not implemented for this setting (trivial row grouping but not Interval->double).");
}
}
if (viOperatorNontriv) {
storm::solver::helper::SchedulerTrackingHelper<ValueType, SolutionType, false> schedHelper(viOperatorNontriv);
schedHelper.computeScheduler(x, b, dir, *this->schedulerChoices, robust, updateX ? &x : nullptr, this->robustSchedulerIndex);
}
}

template<typename ValueType, typename SolutionType>
Expand Down Expand Up @@ -525,7 +558,7 @@ bool IterativeMinMaxLinearEquationSolver<ValueType, SolutionType>::solveEquation

setUpViOperator();

helper::OptimisticValueIterationHelper<ValueType, false> oviHelper(viOperator);
helper::OptimisticValueIterationHelper<ValueType, false> oviHelper(viOperatorNontriv);
auto prec = storm::utility::convertNumber<ValueType>(env.solver().minMax().getPrecision());
std::optional<ValueType> lowerBound, upperBound;
if (this->hasLowerBound()) {
Expand Down Expand Up @@ -619,28 +652,56 @@ bool IterativeMinMaxLinearEquationSolver<ValueType, SolutionType>::solveEquation
}
}

storm::solver::helper::ValueIterationHelper<ValueType, std::is_same_v<ValueType, storm::Interval>, SolutionType> viHelper(viOperator);
uint64_t numIterations{0};
auto viCallback = [&](SolverStatus const& current) {
this->showProgressIterative(numIterations);
return this->updateStatus(current, x, guarantee, numIterations, env.solver().minMax().getMaximalNumberOfIterations());
};
this->startMeasureProgress();
auto status = viHelper.VI(x, b, numIterations, env.solver().minMax().getRelativeTerminationCriterion(),
storm::utility::convertNumber<SolutionType>(env.solver().minMax().getPrecision()), dir, viCallback,
env.solver().minMax().getMultiplicationStyle(), this->isUncertaintyRobust());
this->reportStatus(status, numIterations);

// If requested, we store the scheduler for retrieval.
if (this->isTrackSchedulerSet()) {
this->extractScheduler(x, b, dir, this->isUncertaintyRobust());
}
// This code duplication is necessary because the helper class is different for the two cases.
if (this->A->hasTrivialRowGrouping()) {
storm::solver::helper::ValueIterationHelper<ValueType, true, SolutionType> viHelper(viOperatorTriv);

if (!this->isCachingEnabled()) {
clearCache();
}
uint64_t numIterations{0};
auto viCallback = [&](SolverStatus const& current) {
this->showProgressIterative(numIterations);
return this->updateStatus(current, x, guarantee, numIterations, env.solver().minMax().getMaximalNumberOfIterations());
};
this->startMeasureProgress();
auto status = viHelper.VI(x, b, numIterations, env.solver().minMax().getRelativeTerminationCriterion(),
storm::utility::convertNumber<SolutionType>(env.solver().minMax().getPrecision()), dir, viCallback,
env.solver().minMax().getMultiplicationStyle(), this->isUncertaintyRobust());
this->reportStatus(status, numIterations);

// If requested, we store the scheduler for retrieval.
if (this->isTrackSchedulerSet()) {
this->extractScheduler(x, b, dir, this->isUncertaintyRobust());
}

if (!this->isCachingEnabled()) {
clearCache();
}

return status == SolverStatus::Converged || status == SolverStatus::TerminatedEarly;
} else {
storm::solver::helper::ValueIterationHelper<ValueType, false, SolutionType> viHelper(viOperatorNontriv);

uint64_t numIterations{0};
auto viCallback = [&](SolverStatus const& current) {
this->showProgressIterative(numIterations);
return this->updateStatus(current, x, guarantee, numIterations, env.solver().minMax().getMaximalNumberOfIterations());
};
this->startMeasureProgress();
auto status = viHelper.VI(x, b, numIterations, env.solver().minMax().getRelativeTerminationCriterion(),
storm::utility::convertNumber<SolutionType>(env.solver().minMax().getPrecision()), dir, viCallback,
env.solver().minMax().getMultiplicationStyle(), this->isUncertaintyRobust());
this->reportStatus(status, numIterations);

// If requested, we store the scheduler for retrieval.
if (this->isTrackSchedulerSet()) {
this->extractScheduler(x, b, dir, this->isUncertaintyRobust());
}

if (!this->isCachingEnabled()) {
clearCache();
}

return status == SolverStatus::Converged || status == SolverStatus::TerminatedEarly;
return status == SolverStatus::Converged || status == SolverStatus::TerminatedEarly;
}
}

template<typename ValueType, typename SolutionType>
Expand All @@ -663,7 +724,7 @@ bool IterativeMinMaxLinearEquationSolver<ValueType, SolutionType>::solveEquation
return false;
} else {
setUpViOperator();
helper::IntervalIterationHelper<ValueType, false> iiHelper(viOperator);
helper::IntervalIterationHelper<ValueType, false> iiHelper(viOperatorNontriv);
auto prec = storm::utility::convertNumber<ValueType>(env.solver().minMax().getPrecision());
auto lowerBoundsCallback = [&](std::vector<SolutionType>& vector) { this->createLowerBoundsVector(vector); };
auto upperBoundsCallback = [&](std::vector<SolutionType>& vector) { this->createUpperBoundsVector(vector); };
Expand Down Expand Up @@ -727,7 +788,7 @@ bool IterativeMinMaxLinearEquationSolver<ValueType, SolutionType>::solveEquation
numIterations, env.solver().minMax().getMaximalNumberOfIterations());
};
this->startMeasureProgress();
helper::SoundValueIterationHelper<ValueType, false> sviHelper(viOperator);
helper::SoundValueIterationHelper<ValueType, false> sviHelper(viOperatorNontriv);
std::optional<storm::storage::BitVector> optionalRelevantValues;
if (this->hasRelevantValues()) {
optionalRelevantValues = this->getRelevantValues();
Expand Down Expand Up @@ -806,14 +867,14 @@ bool IterativeMinMaxLinearEquationSolver<ValueType, SolutionType>::solveEquation
}

if constexpr (std::is_same_v<ValueType, storm::RationalNumber>) {
exactOp = viOperator;
exactOp = viOperatorNontriv;
impreciseOp = std::make_shared<helper::ValueIterationOperator<double, false>>();
impreciseOp->setMatrixBackwards(this->A->template toValueType<double>(), &this->A->getRowGroupIndices());
if (this->choiceFixedForRowGroup) {
impreciseOp->setIgnoredRows(true, fixedChoicesCallback);
}
} else if constexpr (std::is_same_v<ValueType, double>) {
impreciseOp = viOperator;
impreciseOp = viOperatorNontriv;
exactOp = std::make_shared<helper::ValueIterationOperator<storm::RationalNumber, false>>();
exactOp->setMatrixBackwards(this->A->template toValueType<storm::RationalNumber>(), &this->A->getRowGroupIndices());
if (this->choiceFixedForRowGroup) {
Expand Down Expand Up @@ -848,7 +909,12 @@ bool IterativeMinMaxLinearEquationSolver<ValueType, SolutionType>::solveEquation
template<typename ValueType, typename SolutionType>
void IterativeMinMaxLinearEquationSolver<ValueType, SolutionType>::clearCache() const {
auxiliaryRowGroupVector.reset();
viOperator.reset();
if (viOperatorTriv) {
viOperatorTriv.reset();
}
if (viOperatorNontriv) {
viOperatorNontriv.reset();
}
StandardMinMaxLinearEquationSolver<ValueType, SolutionType>::clearCache();
}

Expand Down
6 changes: 5 additions & 1 deletion src/storm/solver/IterativeMinMaxLinearEquationSolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,11 @@ class IterativeMinMaxLinearEquationSolver : public StandardMinMaxLinearEquationS
std::unique_ptr<LinearEquationSolverFactory<SolutionType>> linearEquationSolverFactory;

// possibly cached data
mutable std::shared_ptr<storm::solver::helper::ValueIterationOperator<ValueType, std::is_same_v<ValueType, storm::Interval>, SolutionType>> viOperator;

// two different VI operators, one for trivialrowgrouping, one without
mutable std::shared_ptr<storm::solver::helper::ValueIterationOperator<ValueType, true, SolutionType>> viOperatorTriv;
mutable std::shared_ptr<storm::solver::helper::ValueIterationOperator<ValueType, false, SolutionType>> viOperatorNontriv;

mutable std::unique_ptr<std::vector<ValueType>> auxiliaryRowGroupVector; // A.rowGroupCount() entries
};

Expand Down
1 change: 0 additions & 1 deletion src/storm/solver/helper/SchedulerTrackingHelper.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#pragma once
#include <memory>
#include <optional>
#include <vector>

#include "storm/solver/OptimizationDirection.h"
Expand Down
13 changes: 7 additions & 6 deletions src/storm/solver/helper/ValueIterationOperator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

#include <optional>

#include "storm/adapters/RationalNumberAdapter.h"
#include "storm/adapters/RationalNumberForward.h"
#include "storm/storage/BitVector.h"
#include "storm/storage/SparseMatrix.h"
Expand Down Expand Up @@ -31,6 +30,13 @@ void ValueIterationOperator<ValueType, TrivialRowGrouping, SolutionType>::setMat
matrixColumns.clear();
matrixValues.reserve(matrix.getNonzeroEntryCount());
matrixColumns.reserve(matrix.getNonzeroEntryCount() + numRows + 1); // matrixColumns also contain indications for when a row(group) starts

// hasOnlyConstants is only used for Interval matrices, currently only populated for iMCs
if constexpr (std::is_same<ValueType, storm::Interval>::value) {
applyCache.hasOnlyConstants.clear();
applyCache.hasOnlyConstants.grow(matrix.getRowCount());
}

if constexpr (!TrivialRowGrouping) {
matrixColumns.push_back(StartOfRowGroupIndicator); // indicate start of first row(group)
for (auto groupIndex : indexRange<Backward>(0, this->rowGroupIndices->size() - 1)) {
Expand All @@ -47,11 +53,6 @@ void ValueIterationOperator<ValueType, TrivialRowGrouping, SolutionType>::setMat
}
} else {
if constexpr (std::is_same<ValueType, storm::Interval>::value) {
applyCache.hasOnlyConstants.clear();
applyCache.hasOnlyConstants.grow(matrix.getRowCount());
// TODO Implement hasTwoSuccessors
// applyCache.hasTwoSuccessors.clear();
// applyCache.hasTwoSuccessors.grow(matrix.getRowCount());
matrixColumns.push_back(StartOfRowIndicator); // Indicate start of first row
for (auto rowIndex : indexRange<Backward>(0, numRows)) {
bool hasOnlyConstants = true;
Expand Down
4 changes: 0 additions & 4 deletions src/storm/solver/helper/ValueIterationOperator.h
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,6 @@ class ValueIterationOperator {
return result;
}

// TODO I think this is a problem if we have probabilities and a state that is going to the vector, we don't count that
// Currently "fixed in preprocessing"
// It's different for rewards (same problem somewhere else, search for word "octopus" in codebase)
SolutionType remainingValue{storm::utility::one<SolutionType>()};
uint64_t orderCounter = 0;
for (++matrixColumnIt; *matrixColumnIt < StartOfRowIndicator; ++matrixColumnIt, ++matrixValueIt, ++orderCounter) {
Expand Down Expand Up @@ -460,7 +457,6 @@ class ValueIterationOperator {
struct ApplyCache<storm::Interval, Dummy> {
mutable std::vector<std::pair<SolutionType, std::pair<SolutionType, uint64_t>>> robustOrder;
storage::BitVector hasOnlyConstants;
storage::BitVector hasTwoSuccessors;
};

/*!
Expand Down
Loading