diff --git a/docs/modules/components/pages/outputs/snowflake_streaming.adoc b/docs/modules/components/pages/outputs/snowflake_streaming.adoc index 23f6a26f2c..a5a0c9f8e0 100644 --- a/docs/modules/components/pages/outputs/snowflake_streaming.adoc +++ b/docs/modules/components/pages/outputs/snowflake_streaming.adoc @@ -39,7 +39,7 @@ Common:: output: label: "" snowflake_streaming: - account: AAAAAAA-AAAAAAA # No default (required) + account: ORG-ACCOUNT # No default (required) user: "" # No default (required) role: ACCOUNTADMIN # No default (required) database: "" # No default (required) @@ -51,6 +51,17 @@ output: mapping: "" # No default (optional) init_statement: | # No default (optional) CREATE TABLE IF NOT EXISTS mytable (amount NUMBER); + schema_evolution: + enabled: false # No default (required) + new_column_type_mapping: |- + root = match this.value.type() { + this == "string" => "STRING" + this == "bytes" => "BINARY" + this == "number" => "DOUBLE" + this == "bool" => "BOOLEAN" + this == "timestamp" => "TIMESTAMP" + _ => "VARIANT" + } batching: count: 0 byte_size: 0 @@ -69,7 +80,7 @@ Advanced:: output: label: "" snowflake_streaming: - account: AAAAAAA-AAAAAAA # No default (required) + account: ORG-ACCOUNT # No default (required) user: "" # No default (required) role: ACCOUNTADMIN # No default (required) database: "" # No default (required) @@ -81,6 +92,17 @@ output: mapping: "" # No default (optional) init_statement: | # No default (optional) CREATE TABLE IF NOT EXISTS mytable (amount NUMBER); + schema_evolution: + enabled: false # No default (required) + new_column_type_mapping: |- + root = match this.value.type() { + this == "string" => "STRING" + this == "bytes" => "BINARY" + this == "number" => "DOUBLE" + this == "bool" => "BOOLEAN" + this == "timestamp" => "TIMESTAMP" + _ => "VARIANT" + } build_parallelism: 1 batching: count: 0 @@ -170,6 +192,8 @@ output: schema: "PUBLIC" table: "MYTABLE" private_key_file: "my/private/key.p8" + schema_evolution: + enabled: true ``` -- @@ -214,10 +238,7 @@ output: === `account` -Account name, which is the same as the https://docs.snowflake.com/en/user-guide/admin-account-identifier.html#where-are-account-identifiers-used[Account Identifier^]. - However, when using an https://docs.snowflake.com/en/user-guide/admin-account-identifier.html#using-an-account-locator-as-an-identifier[Account Locator^], - the Account Identifier is formatted as `..` and this field needs to be - populated using the `` part. +The Snowflake https://docs.snowflake.com/en/user-guide/admin-account-identifier.html#using-an-account-locator-as-an-identifier[Account name^]. Which should be formatted as `-` where `` is the name of your Snowflake organization and `` is the unique name of your account within your organization. *Type*: `string` @@ -226,7 +247,7 @@ Account name, which is the same as the https://docs.snowflake.com/en/user-guide/ ```yml # Examples -account: AAAAAAA-AAAAAAA +account: ORG-ACCOUNT ``` === `user` @@ -336,6 +357,33 @@ init_statement: |2 ALTER TABLE t1 ADD COLUMN a2 NUMBER; ``` +=== `schema_evolution` + +Options to control schema evolution within the pipeline as new columns are added to the pipeline. + + +*Type*: `object` + + +=== `schema_evolution.enabled` + +Whether schema evolution is enabled. + + +*Type*: `bool` + + +=== `schema_evolution.new_column_type_mapping` + +The mapping function from Redpanda Connect type to column type in Snowflake. Overriding this can allow for customization of the datatype if there is specific information that you know about the data types in use. This mapping should result in the `root` variable being assigned a string with the data type for the new column in Snowflake. + +The input to this mapping is an object with the value and the name of the new column, for example: `{"value": 42.3, "name":"new_data_field"}" + + +*Type*: `string` + +*Default*: `"root = match this.value.type() {\n this == \"string\" =\u003e \"STRING\"\n this == \"bytes\" =\u003e \"BINARY\"\n this == \"number\" =\u003e \"DOUBLE\"\n this == \"bool\" =\u003e \"BOOLEAN\"\n this == \"timestamp\" =\u003e \"TIMESTAMP\"\n _ =\u003e \"VARIANT\"\n}"` + === `build_parallelism` The maximum amount of parallelism to use when building the output for Snowflake. The metric to watch to see if you need to change this is `snowflake_build_output_latency_ns`. diff --git a/internal/impl/snowflake/output_snowflake_streaming.go b/internal/impl/snowflake/output_snowflake_streaming.go index 9cdb34558a..df434988c3 100644 --- a/internal/impl/snowflake/output_snowflake_streaming.go +++ b/internal/impl/snowflake/output_snowflake_streaming.go @@ -11,7 +11,9 @@ package snowflake import ( "context" "crypto/rsa" + "errors" "fmt" + "regexp" "sync" "github.com/redpanda-data/benthos/v4/public/bloblang" @@ -21,20 +23,32 @@ import ( ) const ( - ssoFieldAccount = "account" - ssoFieldUser = "user" - ssoFieldRole = "role" - ssoFieldDB = "database" - ssoFieldSchema = "schema" - ssoFieldTable = "table" - ssoFieldKey = "private_key" - ssoFieldKeyFile = "private_key_file" - ssoFieldKeyPass = "private_key_pass" - ssoFieldInitStatement = "init_statement" - ssoFieldBatching = "batching" - ssoFieldChannelPrefix = "channel_prefix" - ssoFieldMapping = "mapping" - ssoFieldBuildParallelism = "build_parallelism" + ssoFieldAccount = "account" + ssoFieldUser = "user" + ssoFieldRole = "role" + ssoFieldDB = "database" + ssoFieldSchema = "schema" + ssoFieldTable = "table" + ssoFieldKey = "private_key" + ssoFieldKeyFile = "private_key_file" + ssoFieldKeyPass = "private_key_pass" + ssoFieldInitStatement = "init_statement" + ssoFieldBatching = "batching" + ssoFieldChannelPrefix = "channel_prefix" + ssoFieldMapping = "mapping" + ssoFieldBuildParallelism = "build_parallelism" + ssoFieldSchemaEvolution = "schema_evolution" + ssoFieldSchemaEvolutionEnabled = "enabled" + ssoFieldSchemaEvolutionNewColumnTypeMapping = "new_column_type_mapping" + + defaultSchemaEvolutionNewColumnMapping = `root = match this.value.type() { + this == "string" => "STRING" + this == "bytes" => "BINARY" + this == "number" => "DOUBLE" + this == "bool" => "BOOLEAN" + this == "timestamp" => "TIMESTAMP" + _ => "VARIANT" +}` ) func snowflakeStreamingOutputConfig() *service.ConfigSpec { @@ -70,11 +84,8 @@ You can monitor the output batch size using the `+"`snowflake_compressed_output_ `). Fields( service.NewStringField(ssoFieldAccount). - Description(`Account name, which is the same as the https://docs.snowflake.com/en/user-guide/admin-account-identifier.html#where-are-account-identifiers-used[Account Identifier^]. - However, when using an https://docs.snowflake.com/en/user-guide/admin-account-identifier.html#using-an-account-locator-as-an-identifier[Account Locator^], - the Account Identifier is formatted as `+"`..`"+` and this field needs to be - populated using the `+"``"+` part. -`).Example("AAAAAAA-AAAAAAA"), + Description(`The Snowflake https://docs.snowflake.com/en/user-guide/admin-account-identifier.html#using-an-account-locator-as-an-identifier[Account name^]. Which should be formatted as `+"`-`"+` where `+"``"+` is the name of your Snowflake organization and `+"``"+` is the unique name of your account within your organization. +`).Example("ORG-ACCOUNT"), service.NewStringField(ssoFieldUser).Description("The user to run the Snowpipe Stream as. See https://docs.snowflake.com/en/user-guide/admin-user-management[Snowflake Documentation^] on how to create a user."), service.NewStringField(ssoFieldRole).Description("The role for the `user` field. The role must have the https://docs.snowflake.com/en/user-guide/data-load-snowpipe-streaming-overview#required-access-privileges[required privileges^] to call the Snowpipe Streaming APIs. See https://docs.snowflake.com/en/user-guide/admin-user-management#user-roles[Snowflake Documentation^] for more information about roles.").Example("ACCOUNTADMIN"), service.NewStringField(ssoFieldDB).Description("The Snowflake database to ingest data into."), @@ -92,6 +103,13 @@ CREATE TABLE IF NOT EXISTS mytable (amount NUMBER); ALTER TABLE t1 ALTER COLUMN c1 DROP NOT NULL; ALTER TABLE t1 ADD COLUMN a2 NUMBER; `), + service.NewObjectField(ssoFieldSchemaEvolution, + service.NewBoolField(ssoFieldSchemaEvolutionEnabled).Description("Whether schema evolution is enabled."), + service.NewBloblangField(ssoFieldSchemaEvolutionNewColumnTypeMapping).Description(` +The mapping function from Redpanda Connect type to column type in Snowflake. Overriding this can allow for customization of the datatype if there is specific information that you know about the data types in use. This mapping should result in the `+"`root`"+` variable being assigned a string with the data type for the new column in Snowflake. + +The input to this mapping is an object with the value and the name of the new column, for example: `+"`"+`{"value": 42.3, "name":"new_data_field"}`+`"`).Default(defaultSchemaEvolutionNewColumnMapping), + ).Description(`Options to control schema evolution within the pipeline as new columns are added to the pipeline.`).Optional(), service.NewIntField(ssoFieldBuildParallelism).Description("The maximum amount of parallelism to use when building the output for Snowflake. The metric to watch to see if you need to change this is `snowflake_build_output_latency_ns`.").Default(1).Advanced(), service.NewBatchPolicyField(ssoFieldBatching), service.NewOutputMaxInFlightField(), @@ -144,6 +162,8 @@ output: schema: "PUBLIC" table: "MYTABLE" private_key_file: "my/private/key.p8" + schema_evolution: + enabled: true `, ). Example( @@ -268,6 +288,17 @@ func newSnowflakeStreamer( return nil, err } } + var schemaEvolutionMapping *bloblang.Executor + if conf.Contains(ssoFieldSchemaEvolution, ssoFieldSchemaEvolutionEnabled) { + enabled, err := conf.FieldBool(ssoFieldSchemaEvolution, ssoFieldSchemaEvolutionEnabled) + if err == nil && enabled { + schemaEvolutionMapping, err = conf.FieldBloblang(ssoFieldSchemaEvolution, ssoFieldSchemaEvolutionNewColumnTypeMapping) + } + if err != nil { + return nil, err + } + } + buildParallelism, err := conf.FieldInt(ssoFieldBuildParallelism) if err != nil { return nil, err @@ -284,19 +315,14 @@ func newSnowflakeStreamer( // stream to write to a single table. channelPrefix = fmt.Sprintf("Redpanda_Connect_%s.%s.%s", db, schema, table) } - var initStatementsFn func(context.Context) error + var initStatementsFn func(context.Context, *streaming.SnowflakeRestClient) error if conf.Contains(ssoFieldInitStatement) { initStatements, err := conf.FieldString(ssoFieldInitStatement) if err != nil { return nil, err } - initStatementsFn = func(ctx context.Context) error { - c, err := streaming.NewRestClient(account, user, mgr.EngineVersion(), channelPrefix, rsaKey, mgr.Logger()) - if err != nil { - return err - } - defer c.Close() - _, err = c.RunSQL(ctx, streaming.RunSQLRequest{ + initStatementsFn = func(ctx context.Context, client *streaming.SnowflakeRestClient) error { + _, err = client.RunSQL(ctx, streaming.RunSQLRequest{ Statement: initStatements, // Currently we set of timeout of 30 seconds so that we don't have to handle async operations // that need polling to wait until they finish (results are made async when execution is longer @@ -313,6 +339,10 @@ func newSnowflakeStreamer( return err } } + restClient, err := streaming.NewRestClient(account, user, mgr.EngineVersion(), channelPrefix, rsaKey, mgr.Logger()) + if err != nil { + return nil, fmt.Errorf("unable to create rest API client: %w", err) + } client, err := streaming.NewSnowflakeServiceClient( context.Background(), streaming.ClientOptions{ @@ -328,40 +358,46 @@ func newSnowflakeStreamer( return nil, err } o := &snowflakeStreamerOutput{ - channelPrefix: channelPrefix, - client: client, - db: db, - schema: schema, - table: table, - mapping: mapping, - logger: mgr.Logger(), - buildTime: mgr.Metrics().NewTimer("snowflake_build_output_latency_ns"), - uploadTime: mgr.Metrics().NewTimer("snowflake_upload_latency_ns"), - convertTime: mgr.Metrics().NewTimer("snowflake_convert_latency_ns"), - serializeTime: mgr.Metrics().NewTimer("snowflake_serialize_latency_ns"), - compressedOutput: mgr.Metrics().NewCounter("snowflake_compressed_output_size_bytes"), - initStatementsFn: initStatementsFn, - buildParallelism: buildParallelism, + channelPrefix: channelPrefix, + client: client, + db: db, + schema: schema, + table: table, + role: role, + mapping: mapping, + logger: mgr.Logger(), + buildTime: mgr.Metrics().NewTimer("snowflake_build_output_latency_ns"), + uploadTime: mgr.Metrics().NewTimer("snowflake_upload_latency_ns"), + convertTime: mgr.Metrics().NewTimer("snowflake_convert_latency_ns"), + serializeTime: mgr.Metrics().NewTimer("snowflake_serialize_latency_ns"), + compressedOutput: mgr.Metrics().NewCounter("snowflake_compressed_output_size_bytes"), + initStatementsFn: initStatementsFn, + buildParallelism: buildParallelism, + schemaEvolutionMapping: schemaEvolutionMapping, + restClient: restClient, } return o, nil } type snowflakeStreamerOutput struct { - client *streaming.SnowflakeServiceClient - channelPool sync.Pool - channelCreationMu sync.Mutex - poolSize int - compressedOutput *service.MetricCounter - uploadTime *service.MetricTimer - buildTime *service.MetricTimer - convertTime *service.MetricTimer - serializeTime *service.MetricTimer - buildParallelism int - - channelPrefix, db, schema, table string - mapping *bloblang.Executor - logger *service.Logger - initStatementsFn func(context.Context) error + client *streaming.SnowflakeServiceClient + channelPool sync.Pool + channelCreationMu sync.Mutex + poolSize int + compressedOutput *service.MetricCounter + uploadTime *service.MetricTimer + buildTime *service.MetricTimer + convertTime *service.MetricTimer + serializeTime *service.MetricTimer + buildParallelism int + schemaEvolutionMapping *bloblang.Executor + + schemaMigrationMu sync.RWMutex + channelPrefix, db, schema, table, role string + mapping *bloblang.Executor + logger *service.Logger + initStatementsFn func(context.Context, *streaming.SnowflakeRestClient) error + restClient *streaming.SnowflakeRestClient } func (o *snowflakeStreamerOutput) openNewChannel(ctx context.Context) (*streaming.SnowflakeIngestionChannel, error) { @@ -380,19 +416,20 @@ func (o *snowflakeStreamerOutput) openNewChannel(ctx context.Context) (*streamin func (o *snowflakeStreamerOutput) openChannel(ctx context.Context, name string, id int16) (*streaming.SnowflakeIngestionChannel, error) { o.logger.Debugf("opening snowflake streaming channel: %s", name) return o.client.OpenChannel(ctx, streaming.ChannelOptions{ - ID: id, - Name: name, - DatabaseName: o.db, - SchemaName: o.schema, - TableName: o.table, - BuildParallelism: o.buildParallelism, + ID: id, + Name: name, + DatabaseName: o.db, + SchemaName: o.schema, + TableName: o.table, + BuildParallelism: o.buildParallelism, + StrictSchemaEnforcement: o.schemaEvolutionMapping != nil, }) } func (o *snowflakeStreamerOutput) Connect(ctx context.Context) error { if o.initStatementsFn != nil { - if err := o.initStatementsFn(ctx); err != nil { - return err + if err := o.initStatementsFn(ctx, o.restClient); err != nil { + return fmt.Errorf("unable to run initialization statement: %w", err) } // We've already executed our init statement, we don't need to do that anymore o.initStatementsFn = nil @@ -419,6 +456,29 @@ func (o *snowflakeStreamerOutput) WriteBatch(ctx context.Context, batch service. } batch = mapped } + var err error + // We only migrate one column at a time, so tolerate up to 10 schema + // migrations for a single batch before giving up. This protects against + // any bugs over infinitely looping. + for i := 0; i < 10; i++ { + err = o.WriteBatchInternal(ctx, batch) + if err == nil { + return nil + } + migrationErr := schemaMigrationNeededError{} + if !errors.As(err, &migrationErr) { + break + } + if err := migrationErr.migrator(ctx); err != nil { + return err + } + } + return err +} + +func (o *snowflakeStreamerOutput) WriteBatchInternal(ctx context.Context, batch service.MessageBatch) error { + o.schemaMigrationMu.RLock() + defer o.schemaMigrationMu.RUnlock() var channel *streaming.SnowflakeIngestionChannel if maybeChan := o.channelPool.Get(); maybeChan != nil { channel = maybeChan.(*streaming.SnowflakeIngestionChannel) @@ -438,6 +498,31 @@ func (o *snowflakeStreamerOutput) WriteBatch(ctx context.Context, batch service. o.convertTime.Timing(stats.ConvertTime.Nanoseconds()) o.serializeTime.Timing(stats.SerializeTime.Nanoseconds()) } else { + // Only evolve the schema if requested. + if o.schemaEvolutionMapping != nil { + nullColumnErr := streaming.NonNullColumnError{} + if errors.As(err, &nullColumnErr) { + // put the channel back so that we can reopen it along with the rest of the channels to + // pick up the new schema. + o.channelPool.Put(channel) + // Return an error so that we release our read lock and can take the write lock + // to forcibly reopen all our channels to get a new schema. + return schemaMigrationNeededError{ + migrator: func(ctx context.Context) error { + return o.MigrateNotNullColumn(ctx, nullColumnErr) + }, + } + } + missingColumnErr := streaming.MissingColumnError{} + if errors.As(err, &missingColumnErr) { + o.channelPool.Put(channel) + return schemaMigrationNeededError{ + migrator: func(ctx context.Context) error { + return o.MigrateMissingColumn(ctx, missingColumnErr) + }, + } + } + } reopened, reopenErr := o.openChannel(ctx, channel.Name, channel.ID) if reopenErr == nil { o.channelPool.Put(reopened) @@ -456,6 +541,128 @@ func (o *snowflakeStreamerOutput) WriteBatch(ctx context.Context, batch service. return err } +type schemaMigrationNeededError struct { + migrator func(ctx context.Context) error +} + +func (schemaMigrationNeededError) Error() string { + return "schema migration was required and the operation needs to be retried after the migration" +} + +func (o *snowflakeStreamerOutput) MigrateMissingColumn(ctx context.Context, col streaming.MissingColumnError) error { + o.schemaMigrationMu.Lock() + defer o.schemaMigrationMu.Unlock() + msg := service.NewMessage(nil) + msg.SetStructuredMut(map[string]any{ + "name": col.RawName(), + "value": col.Value(), + }) + out, err := msg.BloblangQuery(o.schemaEvolutionMapping) + if err != nil { + return fmt.Errorf("unable to compute new column type for %s: %w", col.ColumnName(), err) + } + v, err := out.AsBytes() + if err != nil { + return fmt.Errorf("unable to extract result from new column type mapping for %s: %w", col.ColumnName(), err) + } + columnType := string(v) + if err := validateColumnType(columnType); err != nil { + return err + } + o.logger.Infof("identified new schema - attempting to alter table to add column: %s %s", col.ColumnName(), columnType) + err = o.RunSQLMigration( + ctx, + // This looks very scary and it *should*. This is prone to SQL injection attacks. The column name is + // quoted according to the rules in Snowflake's documentation. This is also why we need to + // validate the data type, so that you can't sneak an injection attack in there. + fmt.Sprintf(`ALTER TABLE IDENTIFIER(?) + ADD COLUMN IF NOT EXISTS %s %s + COMMENT 'column created by schema evolution from Redpanda Connect'`, + col.ColumnName(), + columnType, + ), + ) + if err != nil { + o.logger.Warnf("unable to add new column, this maybe due to a race with another request, error: %s", err) + } + return o.ReopenAllChannels(ctx) +} + +func (o *snowflakeStreamerOutput) MigrateNotNullColumn(ctx context.Context, col streaming.NonNullColumnError) error { + o.schemaMigrationMu.Lock() + defer o.schemaMigrationMu.Unlock() + o.logger.Infof("identified new schema - attempting to alter table to remove null constraint on column: %s", col.ColumnName()) + err := o.RunSQLMigration( + ctx, + // This looks very scary and it *should*. This is prone to SQL injection attacks. The column name here + // comes directly from the Snowflake API so it better not have a SQL injection :) + fmt.Sprintf(`ALTER TABLE IDENTIFIER(?) ALTER + %s DROP NOT NULL, + %s COMMENT 'column altered to be nullable by schema evolution from Redpanda Connect'`, + col.ColumnName(), + col.ColumnName(), + ), + ) + if err != nil { + o.logger.Warnf("unable to mark column %s as null, this maybe due to a race with another request, error: %s", col.ColumnName(), err) + } + return o.ReopenAllChannels(ctx) +} + +func (o *snowflakeStreamerOutput) RunSQLMigration(ctx context.Context, statement string) error { + _, err := o.restClient.RunSQL(ctx, streaming.RunSQLRequest{ + Statement: statement, + // Currently we set of timeout of 30 seconds so that we don't have to handle async operations + // that need polling to wait until they finish (results are made async when execution is longer + // than 45 seconds). + Timeout: 30, + Database: o.db, + Schema: o.schema, + Role: o.role, + Bindings: map[string]streaming.BindingValue{ + "1": {Type: "TEXT", Value: o.table}, + }, + }) + return err +} + +// ReopenAllChannels should be called while holding schemaMigrationMu so that +// all channels are actually processed +func (o *snowflakeStreamerOutput) ReopenAllChannels(ctx context.Context) error { + all := []*streaming.SnowflakeIngestionChannel{} + for { + maybeChan := o.channelPool.Get() + if maybeChan == nil { + break + } + channel := maybeChan.(*streaming.SnowflakeIngestionChannel) + reopened, reopenErr := o.openChannel(ctx, channel.Name, channel.ID) + if reopenErr == nil { + channel = reopened + } else { + o.logger.Warnf("unable to reopen channel %q schema migration: %v", channel.Name, reopenErr) + // Keep the existing channel so we don't reopen channels, but instead retry later. + } + all = append(all, channel) + } + for _, c := range all { + o.channelPool.Put(c) + } + return nil +} + func (o *snowflakeStreamerOutput) Close(ctx context.Context) error { + o.restClient.Close() return o.client.Close() } + +// This doesn't need to fully match, but be enough to prevent SQL injection as well as +// catch common errors. +var validColumnTypeRegex = regexp.MustCompile(`^\s*(?i:NUMBER|DECIMAL|NUMERIC|INT|INTEGER|BIGINT|SMALLINT|TINYINT|BYTEINT|FLOAT|FLOAT4|FLOAT8|DOUBLE|DOUBLE\s+PRECISION|REAL|VARCHAR|CHAR|CHARACTER|STRING|TEXT|BINARY|VARBINARY|BOOLEAN|DATE|DATETIME|TIME|TIMESTAMP|TIMESTAMP_LTZ|TIMESTAMP_NTZ|TIMESTAMP_TZ|VARIANT|OBJECT|ARRAY)\s*(?:\(\s*\d+\s*\)|\(\s*\d+\s*,\s*\d+\s*\))?\s*$`) + +func validateColumnType(v string) error { + if validColumnTypeRegex.MatchString(v) { + return nil + } + return fmt.Errorf("invalid Snowflake column data type: %s", v) +} diff --git a/internal/impl/snowflake/output_streaming_test.go b/internal/impl/snowflake/output_streaming_test.go new file mode 100644 index 0000000000..5a3f46c1cd --- /dev/null +++ b/internal/impl/snowflake/output_streaming_test.go @@ -0,0 +1,52 @@ +/* + * Copyright 2024 Redpanda Data, Inc. + * + * Licensed as a Redpanda Enterprise file under the Redpanda Community + * License (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://github.com/redpanda-data/redpanda/blob/master/licenses/rcl.md + */ + +package snowflake + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestValidColumnTypeRegex(t *testing.T) { + matches := []string{ + "INT", + "NUMBER", + "NUMBER ( 38, 0 )", + " NUMBER ( 38, 0 ) ", + "DOUBLE PRECISION", + "DOUBLE PRECISION", + " varchar ( 99 ) ", + " varchar ( 0 ) ", + } + for _, m := range matches { + m := m + t.Run(m, func(t *testing.T) { + require.Regexp(t, validColumnTypeRegex, m) + }) + } + nonMatches := []string{ + "VAR", + "N", + "VAR(1, 3)", + "VAR(1)", + "VARCHAR()", + "VARCHAR( )", + "GARBAGE VARCHAR(2)", + "VARCHAR(2) GARBAGE", + } + for _, m := range nonMatches { + m := m + t.Run(m, func(t *testing.T) { + require.NotRegexp(t, validColumnTypeRegex, m) + }) + } +} diff --git a/internal/impl/snowflake/streaming/compat.go b/internal/impl/snowflake/streaming/compat.go index facbb9e735..94e66a33a2 100644 --- a/internal/impl/snowflake/streaming/compat.go +++ b/internal/impl/snowflake/streaming/compat.go @@ -150,6 +150,27 @@ func normalizeColumnName(name string) string { return strings.ToUpper(strings.ReplaceAll(name, `\ `, ` `)) } +// quoteColumnName escapes an object identifier according to the +// rules in Snowflake. +// +// https://docs.snowflake.com/en/sql-reference/identifiers-syntax +func quoteColumnName(name string) string { + var quoted strings.Builder + // Default to assume we're just going to add quotes and there won't + // be any double quotes inside the string that needs escaped. + quoted.Grow(len(name) + 2) + quoted.WriteByte('"') + for _, r := range strings.ToUpper(name) { + if r == '"' { + quoted.WriteString(`""`) + } else { + quoted.WriteRune(r) + } + } + quoted.WriteByte('"') + return quoted.String() +} + // snowflakeTimestampInt computes the same result as the logic in TimestampWrapper // in the Java SDK. It converts a timestamp to the integer representation that // is used internally within Snowflake. diff --git a/internal/impl/snowflake/streaming/compat_test.go b/internal/impl/snowflake/streaming/compat_test.go index 4b4aaffa64..98b17e47a1 100644 --- a/internal/impl/snowflake/streaming/compat_test.go +++ b/internal/impl/snowflake/streaming/compat_test.go @@ -127,6 +127,17 @@ func TestColumnNormalization(t *testing.T) { require.Equal(t, `foo" bar "baz`, normalizeColumnName(`"foo"" bar ""baz"`)) } +func TestColumnQuoting(t *testing.T) { + require.Equal(t, `""`, quoteColumnName("")) + require.Equal(t, `"FOO"`, quoteColumnName("foo")) + require.Equal(t, `"""BAR"""`, quoteColumnName(`"bar"`)) + require.Equal(t, `"FOO BAR"`, quoteColumnName(`foo bar`)) + require.Equal(t, `"FOO\ BAR"`, quoteColumnName(`foo\ bar`)) + require.Equal(t, `"FOO""BAR"`, quoteColumnName(`foo"bar`)) + require.Equal(t, `"FOO""BAR1"`, quoteColumnName(`foo"bar1`)) + require.Equal(t, `""""""""""`, quoteColumnName(`""""`)) +} + func TestSnowflakeTimestamp(t *testing.T) { type TestCase struct { timestamp string diff --git a/internal/impl/snowflake/streaming/parquet.go b/internal/impl/snowflake/streaming/parquet.go index 873ac9f02c..494ba00618 100644 --- a/internal/impl/snowflake/streaming/parquet.go +++ b/internal/impl/snowflake/streaming/parquet.go @@ -13,6 +13,7 @@ package streaming import ( "bytes" "encoding/binary" + "errors" "fmt" "github.com/parquet-go/parquet-go" @@ -24,7 +25,7 @@ import ( // messageToRow converts a message into columnar form using the provided name to index mapping. // We have to materialize the column into a row so that we can know if a column is null - the // msg can be sparse, but the row must not be sparse. -func messageToRow(msg *service.Message, out []any, nameToPosition map[string]int) error { +func messageToRow(msg *service.Message, out []any, nameToPosition map[string]int, allowExtraProperties bool) error { v, err := msg.AsStructured() if err != nil { return fmt.Errorf("error extracting object from message: %w", err) @@ -36,8 +37,9 @@ func messageToRow(msg *service.Message, out []any, nameToPosition map[string]int for k, v := range row { idx, ok := nameToPosition[normalizeColumnName(k)] if !ok { - // TODO(schema): Unknown column, we just skip it. - // In the future we may evolve the schema based on the new data. + if !allowExtraProperties && v != nil { + return MissingColumnError{columnName: k, val: v} + } continue } out[idx] = v @@ -49,6 +51,7 @@ func constructRowGroup( batch service.MessageBatch, schema *parquet.Schema, transformers []*dataTransformer, + allowExtraProperties bool, ) ([]parquet.Row, []*statsBuffer, error) { // We write all of our data in a columnar fashion, but need to pivot that data so that we can feed it into // out parquet library (which sadly will redo the pivot - maybe we need a lower level abstraction...). @@ -76,7 +79,7 @@ func constructRowGroup( // is needed row := make([]any, rowWidth) for _, msg := range batch { - err := messageToRow(msg, row, nameToPosition) + err := messageToRow(msg, row, nameToPosition, allowExtraProperties) if err != nil { return nil, nil, err } @@ -86,7 +89,11 @@ func constructRowGroup( b := buffers[i] err = t.converter.ValidateAndConvert(s, v, b) if err != nil { - // TODO(schema): if this is a null value err then we can evolve the schema to mark it null. + if errors.Is(err, errNullValue) { + return nil, nil, NonNullColumnError{t.column.Name} + } + // There is not special typed error for a validation error, there really isn't + // anything we can do about it. return nil, nil, fmt.Errorf("invalid data for column %s: %w", t.name, err) } // reset the column as nil for the next row diff --git a/internal/impl/snowflake/streaming/parquet_test.go b/internal/impl/snowflake/streaming/parquet_test.go index 27c2581345..f4e2ecb47d 100644 --- a/internal/impl/snowflake/streaming/parquet_test.go +++ b/internal/impl/snowflake/streaming/parquet_test.go @@ -60,6 +60,7 @@ func TestWriteParquet(t *testing.T) { batch, schema, transformers, + false, ) require.NoError(t, err) b, err := writeParquetFile("latest", parquetFileData{ diff --git a/internal/impl/snowflake/streaming/rest.go b/internal/impl/snowflake/streaming/rest.go index 737d64f00f..bca68015e8 100644 --- a/internal/impl/snowflake/streaming/rest.go +++ b/internal/impl/snowflake/streaming/rest.go @@ -249,14 +249,22 @@ type ( Message string `json:"message"` Blobs []blobRegisterStatus `json:"blobs"` } + // BindingValue is a value available as a binding variable in a SQL statement. + BindingValue struct { + // The binding data type, generally TEXT is what you want + // see: https://docs.snowflake.com/en/developer-guide/sql-api/submitting-requests#using-bind-variables-in-a-statement + Type string `json:"type"` + Value string `json:"value"` + } // RunSQLRequest is the way to run a SQL statement RunSQLRequest struct { - Statement string `json:"statement"` - Timeout int64 `json:"timeout"` - Database string `json:"database,omitempty"` - Schema string `json:"schema,omitempty"` - Warehouse string `json:"warehouse,omitempty"` - Role string `json:"role,omitempty"` + Statement string `json:"statement"` + Timeout int64 `json:"timeout"` + Database string `json:"database,omitempty"` + Schema string `json:"schema,omitempty"` + Warehouse string `json:"warehouse,omitempty"` + Role string `json:"role,omitempty"` + Bindings map[string]BindingValue `json:"bindings,omitempty"` // https://docs.snowflake.com/en/sql-reference/parameters Parameters map[string]string `json:"parameters,omitempty"` } diff --git a/internal/impl/snowflake/streaming/schema_errors.go b/internal/impl/snowflake/streaming/schema_errors.go new file mode 100644 index 0000000000..cd594b0e8b --- /dev/null +++ b/internal/impl/snowflake/streaming/schema_errors.go @@ -0,0 +1,81 @@ +/* + * Copyright 2024 Redpanda Data, Inc. + * + * Licensed as a Redpanda Enterprise file under the Redpanda Community + * License (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://github.com/redpanda-data/redpanda/blob/master/licenses/rcl.md + */ + +package streaming + +import "fmt" + +// SchemaMismatchError occurs when the user provided data has data that +// doesn't match the schema *and* the table can be evolved to accommodate +// +// This can be used as a mechanism to evolve the schema dynamically. +type SchemaMismatchError interface { + ColumnName() string + Value() any +} + +var _ error = NonNullColumnError{} +var _ SchemaMismatchError = NonNullColumnError{} + +// NonNullColumnError occurs when a column with a NOT NULL constraint +// gets a value with a `NULL` value. +type NonNullColumnError struct { + columnName string +} + +// ColumnName returns the column name with the NOT NULL constraint +func (e NonNullColumnError) ColumnName() string { + // This name comes directly from the Snowflake API so I hope this is properly quoted... + return e.columnName +} + +// Value returns nil +func (e NonNullColumnError) Value() any { + return nil +} + +// Error implements the error interface +func (e NonNullColumnError) Error() string { + return fmt.Sprintf("column %q has a NOT NULL constraint and recieved a nil value", e.columnName) +} + +var _ error = MissingColumnError{} +var _ SchemaMismatchError = MissingColumnError{} + +// MissingColumnError occurs when a column that is not in the table is +// found on a record +type MissingColumnError struct { + columnName string + val any +} + +// ColumnName returns the column name of the data that was not in the table +// +// NOTE this is escaped, so it's valid to use this directly in a SQL statement +// but I wish that Snowflake would just allow `identifier` for ALTER column. +func (e MissingColumnError) ColumnName() string { + return quoteColumnName(e.columnName) +} + +// RawName is the unquoted name of the new column - DO NOT USE IN SQL! +// This is the more intutitve name for users in the mapping function +func (e MissingColumnError) RawName() string { + return e.columnName +} + +// Value returns the value that was associated with the missing column +func (e MissingColumnError) Value() any { + return e.val +} + +// Error implements the error interface +func (e MissingColumnError) Error() string { + return fmt.Sprintf("new data %+v with the name %q does not have an associated column", e.val, e.columnName) +} diff --git a/internal/impl/snowflake/streaming/streaming.go b/internal/impl/snowflake/streaming/streaming.go index 4f22879172..5c68aa6740 100644 --- a/internal/impl/snowflake/streaming/streaming.go +++ b/internal/impl/snowflake/streaming/streaming.go @@ -145,6 +145,8 @@ type ChannelOptions struct { TableName string // The max parallelism used to build parquet files and convert message batches into rows. BuildParallelism int + // If set to true, don't ignore extra columns in user data, but raise an error. + StrictSchemaEnforcement bool } type encryptionInfo struct { @@ -304,7 +306,7 @@ func (c *SnowflakeIngestionChannel) constructBdecPart(batch service.MessageBatch rowGroups = append(rowGroups, rowGroup{}) chunk := batch[i : i+end] wg.Go(func() error { - rows, stats, err := constructRowGroup(chunk, c.schema, c.transformers) + rows, stats, err := constructRowGroup(chunk, c.schema, c.transformers, !c.StrictSchemaEnforcement) rowGroups[j] = rowGroup{rows, stats} return err })