Skip to content

Commit

Permalink
VRT expressions: Add loop runtime checks
Browse files Browse the repository at this point in the history
  • Loading branch information
dbaston committed Nov 19, 2024
1 parent 1d8819d commit 57f8062
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 73 deletions.
13 changes: 13 additions & 0 deletions autotest/gdrivers/vrtprocesseddataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1324,6 +1324,19 @@ def test_vrtprocesseddataset_trimming_errors(tmp_vsimem):
{"GDAL_EXPRTK_ENABLE_LOOPS": "NO"},
id="loops disabled",
),
pytest.param(
"""
for (var i := 0; i < 20000; i += 1) {
sleep(0.2/20000);
}
return [B1];
""",
np.array([[[1, 2]]]),
np.array([[1, 2]]),
"Loop run-time exceeded maximum",
{"GDAL_EXPRTK_MAX_LOOP_ITERATION_SECONDS": "0.1", "CPL_DEBUG": "ON"},
id="loop evaluation timeout",
),
],
)
def test_vrtprocesseddataset_expression(
Expand Down
246 changes: 173 additions & 73 deletions frmts/vrt/vrtexpression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
#include <cstdint>
#include <sstream>

#include <thread>
#include <chrono>

struct vector_access_check final : public exprtk::vector_access_runtime_check
{
bool handle_runtime_violation(violation_context &context) override
Expand All @@ -44,6 +47,85 @@ struct vector_access_check final : public exprtk::vector_access_runtime_check
}
};

struct loop_timeout_check final : public exprtk::loop_runtime_check
{
using time_point_t = std::chrono::time_point<std::chrono::steady_clock>;

loop_timeout_check() : exprtk::loop_runtime_check()
{
double dfMaxLoopIterationSeconds = CPLAtofM(
CPLGetConfigOption("GDAL_EXPRTK_MAX_LOOP_ITERATION_SECONDS", "1"));
max_duration = std::chrono::microseconds(
static_cast<size_t>(dfMaxLoopIterationSeconds * 1e6));
}

void start_timer()
{
timeout_t = std::chrono::steady_clock::now() + max_duration;
}

bool check() override
{

if (++iterations >= max_iters_per_check)
{
if (std::chrono::steady_clock::now() > timeout_t)
{
return false;
}

iterations = 0;
}

return true;
}

void handle_runtime_violation(const violation_context &context) override
{
std::ostringstream oss;

if (context.violation == violation_type::e_iteration_count)
{
oss << "Exceeded maximium of " << max_loop_iterations
<< " loop iterations.";
}
else if (context.violation == violation_type::e_timeout)
{
oss << "Loop run-time exceeded maximum of "
<< static_cast<double>(max_duration.count() / 1e6)
<< " seconds. You can increase this threshold by setting the "
<< "GDAL_EXPRTK_MAX_LOOP_ITERATION_SECONDS configuration "
<< "option.";
}

throw std::runtime_error(oss.str());
}

static constexpr size_t max_iters_per_check = 10000;
size_t iterations = 0;
time_point_t timeout_t{};
std::chrono::microseconds max_duration{};
};

namespace
{

struct sleep_fn : public exprtk::ifunction<double>
{
sleep_fn() : exprtk::ifunction<double>(1)
{
}

double operator()(const double &seconds) override
{
std::this_thread::sleep_for(
std::chrono::microseconds(static_cast<int>(seconds * 1e6)));
return 0;
}
};

} // namespace

