-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtokenizer.go
38 lines (32 loc) · 1014 Bytes
/
tokenizer.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
package tokenizer
import (
"github.com/sugarme/tokenizer"
)
type Tokenizer struct {
*tokenizer.Tokenizer
}
func (t *Tokenizer) EncodeBatchTexts(texts []string, addSpecialTokens bool) ([]tokenizer.Encoding, error) {
var inputs []tokenizer.EncodeInput
for _, text := range texts {
seq := tokenizer.NewInputSequence(text)
inputs = append(inputs, tokenizer.NewSingleEncodeInput(seq))
}
return t.EncodeBatchSerially(inputs, addSpecialTokens)
}
// EncodeBatchSerially encodes all sentences serially.
func (t *Tokenizer) EncodeBatchSerially(inputs []tokenizer.EncodeInput, addSpecialTokens bool) ([]tokenizer.Encoding, error) {
var encodings []tokenizer.Encoding
for _, input := range inputs {
e, err := t.Tokenizer.Encode(input, addSpecialTokens)
if err != nil {
return nil, err
}
encodings = append(encodings, *e)
}
// Do padding if specified.
padding := t.Tokenizer.GetPadding()
if padding != nil {
encodings = tokenizer.PadEncodings(encodings, *padding)
}
return encodings, nil
}