Skip to content

Commit

Permalink
VRT expressions: Add compile-time options to enable/disable dialects
Browse files Browse the repository at this point in the history
  • Loading branch information
dbaston committed Jan 7, 2025
1 parent ab8fcf4 commit 0ec86a7
Show file tree
Hide file tree
Showing 10 changed files with 146 additions and 72 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ubuntu_20.04/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ cmake "${GDAL_SOURCE_DIR:=..}" \
-DCMAKE_INSTALL_PREFIX=/tmp/install-gdal \
-DGDAL_USE_TIFF_INTERNAL=OFF \
-DGDAL_USE_GEOTIFF_INTERNAL=OFF \
-DGDAL_VRT_ENABLE_EXPRTK=ON \
-DECW_ROOT=/opt/libecwj2-3.3 \
-DMRSID_ROOT=/usr/local \
-DFileGDB_ROOT=/usr/local/FileGDB_API \
Expand Down
30 changes: 24 additions & 6 deletions autotest/gdrivers/vrtderived.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,51 +1202,69 @@ def vrt_expression_xml(tmpdir, expression, dialect, sources):
)
@pytest.mark.parametrize("dialect", ("exprtk", "muparser"))
def test_vrt_pixelfn_expression(
tmp_path, expression, sources, result, dialect, dialects
tmp_vsimem, expression, sources, result, dialect, dialects
):
pytest.importorskip("numpy")

if not gdaltest.gdal_has_vrt_expression_dialect(dialect):
pytest.skip(f"Expression dialect {dialect} is not available")

if dialects and dialect not in dialects:
pytest.skip(f"Expression not supported for dialect {dialect}")

xml = vrt_expression_xml(tmp_path, expression, dialect, sources)
xml = vrt_expression_xml(tmp_vsimem, expression, dialect, sources)

with gdal.Open(xml) as ds:
assert pytest.approx(ds.ReadAsArray()[0][0], nan_ok=True) == result


@pytest.mark.parametrize(
"expression,sources,exception",
"expression,sources,dialect,exception",
[
pytest.param(
"A*B + C",
[("A", 77), ("B", 63)],
"exprtk",
"Undefined symbol",
id="undefined variable",
id="exprtk undefined variable",
),
pytest.param(
"A*B + C",
[("A", 77), ("B", 63)],
"muparser",
"Unexpected token",
id="muparser undefined variable",
),
pytest.param(
"(".join(["asin", "sin", "acos", "cos"] * 100) + "(X" + 100 * 4 * ")",
[("X", 0.5)],
"exprtk",
"exceeds maximum allowed stack depth",
id="expression is too complex",
),
pytest.param(
" ".join(["sin(x) + cos(x)"] * 10000),
[("x", 0.5)],
"exprtk",
"exceeds maximum of 100000 set by GDAL_EXPRTK_MAX_EXPRESSION_LENGTH",
id="expression is too long",
),
],
)
def test_vrt_pixelfn_expression_invalid(tmp_path, expression, sources, exception):
def test_vrt_pixelfn_expression_invalid(
tmp_vsimem, expression, sources, dialect, exception
):
pytest.importorskip("numpy")

if not gdaltest.gdal_has_vrt_expression_dialect(dialect):
pytest.skip(f"Expression dialect {dialect} is not available")

messages = []

def handle(ecls, ecode, emsg):
messages.append(emsg)

xml = vrt_expression_xml(tmp_path, expression, "exprtk", sources)
xml = vrt_expression_xml(tmp_vsimem, expression, dialect, sources)

with gdaltest.error_handler(handle):
ds = gdal.Open(xml)
Expand Down
8 changes: 7 additions & 1 deletion autotest/gdrivers/vrtprocesseddataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1347,6 +1347,9 @@ def test_vrtprocesseddataset_trimming_errors(tmp_vsimem):
def test_vrtprocesseddataset_expression(
request, tmp_vsimem, expression, src, expected, env, error
):
if not gdaltest.gdal_has_vrt_expression_dialect("exprtk"):
pytest.skip("exprtk not available")

if "timeout" in request.node.name and "debug" not in gdal.VersionInfo(""):
pytest.skip("Timeout tests only work on debug builds")

