Skip to content

Commit

Permalink
VRT expressions: Unify pixel function and processed dataset impementa…
Browse files Browse the repository at this point in the history
…tions
  • Loading branch information
dbaston committed Nov 12, 2024
1 parent d17ee41 commit f001051
Show file tree
Hide file tree
Showing 7 changed files with 278 additions and 140 deletions.
4 changes: 3 additions & 1 deletion autotest/gdrivers/vrtderived.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,8 @@ def vrt_expression_xml(tmpdir, expression, sources):
nx = 1
ny = 1

expression = expression.replace("<", "&lt;").replace(">", "&gt;")

xml = f"""<VRTDataset rasterXSize="{nx}" rasterYSize="{ny}">
<VRTRasterBand dataType="Float64" band="1" subClass="VRTDerivedRasterBand">
<PixelFunctionType>expression</PixelFunctionType>
Expand Down Expand Up @@ -1183,7 +1185,7 @@ def test_vrt_pixelfn_expression(tmp_path, expression, sources, result):
pytest.param(
"A*B + C",
[("A", 77), ("B", 63)],
"failed to parse expression",
"Failed to parse expression",
id="undefined variable",
),
],
Expand Down
29 changes: 25 additions & 4 deletions autotest/gdrivers/vrtprocesseddataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,6 +1219,23 @@ def test_vrtprocesseddataset_trimming_errors(tmp_vsimem):
None,
id="multiple bands in, multiple bands out (2)",
),
pytest.param(
"""
var chunk[10];
var out[5];
for (var i := 0; i < out[]; i += 1) {
for (var j := 0; j < chunk[]; j += 1) {
chunk[j] := ALL_BANDS[i * chunk[] + j];
};
out[i] := avg(chunk);
};
return [out];
""",
np.arange(100).reshape(50, 1, 2),
np.array([[[9, 10]], [[29, 30]], [[49, 50]], [[69, 70]], [[89, 90]]]),
None,
id="procedural",
),
pytest.param(
"B1",
np.array([[[1, 2]], [[3, 4]], [[5, 6]]]),
Expand Down Expand Up @@ -1259,12 +1276,16 @@ def test_vrtprocesseddataset_trimming_errors(tmp_vsimem):
def test_vrtprocesseddataset_expression(tmp_vsimem, expression, src, expected, error):

src_filename = tmp_vsimem / "src.tif"
with gdal.GetDriverByName("GTiff").Create(src_filename, 2, 1, 3) as src_ds:
src_ds.WriteArray(src)
src_ds.SetGeoTransform([0, 1, 0, 0, 0, 1])

num_input_bands = 1 if len(src.shape) == 2 else src.shape[0]
expected_output_bands = 1 if len(expected.shape) == 2 else expected.shape[0]

with gdal.GetDriverByName("GTiff").Create(
src_filename, 2, 1, num_input_bands
) as src_ds:
src_ds.WriteArray(src)
src_ds.SetGeoTransform([0, 1, 0, 0, 0, 1])

output_band_xml = "".join(
f"""<VRTRasterBand band="{i+1}" dataType="Float32" subClass="VRTProcessedRasterBand"/>"""
for i in range(expected_output_bands)
Expand All @@ -1278,7 +1299,7 @@ def test_vrtprocesseddataset_expression(tmp_vsimem, expression, src, expected, e
<ProcessingSteps>
<Step>
<Algorithm>Expression</Algorithm>
<Argument name="expression">{expression}</Argument>
<Argument name="expression">{expression.replace('<', '&lt;').replace('>', '&gt;')}</Argument>
</Step>
</ProcessingSteps>
{output_band_xml}
Expand Down
4 changes: 4 additions & 0 deletions frmts/vrt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ add_gdal_driver(
vrtdataset.h
vrtderivedrasterband.cpp
vrtdriver.cpp
vrtexpression.h
vrtexpression.cpp
vrtfilters.cpp
vrtrasterband.cpp
vrtrawrasterband.cpp
Expand All @@ -27,6 +29,7 @@ if (MSVC)
target_compile_options(gdal_vrt PRIVATE /bigobj)
endif()


set(GDAL_DATA_FILES
${CMAKE_CURRENT_SOURCE_DIR}/data/gdalvrt.xsd
)
Expand All @@ -49,3 +52,4 @@ set_property(SOURCE vrtwarped.cpp PROPERTY SKIP_UNITY_BUILD_INCLUSION ON)
if (NOT GDAL_ENABLE_DRIVER_VRT)
target_compile_definitions(gdal_vrt PRIVATE -DNO_OPEN)
endif()

32 changes: 12 additions & 20 deletions frmts/vrt/pixelfunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,8 @@
#include <cmath>
#include "gdal.h"
#include "vrtdataset.h"
#include "vrtexpression.h"

#define exprtk_disable_caseinsensitivity
#define exprtk_disable_rtl_io
#define exprtk_disable_rtl_io_file
#define exprtk_disable_rtl_vecops
#define exprtk_disable_string_capabilities
#include <exprtk.hpp>
#include <limits>

template <typename T>
Expand Down Expand Up @@ -1649,25 +1644,15 @@ static CPLErr ExprPixelFunc(void **papoSources, int nSources, void *pData,

std::vector<double> adfValuesForPixel(nSources);

exprtk::symbol_table<double> oSymbolTable;
GDALExpressionEvaluator oEvaluator(pszExpression);
{
int iSource = 0;
for (const auto &osName : aosSourceNames)
{
oSymbolTable.add_variable(osName, adfValuesForPixel[iSource++]);
oEvaluator.RegisterVariable(osName, &adfValuesForPixel[iSource++]);
}
}
oSymbolTable.add_vector("ALL", adfValuesForPixel);

exprtk::expression<double> oExpression;
oExpression.register_symbol_table(oSymbolTable);

exprtk::parser<double> oParser;
if (!oParser.compile(pszExpression, oExpression))
{
CPLError(CE_Failure, CPLE_AppDefined, "failed to parse expression");
return CE_Failure;
}
oEvaluator.RegisterVector("ALL", &adfValuesForPixel);

double *padfResults =
static_cast<double *>(CPLMalloc(nXSize * sizeof(double)));
Expand All @@ -1685,7 +1670,14 @@ static CPLErr ExprPixelFunc(void **papoSources, int nSources, void *pData,
GetSrcVal(papoSources[iSrc], eSrcType, ii);
}

padfResults[iCol] = oExpression.value();
if (auto eErr = oEvaluator.Evaluate(); eErr != CE_None)
{
return CE_Failure;
}
else
{
padfResults[iCol] = oEvaluator.Results()[0];
}
}

GDALCopyWords(padfResults, GDT_Float64, sizeof(double),
Expand Down
163 changes: 163 additions & 0 deletions frmts/vrt/vrtexpression.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
#include "cpl_error.h"
#include "vrtexpression.h"

#define exprtk_disable_caseinsensitivity
#define exprtk_disable_rtl_io
#define exprtk_disable_rtl_io_file
//#define exprtk_disable_rtl_vecops
#define exprtk_disable_string_capabilities
#include <exprtk.hpp>

class GDALExpressionEvaluator::Impl
{
public:
exprtk::expression<double> m_oExpression{};
exprtk::parser<double> m_oParser{};
exprtk::symbol_table<double> m_oSymbolTable{};
std::string m_osExpression{};

std::vector<std::pair<std::string, double *>> m_aoVariables{};
std::vector<std::pair<std::string, std::vector<double> *>> m_aoVectors{};
std::vector<double> m_adfResults{};

bool m_bIsCompiled{false};

CPLErr compile()
{
for (const auto &[osVariable, pdfValueLoc] : m_aoVariables)
{
m_oSymbolTable.add_variable(osVariable, *pdfValueLoc);
}

for (const auto &[osVariable, padfVectorLoc] : m_aoVectors)
{
m_oSymbolTable.add_vector(osVariable, *padfVectorLoc);
}

m_oExpression.register_symbol_table(m_oSymbolTable);
bool bSuccess = m_oParser.compile(m_osExpression, m_oExpression);

if (!bSuccess)
{
for (size_t i = 0; i < m_oParser.error_count(); i++)
{
const auto &oError = m_oParser.get_error(i);

CPLError(CE_Warning, CPLE_AppDefined,
"Position: %02d "
"Type: [%s] "
"Message: %s\n",
static_cast<int>(oError.token.position),
exprtk::parser_error::to_str(oError.mode).c_str(),
oError.diagnostic.c_str());
}

CPLError(CE_Failure, CPLE_AppDefined,
"Failed to parse expression.");
return CE_Failure;
}

m_bIsCompiled = true;

return CE_None;
}
};

GDALExpressionEvaluator::GDALExpressionEvaluator(std::string_view osExpression)
: m_pImpl(std::make_unique<Impl>())
{
m_pImpl->m_osExpression = osExpression;
}

GDALExpressionEvaluator::~GDALExpressionEvaluator()
{
}

void GDALExpressionEvaluator::RegisterVariable(std::string_view osVariable,
double *pdfValue)
{
m_pImpl->m_aoVariables.emplace_back(osVariable, pdfValue);
}

void GDALExpressionEvaluator::RegisterVector(std::string_view osVariable,
std::vector<double> *padfValue)
{
m_pImpl->m_aoVectors.emplace_back(osVariable, padfValue);
}

CPLErr GDALExpressionEvaluator::Compile()
{
return m_pImpl->compile();
}

const std::vector<double> &GDALExpressionEvaluator::Results() const
{
return m_pImpl->m_adfResults;
}

CPLErr GDALExpressionEvaluator::Evaluate()
{
if (!m_pImpl->m_bIsCompiled)
{
auto eErr = m_pImpl->compile();
if (eErr != CE_None)
{
return eErr;
}
}

m_pImpl->m_adfResults.clear();
double value = m_pImpl->m_oExpression.value(); // force evaluation

const auto &results = m_pImpl->m_oExpression.results();

// We follow a different method to get the result depending on
// how the expression was formed. If a "return" statement was
// used, the result will be accessible via the "result" object.
// If no "return" statement was used, the result is accessible
// from the "value" variable (and must not be a vector.)
if (results.count() == 0)
{
m_pImpl->m_adfResults.resize(1);
m_pImpl->m_adfResults[0] = value;
}
else if (results.count() == 1)
{

if (results[0].type == exprtk::type_store<double>::e_scalar)
{
m_pImpl->m_adfResults.resize(1);
results.get_scalar(0, m_pImpl->m_adfResults[0]);
}
else if (results[0].type == exprtk::type_store<double>::e_vector)
{
results.get_vector(0, m_pImpl->m_adfResults);
}
else
{
CPLError(CE_Failure, CPLE_AppDefined,
"Expression returned an unexpected type.");
return CE_Failure;
}
}
else
{
m_pImpl->m_adfResults.resize(results.count());
for (size_t i = 0; i < results.count(); i++)
{
if (results[i].type != exprtk::type_store<double>::e_scalar)
{
CPLError(CE_Failure, CPLE_AppDefined,
"Expression must return a vector or a list of "
"scalars.");
return CE_Failure;
}
else
{
results.get_scalar(i, m_pImpl->m_adfResults[i]);
}
}
}

return CE_None;
}
30 changes: 30 additions & 0 deletions frmts/vrt/vrtexpression.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#pragma once

#include "cpl_error.h"

#include <string_view>
#include <vector>

class GDALExpressionEvaluator
{
public:
GDALExpressionEvaluator(std::string_view osExpression);

~GDALExpressionEvaluator();

void RegisterVariable(std::string_view osVariable, double *pdfLocation);

void RegisterVector(std::string_view osVariable,
std::vector<double> *padfLocation);

CPLErr Compile();

CPLErr Evaluate();

const std::vector<double> &Results() const;

private:
class Impl;

std::unique_ptr<Impl> m_pImpl;
};
Loading

0 comments on commit f001051

Please sign in to comment.