class GDALExpressionEvaluator::Impl
{
public:
Expand All @@ -56,14 +138,26 @@ class GDALExpressionEvaluator::Impl
std::vector<std::pair<std::string, std::vector<double> *>> m_aoVectors{};
std::vector<double> m_adfResults{};
vector_access_check m_oVectorAccessCheck{};
loop_timeout_check m_oLoopRuntimeCheck{};

bool m_bIsCompiled{false};

Impl()
sleep_fn sleep{};

explicit Impl()
{
using settings_t = std::decay_t<decltype(m_oParser.settings())>;

m_oLoopRuntimeCheck.loop_set = loop_timeout_check::e_all_loops;
m_oLoopRuntimeCheck.max_loop_iterations = std::numeric_limits<
decltype(m_oLoopRuntimeCheck.max_loop_iterations)>::max();
m_oParser.register_vector_access_runtime_check(m_oVectorAccessCheck);
m_oParser.register_loop_runtime_check(m_oLoopRuntimeCheck);

if (CPLTestBool(CPLGetConfigOption("CPL_DEBUG", "OFF")))
{
m_oSymbolTable.add_function("sleep", sleep);
}

int nMaxVectorLength = std::atoi(
CPLGetConfigOption("GDAL_EXPRTK_MAX_VECTOR_LENGTH", "100000"));
Expand Down Expand Up @@ -138,6 +232,83 @@ class GDALExpressionEvaluator::Impl

return CE_None;
}

CPLErr evaluate()
{
if (!m_bIsCompiled)
{
auto eErr = compile();
if (eErr != CE_None)
{
return eErr;
}
}

m_adfResults.clear();
double value;
try
{
value = m_oExpression.value(); // force evaluation
}
catch (const std::exception &e)
{
CPLError(CE_Failure, CPLE_AppDefined, "%s", e.what());
return CE_Failure;
}

m_oLoopRuntimeCheck.start_timer();
const auto &results = m_oExpression.results();

// We follow a different method to get the result depending on
// how the expression was formed. If a "return" statement was
// used, the result will be accessible via the "result" object.
// If no "return" statement was used, the result is accessible
// from the "value" variable (and must not be a vector.)
if (results.count() == 0)
{
m_adfResults.resize(1);
m_adfResults[0] = value;
}
else if (results.count() == 1)
{

if (results[0].type == exprtk::type_store<double>::e_scalar)
{
m_adfResults.resize(1);
results.get_scalar(0, m_adfResults[0]);
}
else if (results[0].type == exprtk::type_store<double>::e_vector)
{
results.get_vector(0, m_adfResults);
}
else
{
CPLError(CE_Failure, CPLE_AppDefined,
"Expression returned an unexpected type.");
return CE_Failure;
}
}
else
{
m_adfResults.resize(results.count());
for (size_t i = 0; i < results.count(); i++)
{
if (results[i].type != exprtk::type_store<double>::e_scalar)
{
CPLError(CE_Failure, CPLE_AppDefined,
"Expression must return a vector or a list of "
"scalars.");
return CE_Failure;
}
else
{
results.get_scalar(i, m_adfResults[i]);
}
}
}

return CE_None;
}
};

GDALExpressionEvaluator::GDALExpressionEvaluator(std::string_view osExpression)
Expand Down Expand Up @@ -174,76 +345,5 @@ const std::vector<double> &GDALExpressionEvaluator::Results() const

CPLErr GDALExpressionEvaluator::Evaluate()
{
if (!m_pImpl->m_bIsCompiled)
{
auto eErr = m_pImpl->compile();
if (eErr != CE_None)
{
return eErr;
}
}

m_pImpl->m_adfResults.clear();
double value;
try
{
value = m_pImpl->m_oExpression.value(); // force evaluation
}
catch (const std::exception &e)
{
CPLError(CE_Failure, CPLE_AppDefined, "%s", e.what());
return CE_Failure;
}

const auto &results = m_pImpl->m_oExpression.results();

// We follow a different method to get the result depending on
// how the expression was formed. If a "return" statement was
// used, the result will be accessible via the "result" object.
// If no "return" statement was used, the result is accessible
// from the "value" variable (and must not be a vector.)
if (results.count() == 0)
{
m_pImpl->m_adfResults.resize(1);
m_pImpl->m_adfResults[0] = value;
}
else if (results.count() == 1)
{

if (results[0].type == exprtk::type_store<double>::e_scalar)
{
m_pImpl->m_adfResults.resize(1);
results.get_scalar(0, m_pImpl->m_adfResults[0]);
}
else if (results[0].type == exprtk::type_store<double>::e_vector)
{
results.get_vector(0, m_pImpl->m_adfResults);
}
else
{
CPLError(CE_Failure, CPLE_AppDefined,
"Expression returned an unexpected type.");
return CE_Failure;
}
}
else
{
m_pImpl->m_adfResults.resize(results.count());
for (size_t i = 0; i < results.count(); i++)
{
if (results[i].type != exprtk::type_store<double>::e_scalar)
{
CPLError(CE_Failure, CPLE_AppDefined,
"Expression must return a vector or a list of "
"scalars.");
return CE_Failure;
}
else
{
results.get_scalar(i, m_pImpl->m_adfResults[i]);
}
}
}

return CE_None;
return m_pImpl->evaluate();
}

0 comments on commit 57f8062

Please sign in to comment.