Skip to content

Commit c4b3a90

Browse files
zr-hebobramp
authored andcommitted
add split sql method
Signed-off-by: zr-hebo <[email protected]>
1 parent df78d5a commit c4b3a90

File tree

2 files changed

+77
-0
lines changed

2 files changed

+77
-0
lines changed

ast.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,37 @@ func SplitStatement(blob string) (string, string, error) {
114114
return blob, "", nil
115115
}
116116

117+
// SplitStatementToPieces split raw sql statement that may have multi sql pieces to sql pieces
118+
// returns the sql pieces blob contains; or error if sql cannot be parsed
119+
func SplitStatementToPieces(blob string) (pieces []string, err error) {
120+
pieces = make([]string, 0, 16)
121+
tokenizer := NewStringTokenizer(blob)
122+
123+
tkn := 0
124+
var stmt string
125+
stmtBegin := 0
126+
for {
127+
tkn, _ = tokenizer.Scan()
128+
if tkn == ';' {
129+
stmt = blob[stmtBegin : tokenizer.Position-2]
130+
pieces = append(pieces, stmt)
131+
stmtBegin = tokenizer.Position - 1
132+
133+
} else if tkn == 0 || tkn == eofChar {
134+
blobTail := tokenizer.Position - 2
135+
136+
if stmtBegin < blobTail {
137+
stmt = blob[stmtBegin : blobTail+1]
138+
pieces = append(pieces, stmt)
139+
}
140+
break
141+
}
142+
}
143+
144+
err = tokenizer.LastError
145+
return
146+
}
147+
117148
// SQLNode defines the interface for all nodes
118149
// generated by the parser.
119150
type SQLNode interface {

ast_test.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"bytes"
2121
"encoding/json"
2222
"reflect"
23+
"strings"
2324
"testing"
2425
"unsafe"
2526

@@ -441,3 +442,48 @@ func TestColumns_FindColumn(t *testing.T) {
441442
}
442443
}
443444
}
445+
446+
func TestSplitStatementToPieces(t *testing.T) {
447+
testcases := []struct {
448+
input string
449+
output string
450+
}{{
451+
input: "select * from table",
452+
}, {
453+
input: "select * from table1; select * from table2;",
454+
output: "select * from table1; select * from table2",
455+
}, {
456+
input: "select * from /* comment ; */ table;",
457+
output: "select * from /* comment ; */ table",
458+
}, {
459+
input: "select * from table where semi = ';';",
460+
output: "select * from table where semi = ';'",
461+
}, {
462+
input: "select * from table1;--comment;\nselect * from table2;",
463+
output: "select * from table1;--comment;\nselect * from table2",
464+
}, {
465+
input: "CREATE TABLE `total_data` (`id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'id', " +
466+
"`region` varchar(32) NOT NULL COMMENT 'region name, like zh; th; kepler'," +
467+
"`data_size` bigint NOT NULL DEFAULT '0' COMMENT 'data size;'," +
468+
"`createtime` datetime NOT NULL DEFAULT NOW() COMMENT 'create time;'," +
469+
"`comment` varchar(100) NOT NULL DEFAULT '' COMMENT 'comment'," +
470+
"PRIMARY KEY (`id`))",
471+
}}
472+
473+
for _, tcase := range testcases {
474+
if tcase.output == "" {
475+
tcase.output = tcase.input
476+
}
477+
478+
stmtPieces, err := SplitStatementToPieces(tcase.input)
479+
if err != nil {
480+
t.Errorf("input: %s, err: %v", tcase.input, err)
481+
continue
482+
}
483+
484+
out := strings.Join(stmtPieces, ";")
485+
if out != tcase.output {
486+
t.Errorf("out: %s, want %s", out, tcase.output)
487+
}
488+
}
489+
}

0 commit comments

Comments
 (0)