From 6136116e9d12d644a7c3cec9a5b061e67c602aed Mon Sep 17 00:00:00 2001 From: Michael Maltese Date: Wed, 17 Jul 2024 16:33:38 -0400 Subject: [PATCH] sql: support query strings containing multiple statements --- internal/impl/sql/conn_fields.go | 4 +- internal/impl/sql/input_sql_raw.go | 2 +- internal/impl/sql/multi_statement.go | 136 ++++++++++++++++++++++ internal/impl/sql/multi_statement_test.go | 84 +++++++++++++ internal/impl/sql/output_sql_raw.go | 2 +- internal/impl/sql/processor_sql_raw.go | 4 +- 6 files changed, 226 insertions(+), 6 deletions(-) create mode 100644 internal/impl/sql/multi_statement.go create mode 100644 internal/impl/sql/multi_statement_test.go diff --git a/internal/impl/sql/conn_fields.go b/internal/impl/sql/conn_fields.go index 9c46d92015..1f02392412 100644 --- a/internal/impl/sql/conn_fields.go +++ b/internal/impl/sql/conn_fields.go @@ -170,14 +170,14 @@ func (c *connSettings) apply(ctx context.Context, db *sql.DB, log *service.Logge c.initOnce.Do(func() { for _, fileStmt := range c.initFileStatements { - if _, err := db.ExecContext(ctx, fileStmt[1]); err != nil { + if err := execMultiWithContext(db, ctx, fileStmt[1]); err != nil { log.Warnf("Failed to execute init_file '%v': %v", fileStmt[0], err) } else { log.Debugf("Successfully ran init_file '%v'", fileStmt[0]) } } if c.initStatement != "" { - if _, err := db.ExecContext(ctx, c.initStatement); err != nil { + if err := execMultiWithContext(db, ctx, c.initStatement); err != nil { log.Warnf("Failed to execute init_statement: %v", err) } else { log.Debug("Successfully ran init_statement") diff --git a/internal/impl/sql/input_sql_raw.go b/internal/impl/sql/input_sql_raw.go index fff4663c61..519269132f 100644 --- a/internal/impl/sql/input_sql_raw.go +++ b/internal/impl/sql/input_sql_raw.go @@ -172,7 +172,7 @@ func (s *sqlRawInput) Connect(ctx context.Context) (err error) { } var rows *sql.Rows - if rows, err = db.QueryContext(ctx, s.queryStatic, args...); err != nil { + if rows, err = queryMultiWithContext(db, ctx, s.queryStatic, args...); err != nil { return } else if err = rows.Err(); err != nil { s.logger.With("err", err).Warnf("unexpected error while execute raw query %q", s.queryStatic) diff --git a/internal/impl/sql/multi_statement.go b/internal/impl/sql/multi_statement.go new file mode 100644 index 0000000000..e27706f2b3 --- /dev/null +++ b/internal/impl/sql/multi_statement.go @@ -0,0 +1,136 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "context" + "database/sql" + "strings" +) + +func splitSQLStatements(statement string) []string { + var result []string + startp := 0 + p := 0 + sawNonCommentOrSpace := false + for { + if p == len(statement) || statement[p] == ';' { + if p != len(statement) && statement[p] == ';' { + // include trailing semicolon + p++ + } + statementPart := statement[startp:p] + if sawNonCommentOrSpace { + result = append(result, strings.TrimSpace(statementPart)) + } else { + // coalesce any functionally "empty" statements into the previous statement + // so any configurations that have something like "statement; -- final comment" + // will still work + result[len(result)-1] += statementPart + } + if p == len(statement) { + break + } + startp = p + sawNonCommentOrSpace = false + } else if statement[p] == '\'' || statement[p] == '"' || statement[p] == '`' { + // single-quoted strings, double-quoted identifiers, and backtick-quoted identifiers + sentinel := statement[p] + p++ + for p < len(statement) && statement[p] != sentinel { + p++ + } + sawNonCommentOrSpace = true + } else if statement[p] == '#' || + (p+1 < len(statement) && statement[p:p+2] == "--") || + (p+1 < len(statement) && statement[p:p+2] == "//") { + // single-line comments starting with hash, double-dash, or double-slash + for p < len(statement) && statement[p] != '\n' { + p++ + } + } else if p+1 < len(statement) && statement[p:p+2] == "/*" { + // multi-line comments starting with slash-asterisk + for p+1 < len(statement) && statement[p:p+2] != "*/" { + p++ + } + } else if !(statement[p] == ' ' || statement[p] == '\t' || statement[p] == '\r' || statement[p] == '\n') { + sawNonCommentOrSpace = true + } + if p != len(statement) { + p++ + } + } + + return result +} + +func execMultiWithContext(db *sql.DB, ctx context.Context, query string, args ...any) error { + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer func() { + _ = tx.Rollback() + }() + + statements := splitSQLStatements(query) + for _, part := range statements { + if _, err = tx.ExecContext(ctx, part, args...); err != nil { + return err + } + args = []any{} + } + + if err = tx.Commit(); err != nil { + return err + } + + // TODO: should this return anything for a result? + return nil +} + +func queryMultiWithContext(db *sql.DB, ctx context.Context, query string, args ...any) (*sql.Rows, error) { + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + defer func() { + _ = tx.Rollback() + }() + + statements := splitSQLStatements(query) + var rows *sql.Rows + for i, part := range statements { + // this may not be useful to only give the args to the first query. but, principle of least surprise, + // make it act the same way that execMultiWithContext and the various drivers do. + if i < len(statements)-1 { + if _, err = tx.ExecContext(ctx, part, args...); err != nil { + return nil, err + } + } else { + rows, err = tx.QueryContext(ctx, part, args...) + if err != nil { + return nil, err + } + } + args = []any{} + } + + if err = tx.Commit(); err != nil { + return nil, err + } + + return rows, nil +} diff --git a/internal/impl/sql/multi_statement_test.go b/internal/impl/sql/multi_statement_test.go new file mode 100644 index 0000000000..9b58e6bc72 --- /dev/null +++ b/internal/impl/sql/multi_statement_test.go @@ -0,0 +1,84 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func assertSplitEquals(t *testing.T, message string, statement string, wanted []string) { + result := splitSQLStatements(statement) + assert.Equal(t, wanted, result, message) +} + +func TestSplitStatements(t *testing.T) { + assertSplitEquals(t, "no semicolon", "select null", []string{"select null"}) + + assertSplitEquals(t, "basic semicolon", "select 1; select 2", []string{"select 1;", "select 2"}) + + assertSplitEquals(t, "semicolon in single-quoted string", + "select 'singlequoted;string'; select null", + []string{"select 'singlequoted;string';", "select null"}) + + assertSplitEquals(t, "semicolon in double-quoted identifier", + "select \"doublequoted;ident\"; select null", + []string{"select \"doublequoted;ident\";", "select null"}) + + assertSplitEquals(t, "semicolon in backtick-quoted identifier", + "select `backtick;ident`; select null", + []string{"select `backtick;ident`;", "select null"}) + + assertSplitEquals(t, "semicolon in hash-comment", ` + select #hash;comment + 1; select 2 + `, []string{"select #hash;comment\n\t\t1;", "select 2"}) + + assertSplitEquals(t, "semicolon in double-dash comment", ` + select --double-dash;comment + 1; select 2 + `, []string{"select --double-dash;comment\n\t\t1;", "select 2"}) + + assertSplitEquals(t, "semicolon in double-slash comment", ` + select //double-slash;comment + 1; select 2 + `, []string{"select //double-slash;comment\n\t\t1;", "select 2"}) + + assertSplitEquals(t, "semicolon in multi-line comment", ` + select /*multi; + line;comment*/ + 1; select 2 + `, []string{"select /*multi;\n\t\tline;comment*/\n\t\t1;", "select 2"}) + + assertSplitEquals(t, "semicolon at end should be single statement", + "select null;", + []string{"select null;"}) + + assertSplitEquals(t, "comment with no newline should not fail", + "select null // comment with no newline", + []string{"select null // comment with no newline"}) + + assertSplitEquals(t, "semicolon followed by comment at end should be single statement", + "select null; // trailing comment", + []string{"select null; // trailing comment"}) + + assertSplitEquals(t, "coalesce empty statements into previous but not nonempty statements", + `select 1; // comment + ; + select 2;`, + []string{"select 1; // comment\n\t\t;", "select 2;"}) + +} diff --git a/internal/impl/sql/output_sql_raw.go b/internal/impl/sql/output_sql_raw.go index 1303d7152b..89cac3f6b7 100644 --- a/internal/impl/sql/output_sql_raw.go +++ b/internal/impl/sql/output_sql_raw.go @@ -234,7 +234,7 @@ func (s *sqlRawOutput) WriteBatch(ctx context.Context, batch service.MessageBatc } } - if _, err := s.db.ExecContext(ctx, queryStr, args...); err != nil { + if err := execMultiWithContext(s.db, ctx, queryStr, args...); err != nil { return err } } diff --git a/internal/impl/sql/processor_sql_raw.go b/internal/impl/sql/processor_sql_raw.go index 296fda5aee..3a7640fcf2 100644 --- a/internal/impl/sql/processor_sql_raw.go +++ b/internal/impl/sql/processor_sql_raw.go @@ -244,13 +244,13 @@ func (s *sqlRawProcessor) ProcessBatch(ctx context.Context, batch service.Messag } if s.onlyExec { - if _, err := s.db.ExecContext(ctx, queryStr, args...); err != nil { + if err := execMultiWithContext(s.db, ctx, queryStr, args...); err != nil { s.logger.Debugf("Failed to run query: %v", err) msg.SetError(err) continue } } else { - rows, err := s.db.QueryContext(ctx, queryStr, args...) + rows, err := queryMultiWithContext(s.db, ctx, queryStr, args...) if err != nil { s.logger.Debugf("Failed to run query: %v", err) msg.SetError(err)