Expand All @@ -1362,7 +1365,7 @@ def test_vrtprocesseddataset_expression(
src_ds.SetGeoTransform([0, 1, 0, 0, 0, 1])

output_band_xml = "".join(
f"""<VRTRasterBand band="{i+1}" dataType="Float32" subClass="VRTProcessedRasterBand"/>"""
f"""<VRTRasterBand band="{i + 1}" dataType="Float32" subClass="VRTProcessedRasterBand"/>"""
for i in range(expected_output_bands)
)

Expand Down Expand Up @@ -1403,6 +1406,9 @@ def test_vrtprocesseddataset_expression(
)
def test_vrtprocesseddataset_expression_batchsize(tmp_vsimem, batch_size):

if not gdaltest.gdal_has_vrt_expression_dialect("muparser"):
pytest.skip("muparser not available")

src_filename = tmp_vsimem / "in.tif"

inputs = np.arange(12)
Expand Down
26 changes: 26 additions & 0 deletions autotest/pymod/gdaltest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
###############################################################################

import contextlib
import functools
import io
import json
import math
Expand Down Expand Up @@ -2131,3 +2132,28 @@ def handler(lvl, no, msg):
assert any(
[err["level"] == type and match in err["message"] for err in errors]
), f'Did not receive an error of type {err_levels[type]} matching "{match}"'


###############################################################################
# Check VRT capabilities


@functools.lru_cache()
def gdal_has_vrt_expression_dialect(dialect):
with disable_exceptions(), gdal.quiet_errors():
vrt = f"""<VRTDataset rasterXSize="20" rasterYSize="20">
<VRTRasterBand dataType="Float64" band="1" subClass="VRTDerivedRasterBand">
<PixelFunctionType>expression</PixelFunctionType>
<PixelFunctionArguments expression="B1 + 5" dialect="{dialect}"/>
<ArraySource>
<Array name="test">
<DataType>Float64</DataType>
<Dimension name="Y" size="20"/>
<Dimension name="X" size="20"/>
<ConstantValue>10</ConstantValue>
</Array>
</ArraySource>
</VRTRasterBand>
</VRTDataset>"""
ds = gdal.Open(vrt)
return ds.ReadRaster() is not None
50 changes: 30 additions & 20 deletions frmts/vrt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ add_gdal_driver(
vrtderivedrasterband.cpp
vrtdriver.cpp
vrtexpression.h
vrtexpression.cpp
vrtexpression_muparser.cpp
vrtfilters.cpp
vrtrasterband.cpp
vrtrawrasterband.cpp
Expand All @@ -26,26 +24,38 @@ gdal_standard_includes(gdal_vrt)
target_include_directories(gdal_vrt PRIVATE ${GDAL_RASTER_FORMAT_SOURCE_DIR}/raw
$<TARGET_PROPERTY:ogrsf_generic,SOURCE_DIR>)

target_include_directories(gdal_vrt SYSTEM PRIVATE ${PROJECT_SOURCE_DIR}/third_party/exprtk)
if (MSVC)
set_source_files_properties(vrtexpression.cpp PROPERTIES COMPILE_FLAGS "/bigobj")
elseif(MINGW)
set_source_files_properties(vrtexpression.cpp PROPERTIES COMPILE_FLAGS "-Wa,-mbig-obj")
option(GDAL_VRT_ENABLE_EXPRTK "Enable exprtk library for VRT expressions" OFF)
option(GDAL_VRT_ENABLE_MUPARSER "Enable muparser library for VRT expressions" ON)

if (GDAL_VRT_ENABLE_EXPRTK)
target_sources(gdal_vrt PRIVATE vrtexpression_exprtk.cpp)
target_include_directories(gdal_vrt SYSTEM PRIVATE ${PROJECT_SOURCE_DIR}/third_party/exprtk)
if (MSVC)
set_source_files_properties(vrtexpression_exprtk.cpp PROPERTIES COMPILE_FLAGS "/bigobj")
elseif(MINGW)
set_source_files_properties(vrtexpression.cpp PROPERTIES COMPILE_FLAGS "-Wa,-mbig-obj")
endif()
target_compile_definitions(gdal_vrt PRIVATE GDAL_VRT_ENABLE_EXPRTK)
endif()

if (GDAL_VRT_ENABLE_MUPARSER)
target_sources(gdal_vrt PRIVATE vrtexpression_muparser.cpp)
add_library(muparser OBJECT
${PROJECT_SOURCE_DIR}/third_party/muparser/src/muParser.cpp
${PROJECT_SOURCE_DIR}/third_party/muparser/src/muParserBase.cpp
${PROJECT_SOURCE_DIR}/third_party/muparser/src/muParserBytecode.cpp
${PROJECT_SOURCE_DIR}/third_party/muparser/src/muParserCallback.cpp
${PROJECT_SOURCE_DIR}/third_party/muparser/src/muParserError.cpp
${PROJECT_SOURCE_DIR}/third_party/muparser/src/muParserTokenReader.cpp
)
target_compile_definitions(muparser PRIVATE MUPARSERLIB_EXPORTS)
set_property(TARGET muparser PROPERTY POSITION_INDEPENDENT_CODE ${GDAL_OBJECT_LIBRARIES_POSITION_INDEPENDENT_CODE})
target_include_directories(muparser SYSTEM PRIVATE ${PROJECT_SOURCE_DIR}/third_party/muparser/include)
target_include_directories(gdal_vrt SYSTEM PRIVATE ${PROJECT_SOURCE_DIR}/third_party/muparser/include)
target_sources(${GDAL_LIB_TARGET_NAME} PRIVATE $<TARGET_OBJECTS:muparser>)
target_compile_definitions(gdal_vrt PRIVATE GDAL_VRT_ENABLE_MUPARSER)
endif()

add_library(muparser OBJECT
${PROJECT_SOURCE_DIR}/third_party/muparser/src/muParser.cpp
${PROJECT_SOURCE_DIR}/third_party/muparser/src/muParserBase.cpp
${PROJECT_SOURCE_DIR}/third_party/muparser/src/muParserBytecode.cpp
${PROJECT_SOURCE_DIR}/third_party/muparser/src/muParserCallback.cpp
${PROJECT_SOURCE_DIR}/third_party/muparser/src/muParserError.cpp
${PROJECT_SOURCE_DIR}/third_party/muparser/src/muParserTokenReader.cpp
)
target_compile_definitions(muparser PRIVATE MUPARSERLIB_EXPORTS)
set_property(TARGET muparser PROPERTY POSITION_INDEPENDENT_CODE ${GDAL_OBJECT_LIBRARIES_POSITION_INDEPENDENT_CODE})
target_include_directories(muparser SYSTEM PRIVATE ${PROJECT_SOURCE_DIR}/third_party/muparser/include)
target_include_directories(gdal_vrt SYSTEM PRIVATE ${PROJECT_SOURCE_DIR}/third_party/muparser/include)
target_sources(${GDAL_LIB_TARGET_NAME} PRIVATE $<TARGET_OBJECTS:muparser>)

set(GDAL_DATA_FILES
${CMAKE_CURRENT_SOURCE_DIR}/data/gdalvrt.xsd
Expand Down
33 changes: 13 additions & 20 deletions frmts/vrt/pixelfunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1654,28 +1654,17 @@ static CPLErr ExprPixelFunc(void **papoSources, int nSources, void *pData,
std::vector<double> adfValuesForPixel(nSources);

const char *pszDialect = CSLFetchNameValue(papszArgs, "dialect");
if (pszDialect)
if (!pszDialect)
{
if (EQUAL(pszDialect, "muparser"))
{
poExpression =
std::make_unique<gdal::MuParserExpression>(pszExpression);
}
else if (EQUAL(pszDialect, "exprtk"))
{
poExpression =
std::make_unique<gdal::ExprtkExpression>(pszExpression);
}
else
{
CPLError(CE_Failure, CPLE_AppDefined,
"unknown expression dialect: %s", pszDialect);
return CE_Failure;
}
pszDialect = "muparser";
}
else

poExpression = gdal::MathExpression::Create(pszExpression, pszDialect);

// cppcheck-suppress knownConditionTrueFalse
if (!poExpression)
{
poExpression = std::make_unique<gdal::ExprtkExpression>(pszExpression);
return CE_Failure;
}

{
Expand All @@ -1686,7 +1675,11 @@ static CPLErr ExprPixelFunc(void **papoSources, int nSources, void *pData,
&adfValuesForPixel[iSource++]);
}
}
poExpression->RegisterVector("BANDS", &adfValuesForPixel);
CPLString osExpression(pszExpression);
if (osExpression.find("BANDS") != std::string::npos)
{
poExpression->RegisterVector("BANDS", &adfValuesForPixel);
}

double *padfResults =
static_cast<double *>(CPLMalloc(nXSize * sizeof(double)));
Expand Down
34 changes: 33 additions & 1 deletion frmts/vrt/vrtexpression.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class MathExpression
public:
virtual ~MathExpression() = default;

static std::unique_ptr<MathExpression> Create(const char *pszExpression,
const char *pszDialect);

/**
* Register a variable to be used in the expression.
*
Expand Down Expand Up @@ -63,7 +66,8 @@ class MathExpression
* the expression is evaluated.
*
* @return CE_None if the expression can be successfully parsed and all
* symbols have been registered, CE_Failure otherwise.
* symbols have been registe vrtexpression_exprtk.cpp
vrtexpression_muparser.cppred, CE_Failure otherwise.
*
* @since 3.11
*/
Expand Down Expand Up @@ -92,6 +96,8 @@ class MathExpression

/*! @cond Doxygen_Suppress */

#if GDAL_VRT_ENABLE_EXPRTK

/**
* Class to support evaluation of an expression using the exprtk library.
*/
Expand Down Expand Up @@ -120,6 +126,10 @@ class ExprtkExpression : public MathExpression
std::unique_ptr<Impl> m_pImpl;
};

#endif

#if GDAL_VRT_ENABLE_MUPARSER

/**
* Class to support evaluation of an expression using the muparser library.
*/
Expand Down Expand Up @@ -148,6 +158,28 @@ class MuParserExpression : public MathExpression
std::unique_ptr<Impl> m_pImpl;
};

#endif

inline std::unique_ptr<MathExpression>
MathExpression::Create(const char *pszExpression, const char *pszDialect)
{
#if GDAL_VRT_ENABLE_EXPRTK
if (EQUAL(pszDialect, "exprtk"))
{
return std::make_unique<gdal::ExprtkExpression>(pszExpression);
}
#endif
#if GDAL_VRT_ENABLE_MUPARSER
if (EQUAL(pszDialect, "muparser"))
{
return std::make_unique<gdal::MuParserExpression>(pszExpression);
}
#endif
CPLError(CE_Failure, CPLE_IllegalArg, "Unknown expression dialect: %s",
pszDialect);
return nullptr;
}

/*! @endcond */

} // namespace gdal
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#define exprtk_disable_rtl_io_file
#define exprtk_disable_rtl_vecops
#define exprtk_disable_string_capabilities
#define exprtk_disable_advanced_features

