Skip to content

Commit 9d74ad2

Browse files
sougoubramp
authored andcommitted
parser: improved handling of SET
This change improves the handling of constructs like SET NAMES and SET CHARSET. It allows intermixing of these with other assignments. VTGate and mysqlproxy have been correspondingly updated. Signed-off-by: Sugu Sougoumarane <[email protected]>
1 parent 90c68f6 commit 9d74ad2

File tree

7 files changed

+2153
-2063
lines changed

7 files changed

+2153
-2063
lines changed

analyzer.go

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -267,21 +267,18 @@ func StringIn(str string, values ...string) bool {
267267
// if the query is a SET statement. Values can be int64 or string.
268268
// Since set variable names are case insensitive, all keys are returned
269269
// as lower case.
270-
func ExtractSetValues(sql string) (keyValues map[string]interface{}, charset string, scope string, err error) {
270+
func ExtractSetValues(sql string) (keyValues map[string]interface{}, scope string, err error) {
271271
stmt, err := Parse(sql)
272272
if err != nil {
273-
return nil, "", "", err
273+
return nil, "", err
274274
}
275275
setStmt, ok := stmt.(*Set)
276276
if !ok {
277-
return nil, "", "", fmt.Errorf("ast did not yield *sqlparser.Set: %T", stmt)
277+
return nil, "", fmt.Errorf("ast did not yield *sqlparser.Set: %T", stmt)
278278
}
279279
result := make(map[string]interface{})
280280
for _, expr := range setStmt.Exprs {
281-
if !expr.Name.Qualifier.IsEmpty() {
282-
return nil, "", "", fmt.Errorf("invalid syntax: %v", String(expr.Name))
283-
}
284-
key := expr.Name.Name.Lowered()
281+
key := expr.Name.Lowered()
285282

286283
switch expr := expr.Expr.(type) {
287284
case *SQLVal:
@@ -291,19 +288,19 @@ func ExtractSetValues(sql string) (keyValues map[string]interface{}, charset str
291288
case IntVal:
292289
num, err := strconv.ParseInt(string(expr.Val), 0, 64)
293290
if err != nil {
294-
return nil, "", "", err
291+
return nil, "", err
295292
}
296293
result[key] = num
297294
default:
298-
return nil, "", "", fmt.Errorf("invalid value type: %v", String(expr))
295+
return nil, "", fmt.Errorf("invalid value type: %v", String(expr))
299296
}
300297
case *NullVal:
301298
result[key] = nil
302299
case *Default:
303300
result[key] = "default"
304301
default:
305-
return nil, "", "", fmt.Errorf("invalid syntax: %s", String(expr))
302+
return nil, "", fmt.Errorf("invalid syntax: %s", String(expr))
306303
}
307304
}
308-
return result, setStmt.Charset.Lowered(), strings.ToLower(setStmt.Scope), nil
305+
return result, strings.ToLower(setStmt.Scope), nil
309306
}

analyzer_test.go

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -360,20 +360,16 @@ func TestStringIn(t *testing.T) {
360360

361361
func TestExtractSetValues(t *testing.T) {
362362
testcases := []struct {
363-
sql string
364-
out map[string]interface{}
365-
charset string
366-
scope string
367-
err string
363+
sql string
364+
out map[string]interface{}
365+
scope string
366+
err string
368367
}{{
369368
sql: "invalid",
370369
err: "syntax error at position 8 near 'invalid'",
371370
}, {
372371
sql: "select * from t",
373372
err: "ast did not yield *sqlparser.Set: *sqlparser.Select",
374-
}, {
375-
sql: "set a.autocommit=1",
376-
err: "invalid syntax: a.autocommit",
377373
}, {
378374
sql: "set autocommit=1+1",
379375
err: "invalid syntax: 1 + 1",
@@ -393,21 +389,17 @@ func TestExtractSetValues(t *testing.T) {
393389
sql: "SET foo = 0x1234",
394390
err: "invalid value type: 0x1234",
395391
}, {
396-
sql: "SET names utf8",
397-
out: map[string]interface{}{},
398-
charset: "utf8",
392+
sql: "SET names utf8",
393+
out: map[string]interface{}{"names": "utf8"},
399394
}, {
400-
sql: "SET names ascii collation ascii_bin",
401-
out: map[string]interface{}{},
402-
charset: "ascii",
395+
sql: "SET names ascii collate ascii_bin",
396+
out: map[string]interface{}{"names": "ascii"},
403397
}, {
404-
sql: "SET charset default",
405-
out: map[string]interface{}{},
406-
charset: "default",
398+
sql: "SET charset default",
399+
out: map[string]interface{}{"charset": "default"},
407400
}, {
408-
sql: "SET character set ascii",
409-
out: map[string]interface{}{},
410-
charset: "ascii",
401+
sql: "SET character set ascii",
402+
out: map[string]interface{}{"charset": "ascii"},
411403
}, {
412404
sql: "SET SESSION wait_timeout = 3600",
413405
out: map[string]interface{}{"wait_timeout": int64(3600)},
@@ -418,7 +410,7 @@ func TestExtractSetValues(t *testing.T) {
418410
scope: "global",
419411
}}
420412
for _, tcase := range testcases {
421-
out, charset, _, err := ExtractSetValues(tcase.sql)
413+
out, _, err := ExtractSetValues(tcase.sql)
422414
if tcase.err != "" {
423415
if err == nil || err.Error() != tcase.err {
424416
t.Errorf("ExtractSetValues(%s): %v, want '%s'", tcase.sql, err, tcase.err)
@@ -429,9 +421,6 @@ func TestExtractSetValues(t *testing.T) {
429421
if !reflect.DeepEqual(out, tcase.out) {
430422
t.Errorf("ExtractSetValues(%s): %v, want '%v'", tcase.sql, out, tcase.out)
431423
}
432-
if charset != tcase.charset {
433-
t.Errorf("ExtractSetValues(%s): %v, want '%v'", tcase.sql, charset, tcase.charset)
434-
}
435424
}
436425
}
437426

ast.go

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -562,8 +562,7 @@ func (node *Delete) WalkSubtree(visit Visit) error {
562562
// Set represents a SET statement.
563563
type Set struct {
564564
Comments Comments
565-
Exprs UpdateExprs
566-
Charset ColIdent
565+
Exprs SetExprs
567566
Scope string
568567
}
569568

@@ -2909,6 +2908,56 @@ func (node *UpdateExpr) WalkSubtree(visit Visit) error {
29092908
)
29102909
}
29112910

2911+
// SetExprs represents a list of set expressions.
2912+
type SetExprs []*SetExpr
2913+
2914+
// Format formats the node.
2915+
func (node SetExprs) Format(buf *TrackedBuffer) {
2916+
var prefix string
2917+
for _, n := range node {
2918+
buf.Myprintf("%s%v", prefix, n)
2919+
prefix = ", "
2920+
}
2921+
}
2922+
2923+
// WalkSubtree walks the nodes of the subtree.
2924+
func (node SetExprs) WalkSubtree(visit Visit) error {
2925+
for _, n := range node {
2926+
if err := Walk(visit, n); err != nil {
2927+
return err
2928+
}
2929+
}
2930+
return nil
2931+
}
2932+
2933+
// SetExpr represents a set expression.
2934+
type SetExpr struct {
2935+
Name ColIdent
2936+
Expr Expr
2937+
}
2938+
2939+
// Format formats the node.
2940+
func (node *SetExpr) Format(buf *TrackedBuffer) {
2941+
// We don't have to backtick set variable names.
2942+
if node.Name.EqualString("charset") || node.Name.EqualString("names") {
2943+
buf.Myprintf("%s %v", node.Name.String(), node.Expr)
2944+
} else {
2945+
buf.Myprintf("%s = %v", node.Name.String(), node.Expr)
2946+
}
2947+
}
2948+
2949+
// WalkSubtree walks the nodes of the subtree.
2950+
func (node *SetExpr) WalkSubtree(visit Visit) error {
2951+
if node == nil {
2952+
return nil
2953+
}
2954+
return Walk(
2955+
visit,
2956+
node.Name,
2957+
node.Expr,
2958+
)
2959+
}
2960+
29122961
// OnDup represents an ON DUPLICATE KEY clause.
29132962
type OnDup UpdateExprs
29142963

parse_next_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,11 @@ func TestParseNextEdgeCases(t *testing.T) {
120120
}, {
121121
name: "Handle ForceEOF statements",
122122
input: "set character set utf8; select 1 from a",
123-
want: []string{"set ", "select 1 from a"},
123+
want: []string{"set charset 'utf8'", "select 1 from a"},
124124
}, {
125125
name: "Semicolin inside a string",
126126
input: "set character set ';'; select 1 from a",
127-
want: []string{"set ", "select 1 from a"},
127+
want: []string{"set charset ';'", "select 1 from a"},
128128
}, {
129129
name: "Partial DDL",
130130
input: "create table a; select 1 from a",

parse_test.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -657,24 +657,26 @@ var (
657657
input: "set character_set_results = utf8",
658658
}, {
659659
input: "set names utf8 collate foo",
660-
output: "set ",
660+
output: "set names 'utf8'",
661661
}, {
662662
input: "set character set utf8",
663-
output: "set ",
663+
output: "set charset 'utf8'",
664664
}, {
665665
input: "set character set 'utf8'",
666-
output: "set ",
666+
output: "set charset 'utf8'",
667667
}, {
668668
input: "set character set \"utf8\"",
669-
output: "set ",
669+
output: "set charset 'utf8'",
670670
}, {
671671
input: "set charset default",
672-
output: "set ",
672+
output: "set charset default",
673673
}, {
674674
input: "set session wait_timeout = 3600",
675675
output: "set session wait_timeout = 3600",
676676
}, {
677677
input: "set /* list */ a = 3, b = 4",
678+
}, {
679+
input: "set /* mixed list */ a = 3, names 'utf8', charset 'ascii', b = 4",
678680
}, {
679681
input: "alter ignore table a add foo",
680682
output: "alter table a",

0 commit comments

Comments
 (0)