Skip to content

Commit

Permalink
Merge pull request #3089 from redpanda-data/pgcdc
Browse files Browse the repository at this point in the history
pgcdc: fix null value handling
  • Loading branch information
rockwotj authored Dec 20, 2024
2 parents 013da5f + 6c959c6 commit 2e31344
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 27 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ All notable changes to this project will be documented in this file.
### Fixed

- The `code` and `file` fields on the `javascript` processor docs no longer erroneously mention interpolation support. (@mihaitodor)
- The `postgres_cdc` now correctly handles `null` values. (@rockwotj)

## 4.44.0 - 2024-12-13

Expand Down
24 changes: 13 additions & 11 deletions internal/impl/postgresql/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,9 @@ func TestIntegrationPgCDCForPgOutputStreamComplexTypesPlugin(t *testing.T) {
);`)
require.NoError(t, err)

_, err = db.Exec(`INSERT INTO complex_types_example (json_data) VALUES ('{"nested":null}'::jsonb);`)
require.NoError(t, err)

databaseURL := fmt.Sprintf("user=user_name password=%s dbname=dbname sslmode=disable host=%s port=%s", password, hostAndPortSplited[0], hostAndPortSplited[1])
template := fmt.Sprintf(`
pg_stream:
Expand All @@ -557,7 +560,7 @@ file:
`, tmpDir)

streamOutBuilder := service.NewStreamBuilder()
require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: INFO`))
require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: TRACE`))
require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf))
require.NoError(t, streamOutBuilder.AddInputYAML(template))

Expand Down Expand Up @@ -585,29 +588,28 @@ file:
require.Eventually(t, func() bool {
outBatchMut.Lock()
defer outBatchMut.Unlock()
return len(outBatches) == 1
return len(outBatches) == 2
}, time.Second*25, time.Millisecond*100)

messageWithComplexTypes := outBatches[0]

// producing change to non-complex type to trigger replication and receive updated row so we can check the complex types again
// but after they have been produced by replication to ensure the consistency
_, err = db.Exec("UPDATE complex_types_example SET id = 2 WHERE id = 1")
_, err = db.Exec("UPDATE complex_types_example SET id = 3 WHERE id = 1")
require.NoError(t, err)
_, err = db.Exec("UPDATE complex_types_example SET id = 4 WHERE id = 2")
require.NoError(t, err)

assert.Eventually(t, func() bool {
outBatchMut.Lock()
defer outBatchMut.Unlock()
return len(outBatches) == 2
return len(outBatches) == 4
}, time.Second*25, time.Millisecond*100)

// replacing update with insert to remove replication messages type differences
// so we will be checking only the data
lastMessage := outBatches[len(outBatches)-1]
lastMessage = strings.Replace(lastMessage, "update", "insert", 1)
messageWithComplexTypes = strings.Replace(messageWithComplexTypes, "\"table_snapshot_progress\":0,", "", 1)

require.Equal(t, messageWithComplexTypes, strings.Replace(lastMessage, ":2", ":1", 1))
require.JSONEq(t, `{"id":1, "int_array":[1, 2, 3, 4, 5], "ip_addr":"192.168.1.1/32", "json_data":{"name":"test", "value":42}, "location": "(45.5,-122.6)", "search_text":"'brown':3 'dog':9 'fox':4 'jump':5 'lazi':8 'quick':2", "tags":["tag1", "tag2", "tag3"], "time_range": "[2024-01-01 00:00:00,2024-12-31 00:00:00)", "uuid_col":"a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"}`, outBatches[0])
require.JSONEq(t, `{"id":2, "int_array":null, "ip_addr":null, "json_data":{"nested":null}, "location":null, "search_text":null, "tags":null, "time_range":null, "uuid_col":null}`, outBatches[1])
require.JSONEq(t, `{"id":3, "int_array":[1, 2, 3, 4, 5], "ip_addr":"192.168.1.1/32", "json_data":{"name":"test", "value":42}, "location": "(45.5,-122.6)", "search_text":"'brown':3 'dog':9 'fox':4 'jump':5 'lazi':8 'quick':2", "tags":["tag1", "tag2", "tag3"], "time_range": "[2024-01-01 00:00:00,2024-12-31 00:00:00)", "uuid_col":"a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"}`, outBatches[2])
require.JSONEq(t, `{"id":4, "int_array":null, "ip_addr":null, "json_data":{"nested":null}, "location":null, "search_text":null, "tags":null, "time_range":null, "uuid_col":null}`, outBatches[3])

require.NoError(t, streamOut.StopWithin(time.Second*10))
}
Expand Down
2 changes: 1 addition & 1 deletion internal/impl/postgresql/pglogicalstream/logical_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ func (s *Stream) processSnapshot() error {
col := columnNames[i]
var val any
if val, err = getter(scanArgs[i]); err != nil {
return err
return fmt.Errorf("unable to decode column %s: %w", col, err)
}
data[col] = val
normalized := sanitize.QuotePostgresIdentifier(col)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ func decodePgOutput(WALData []byte, relations map[uint32]*RelationMessage, typeM
return nil, fmt.Errorf("unable to decode column data: %w", err)
}
values[colName] = val
default:
return nil, fmt.Errorf("unable to decode column data, unknown data type: %d", col.DataType)
}
}
message.Data = values
Expand Down Expand Up @@ -149,7 +151,10 @@ func decodePgOutput(WALData []byte, relations map[uint32]*RelationMessage, typeM
return message, nil
}

func decodeTextColumnData(mi *pgtype.Map, data []byte, dataType uint32) (interface{}, error) {
func decodeTextColumnData(mi *pgtype.Map, data []byte, dataType uint32) (any, error) {
if data == nil {
return nil, nil
}
if dt, ok := mi.TypeForOID(dataType); ok {
val, err := dt.Codec.DecodeValue(mi, dataType, pgtype.TextFormatCode, data)
if err != nil {
Expand Down Expand Up @@ -177,6 +182,5 @@ func decodeTextColumnData(mi *pgtype.Map, data []byte, dataType uint32) (interfa

return val, err
}

return string(data), nil
}
66 changes: 53 additions & 13 deletions internal/impl/postgresql/pglogicalstream/snapshotter.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,17 +152,39 @@ func (s *Snapshotter) prepareScannersAndGetters(columnTypes []*sql.ColumnType) (
switch v.DatabaseTypeName() {
case "VARCHAR", "TEXT", "UUID", "TIMESTAMP":
scanArgs[i] = new(sql.NullString)
valueGetters[i] = func(v any) (any, error) { return v.(*sql.NullString).String, nil }
valueGetters[i] = func(v any) (any, error) {
str := v.(*sql.NullString)
if !str.Valid {
return nil, nil
}
return str.String, nil
}
case "BOOL":
scanArgs[i] = new(sql.NullBool)
valueGetters[i] = func(v any) (any, error) { return v.(*sql.NullBool).Bool, nil }
valueGetters[i] = func(v any) (any, error) {
val := v.(*sql.NullBool)
if !val.Valid {
return nil, nil
}
return val.Bool, nil
}
case "INT4":
scanArgs[i] = new(sql.NullInt64)
valueGetters[i] = func(v any) (any, error) { return v.(*sql.NullInt64).Int64, nil }
valueGetters[i] = func(v any) (any, error) {
val := v.(*sql.NullInt64)
if !val.Valid {
return nil, nil
}
return val.Int64, nil
}
case "JSONB":
scanArgs[i] = new(sql.NullString)
valueGetters[i] = func(v any) (any, error) {
payload := v.(*sql.NullString).String
str := v.(*sql.NullString)
if !str.Valid {
return nil, nil
}
payload := str.String
if payload == "" {
return payload, nil
}
Expand All @@ -177,8 +199,11 @@ func (s *Snapshotter) prepareScannersAndGetters(columnTypes []*sql.ColumnType) (
scanArgs[i] = new(sql.NullString)
valueGetters[i] = func(v any) (any, error) {
inet := pgtype.Inet{}
val := v.(*sql.NullString).String
if err := inet.Scan(val); err != nil {
val := v.(*sql.NullString)
if !val.Valid {
return nil, nil
}
if err := inet.Scan(val.String); err != nil {
return nil, err
}

Expand All @@ -188,8 +213,11 @@ func (s *Snapshotter) prepareScannersAndGetters(columnTypes []*sql.ColumnType) (
scanArgs[i] = new(sql.NullString)
valueGetters[i] = func(v any) (any, error) {
newArray := pgtype.Tsrange{}
val := v.(*sql.NullString).String
if err := newArray.Scan(val); err != nil {
val := v.(*sql.NullString)
if !val.Valid {
return nil, nil
}
if err := newArray.Scan(val.String); err != nil {
return nil, err
}

Expand All @@ -200,8 +228,11 @@ func (s *Snapshotter) prepareScannersAndGetters(columnTypes []*sql.ColumnType) (
scanArgs[i] = new(sql.NullString)
valueGetters[i] = func(v any) (any, error) {
newArray := pgtype.Int4Array{}
val := v.(*sql.NullString).String
if err := newArray.Scan(val); err != nil {
val := v.(*sql.NullString)
if !val.Valid {
return nil, nil
}
if err := newArray.Scan(val.String); err != nil {
return nil, err
}

Expand All @@ -211,16 +242,25 @@ func (s *Snapshotter) prepareScannersAndGetters(columnTypes []*sql.ColumnType) (
scanArgs[i] = new(sql.NullString)
valueGetters[i] = func(v any) (any, error) {
newArray := pgtype.TextArray{}
val := v.(*sql.NullString).String
if err := newArray.Scan(val); err != nil {
val := v.(*sql.NullString)
if !val.Valid {
return nil, nil
}
if err := newArray.Scan(val.String); err != nil {
return nil, err
}

return newArray.Elements, nil
}
default:
scanArgs[i] = new(sql.NullString)
valueGetters[i] = func(v any) (any, error) { return v.(*sql.NullString).String, nil }
valueGetters[i] = func(v any) (any, error) {
val := v.(*sql.NullString)
if !val.Valid {
return nil, nil
}
return val.String, nil
}
}
}

Expand Down

0 comments on commit 2e31344

Please sign in to comment.