#if defined(__GNUC__)
#pragma GCC diagnostic push
Expand Down
27 changes: 7 additions & 20 deletions frmts/vrt/vrtprocesseddatasetfunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1491,25 +1491,6 @@ class ExpressionData
{
}

static std::unique_ptr<gdal::MathExpression>
CreateExpression(std::string_view osExpression, const char *pszDialect)
{
if (EQUAL(pszDialect, "exprtk"))
{
return std::make_unique<gdal::ExprtkExpression>(osExpression);
}
else if (EQUAL(pszDialect, "muparser"))
{
return std::make_unique<gdal::MuParserExpression>(osExpression);
}
else
{
CPLError(CE_Failure, CPLE_IllegalArg,
"Unknown expression dialect: %s", pszDialect);
return nullptr;
}
}

CPLErr Compile()
{
auto eErr = m_oNominalBatchEnv.Initialize(m_osExpression, m_osDialect,
Expand Down Expand Up @@ -1605,7 +1586,9 @@ class ExpressionData
CPLErr Initialize(const CPLString &osExpression,
const CPLString &osDialect, int nBatchSize)
{
m_poExpression = CreateExpression(osExpression, osDialect.c_str());
m_poExpression =
gdal::MathExpression::Create(osExpression, osDialect.c_str());
// cppcheck-suppress knownConditionTrueFalse
if (m_poExpression == nullptr)
{
return CE_Failure;
Expand Down Expand Up @@ -1676,6 +1659,10 @@ static CPLErr ExpressionInit(const char * /*pszFuncName*/, void * /*pUserData*/,
}

const char *pszDialect = CSLFetchNameValue(papszFunctionArgs, "dialect");
if (pszDialect == nullptr)
{
pszDialect = "muparser";
}

const char *pszExpression =
CSLFetchNameValue(papszFunctionArgs, "expression");
Expand Down
Loading

0 comments on commit 0ec86a7

Please sign in to comment.