Skip to content

Commit 4645aca

Browse files
authored
Merge pull request #42 from oracle-samples/multi_primary_key_fix
Fix multi primary key tests.
2 parents baadd6e + 602cc4e commit 4645aca

File tree

3 files changed

+57
-12
lines changed

3 files changed

+57
-12
lines changed

oracle/clause_builder.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,16 @@ func OnConflictClauseBuilder(c clause.Clause, builder clause.Builder) {
325325
missingColumns = append(missingColumns, conflictCol.Name)
326326
}
327327
}
328+
328329
if len(missingColumns) > 0 {
330+
// primary keys with auto increment will always be missing from create values columns
331+
for _, missingCol := range missingColumns {
332+
field := stmt.Schema.LookUpField(missingCol)
333+
if field != nil && field.PrimaryKey && field.AutoIncrement {
334+
return
335+
}
336+
}
337+
329338
var selectedColumns []string
330339
for col := range selectedColumnSet {
331340
selectedColumns = append(selectedColumns, col)
@@ -335,6 +344,34 @@ func OnConflictClauseBuilder(c clause.Clause, builder clause.Builder) {
335344
return
336345
}
337346

347+
// exclude primary key, default value columns from merge update clause
348+
if len(onConflict.DoUpdates) > 0 {
349+
hasPrimaryKey := false
350+
351+
for _, assignment := range onConflict.DoUpdates {
352+
field := stmt.Schema.LookUpField(assignment.Column.Name)
353+
if field != nil && field.PrimaryKey {
354+
hasPrimaryKey = true
355+
break
356+
}
357+
}
358+
359+
if hasPrimaryKey {
360+
onConflict.DoUpdates = nil
361+
columns := make([]string, 0, len(values.Columns)-1)
362+
for _, col := range values.Columns {
363+
field := stmt.Schema.LookUpField(col.Name)
364+
365+
if field != nil && !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil ||
366+
strings.EqualFold(field.DefaultValue, "NULL")) && field.AutoCreateTime == 0 {
367+
columns = append(columns, col.Name)
368+
}
369+
370+
}
371+
onConflict.DoUpdates = append(onConflict.DoUpdates, clause.AssignmentColumns(columns)...)
372+
}
373+
}
374+
338375
// Build MERGE statement
339376
buildMergeInClause(stmt, onConflict, values, conflictColumns)
340377
}

oracle/create.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,9 +267,11 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
267267
valuesColumnMap[strings.ToUpper(column.Name)] = true
268268
}
269269

270+
// Filter conflict columns to remove non unique columns
270271
var filteredConflictColumns []clause.Column
271272
for _, conflictCol := range conflictColumns {
272-
if valuesColumnMap[strings.ToUpper(conflictCol.Name)] {
273+
field := stmt.Schema.LookUpField(conflictCol.Name)
274+
if valuesColumnMap[strings.ToUpper(conflictCol.Name)] && (field.Unique || field.AutoIncrement) {
273275
filteredConflictColumns = append(filteredConflictColumns, conflictCol)
274276
}
275277
}
@@ -336,6 +338,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
336338

337339
// Build ON clause using conflict columns
338340
plsqlBuilder.WriteString(" ON (")
341+
339342
for idx, conflictCol := range conflictColumns {
340343
if idx > 0 {
341344
plsqlBuilder.WriteString(" AND ")
@@ -425,7 +428,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
425428
}
426429
plsqlBuilder.WriteString(" WHEN MATCHED THEN UPDATE SET t.")
427430
writeQuotedIdentifier(&plsqlBuilder, noopCol)
428-
plsqlBuilder.WriteString(" = t.")
431+
plsqlBuilder.WriteString(" = s.")
429432
writeQuotedIdentifier(&plsqlBuilder, noopCol)
430433
plsqlBuilder.WriteString("\n")
431434
}

tests/multi_primary_keys_test.go

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -75,18 +75,17 @@ func compareTags(tags []Tag, contents []string) bool {
7575
}
7676

7777
func TestManyToManyWithMultiPrimaryKeys(t *testing.T) {
78-
t.Skip()
7978
if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" {
8079
t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment")
8180
}
8281

83-
if name := DB.Dialector.Name(); name == "postgres" {
82+
if name := DB.Dialector.Name(); name == "postgres" || name == "oracle" {
8483
stmt := gorm.Statement{DB: DB}
8584
stmt.Parse(&Blog{})
8685
stmt.Schema.LookUpField("ID").Unique = true
8786
stmt.Parse(&Tag{})
8887
stmt.Schema.LookUpField("ID").Unique = true
89-
// postgers only allow unique constraint matching given keys
88+
// postgers and oracle only allow unique constraint matching given keys
9089
}
9190

9291
DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags", "locale_blog_tags", "shared_blog_tags")
@@ -300,7 +299,6 @@ func TestManyToManyWithCustomizedForeignKeys(t *testing.T) {
300299
}
301300

302301
func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) {
303-
t.Skip()
304302
if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" {
305303
t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment")
306304
}
@@ -309,6 +307,15 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) {
309307
t.Skip("skip postgres due to it only allow unique constraint matching given keys")
310308
}
311309

