Skip to content

Commit

Permalink
Implement enum8 bounds check with bitset
Browse files Browse the repository at this point in the history
  • Loading branch information
mdonkers committed Jan 9, 2025
1 parent 92a353c commit bb51f19
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 26 deletions.
7 changes: 2 additions & 5 deletions lib/column/enum.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,9 @@ import (
"bytes"
"errors"
"math"
"slices"
"strconv"

"github.com/ClickHouse/ch-go/proto"
"golang.org/x/exp/maps"
)

func Enum(chType Type, name string) (Interface, error) {
Expand All @@ -48,10 +46,9 @@ func Enum(chType Type, name string) (Interface, error) {
v := int8(indexes[i])
enum.iv[values[i]] = proto.Enum8(v)
enum.vi[proto.Enum8(v)] = values[i]

enum.enumValuesBitset[uint8(v)>>6] |= 1 << (v & 63)
}
enum.minEnum = int8(slices.Min(maps.Keys(enum.vi)))
enum.maxEnum = int8(slices.Max(maps.Keys(enum.vi)))
enum.continuous = (enum.maxEnum-enum.minEnum)+1 == int8(len(enum.vi))
return &enum, nil
}
enum := Enum16{
Expand Down
32 changes: 11 additions & 21 deletions lib/column/enum8.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ type Enum8 struct {
name string
col proto.ColEnum8

continuous bool
minEnum int8
maxEnum int8
// Encoding of the enums that have been specified by the user.
// Using this when appending rows, to validate the enum is valud.
enumValuesBitset [4]uint64
}

func (col *Enum8) Reset() {
Expand Down Expand Up @@ -183,39 +183,29 @@ func (col *Enum8) Append(v any) (nulls []uint8, err error) {
func (col *Enum8) AppendRow(elem any) error {
switch elem := elem.(type) {
case int8:
if col.continuous && elem >= col.minEnum && elem <= col.maxEnum {
col.col.Append(proto.Enum8(elem))
} else {
return col.AppendRow(int(elem))
}
return col.AppendRow(int(elem))
case *int8:
if col.continuous && *elem >= col.minEnum && *elem <= col.maxEnum {
col.col.Append(proto.Enum8(*elem))
} else {
return col.AppendRow(int(*elem))
}
return col.AppendRow(int(*elem))
case int:
v := proto.Enum8(elem)
_, ok := col.vi[v]
if !ok {
// Check if the enum value is defined
if col.enumValuesBitset[uint8(elem)>>6]&(1<<(elem&63)) == 0 {
return &Error{
Err: fmt.Errorf("unknown element %v", elem),
ColumnType: string(col.chType),
}
}
col.col.Append(v)
col.col.Append(proto.Enum8(elem))
case *int:
switch {
case elem != nil:
v := proto.Enum8(*elem)
_, ok := col.vi[v]
if !ok {
// Check if the enum value is defined
if col.enumValuesBitset[uint8(*elem)>>6]&(1<<(*elem&63)) == 0 {
return &Error{
Err: fmt.Errorf("unknown element %v", *elem),
ColumnType: string(col.chType),
}
}
col.col.Append(v)
col.col.Append(proto.Enum8(*elem))
default:
col.col.Append(0)
}
Expand Down
50 changes: 50 additions & 0 deletions lib/column/enum_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package column

import (
"slices"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -155,3 +156,52 @@ func TestExtractEnumNamedValues(t *testing.T) {
})
}
}

func TestEnumValuesBoundsChecks(t *testing.T) {
tests := []struct {
name string
enumType string
validEnums []int
}{
{
name: "Simple enum range",
enumType: "Enum8('-2'=-2,'-1'=-1,'0'=0,'1'=1,'2'=2)",
validEnums: createValidEnumsRange(-2, 2),
},
{
name: "Full enum range",
enumType: "Enum8('-128'=-128,'-127'=-127,'-126'=-126,'-125'=-125,'-124'=-124,'-123'=-123,'-122'=-122,'-121'=-121,'-120'=-120,'-119'=-119,'-118'=-118,'-117'=-117,'-116'=-116,'-115'=-115,'-114'=-114,'-113'=-113,'-112'=-112,'-111'=-111,'-110'=-110,'-109'=-109,'-108'=-108,'-107'=-107,'-106'=-106,'-105'=-105,'-104'=-104,'-103'=-103,'-102'=-102,'-101'=-101,'-100'=-100,'-99'=-99,'-98'=-98,'-97'=-97,'-96'=-96,'-95'=-95,'-94'=-94,'-93'=-93,'-92'=-92,'-91'=-91,'-90'=-90,'-89'=-89,'-88'=-88,'-87'=-87,'-86'=-86,'-85'=-85,'-84'=-84,'-83'=-83,'-82'=-82,'-81'=-81,'-80'=-80,'-79'=-79,'-78'=-78,'-77'=-77,'-76'=-76,'-75'=-75,'-74'=-74,'-73'=-73,'-72'=-72,'-71'=-71,'-70'=-70,'-69'=-69,'-68'=-68,'-67'=-67,'-66'=-66,'-65'=-65,'-64'=-64,'-63'=-63,'-62'=-62,'-61'=-61,'-60'=-60,'-59'=-59,'-58'=-58,'-57'=-57,'-56'=-56,'-55'=-55,'-54'=-54,'-53'=-53,'-52'=-52,'-51'=-51,'-50'=-50,'-49'=-49,'-48'=-48,'-47'=-47,'-46'=-46,'-45'=-45,'-44'=-44,'-43'=-43,'-42'=-42,'-41'=-41,'-40'=-40,'-39'=-39,'-38'=-38,'-37'=-37,'-36'=-36,'-35'=-35,'-34'=-34,'-33'=-33,'-32'=-32,'-31'=-31,'-30'=-30,'-29'=-29,'-28'=-28,'-27'=-27,'-26'=-26,'-25'=-25,'-24'=-24,'-23'=-23,'-22'=-22,'-21'=-21,'-20'=-20,'-19'=-19,'-18'=-18,'-17'=-17,'-16'=-16,'-15'=-15,'-14'=-14,'-13'=-13,'-12'=-12,'-11'=-11,'-10'=-10,'-9'=-9,'-8'=-8,'-7'=-7,'-6'=-6,'-5'=-5,'-4'=-4,'-3'=-3,'-2'=-2,'-1'=-1,'0'=0,'1'=1,'2'=2,'3'=3,'4'=4,'5'=5,'6'=6,'7'=7,'8'=8,'9'=9,'10'=10,'11'=11,'12'=12,'13'=13,'14'=14,'15'=15,'16'=16,'17'=17,'18'=18,'19'=19,'20'=20,'21'=21,'22'=22,'23'=23,'24'=24,'25'=25,'26'=26,'27'=27,'28'=28,'29'=29,'30'=30,'31'=31,'32'=32,'33'=33,'34'=34,'35'=35,'36'=36,'37'=37,'38'=38,'39'=39,'40'=40,'41'=41,'42'=42,'43'=43,'44'=44,'45'=45,'46'=46,'47'=47,'48'=48,'49'=49,'50'=50,'51'=51,'52'=52,'53'=53,'54'=54,'55'=55,'56'=56,'57'=57,'58'=58,'59'=59,'60'=60,'61'=61,'62'=62,'63'=63,'64'=64,'65'=65,'66'=66,'67'=67,'68'=68,'69'=69,'70'=70,'71'=71,'72'=72,'73'=73,'74'=74,'75'=75,'76'=76,'77'=77,'78'=78,'79'=79,'80'=80,'81'=81,'82'=82,'83'=83,'84'=84,'85'=85,'86'=86,'87'=87,'88'=88,'89'=89,'90'=90,'91'=91,'92'=92,'93'=93,'94'=94,'95'=95,'96'=96,'97'=97,'98'=98,'99'=99,'100'=100,'101'=101,'102'=102,'103'=103,'104'=104,'105'=105,'106'=106,'107'=107,'108'=108,'109'=109,'110'=110,'111'=111,'112'=112,'113'=113,'114'=114,'115'=115,'116'=116,'117'=117,'118'=118,'119'=119,'120'=120,'121'=121,'122'=122,'123'=123,'124'=124,'125'=125,'126'=126,'127'=127)",
validEnums: createValidEnumsRange(-128, 127),
},
{
name: "Enum range with gaps",
enumType: "Enum8('-10'=-10,'-5'=-5,'0'=0,'1'=1,'5'=5,'10'=10)",
validEnums: []int{-10, -5, 0, 1, 5, 10},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e, err := Enum(Type(tt.enumType), tt.name)
assert.NoError(t, err)

// Try appending the full enum8 range. If the value is in the validEnums slice it should not error
for i := -128; i < 128; i++ {
valid := e.AppendRow(i)

if slices.Contains(tt.validEnums, i) {
assert.NoError(t, valid)
} else {
assert.Error(t, valid)
}
}
})
}
}

func createValidEnumsRange(min, max int) []int {
resultRange := make([]int, 0, max-min+1)
for i := min; i <= max; i++ {
resultRange = append(resultRange, i)
}
return resultRange
}

0 comments on commit bb51f19

Please sign in to comment.