310+
if name := DB.Dialector.Name(); name == "oracle" {
311+
stmt := gorm.Statement{DB: DB}
312+
stmt.Parse(&Blog{})
313+
stmt.Schema.LookUpField("ID").Unique = true
314+
stmt.Parse(&Tag{})
315+
stmt.Schema.LookUpField("ID").Unique = true
316+
// oracle only allow unique constraint matching given keys
317+
}
318+
312319
DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags", "locale_blog_tags", "shared_blog_tags")
313320
if err := DB.AutoMigrate(&Blog{}, &Tag{}); err != nil {
314321
t.Fatalf("Failed to auto migrate, got error: %v", err)
@@ -326,7 +333,7 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) {
326333
DB.Save(&blog)
327334

328335
blog2 := Blog{
329-
ID: blog.ID,
336+
ID: 2,
330337
Locale: "EN",
331338
}
332339
DB.Create(&blog2)
@@ -358,7 +365,7 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) {
358365
}
359366

360367
var blog1 Blog
361-
DB.Preload("LocaleTags").Find(&blog1, "locale = ? AND id = ?", "ZH", blog.ID)
368+
DB.Preload("LocaleTags").Find(&blog1, "\"locale\" = ? AND \"id\" = ?", "ZH", blog.ID)
362369
if !compareTags(blog1.LocaleTags, []string{"tag1", "tag2", "tag3"}) {
363370
t.Fatalf("Preload many2many relations")
364371
}
@@ -388,7 +395,7 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) {
388395
}
389396

390397
var blog11 Blog
391-
DB.Preload("LocaleTags").First(&blog11, "id = ? AND locale = ?", blog.ID, blog.Locale)
398+
DB.Preload("LocaleTags").First(&blog11, "\"id\" = ? AND \"locale\" = ?", blog.ID, blog.Locale)
392399
if !compareTags(blog11.LocaleTags, []string{"tag1", "tag2", "tag3"}) {
393400
t.Fatalf("CN Blog's tags should not be changed after EN Blog Replace")
394401
}
@@ -399,7 +406,7 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) {
399406
}
400407

401408
var blog21 Blog
402-
DB.Preload("LocaleTags").First(&blog21, "id = ? AND locale = ?", blog2.ID, blog2.Locale)
409+
DB.Preload("LocaleTags").First(&blog21, "\"id\" = ? AND \"locale\" = ?", blog2.ID, blog2.Locale)
403410
if !compareTags(blog21.LocaleTags, []string{"tag5", "tag6"}) {
404411
t.Fatalf("EN Blog's tags should be changed after Replace")
405412
}
@@ -454,8 +461,6 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) {
454461
}
455462

456463
func TestCompositePrimaryKeysAssociations(t *testing.T) {
457-
t.Skip()
458-
459464
type Label struct {
460465
BookID *uint `gorm:"primarykey"`
461466
Name string `gorm:"primarykey"`

0 commit comments

Comments
 (0)