语法补全功能,比如"as soon as possible",
当我们识别到"as soon as"时, 基本即可判定用户需要键入"possible"。
设计一个统计词频的模型,用于这个功能,
类似(prefix, next word)这样的二元组,
比如一个上面的句子"as soon as possible",
有产生如下的二元组(as, soon, 1)、(as soon, as, 1)、(as soon as, possible, 1)。
意思是这一个句子产生了如下的统计:
当前缀为"as",接下来的单词是"soon",有了1个期望点;
当前缀为"as soon",接下来的单词是"as",有了1个期望点;
当前缀为"as soon as",接下来的单词是"possible",有了1个期望点。
那么如果给你很多的句子,当然就可以产生很多的期望点,同一个前缀下,同一个next word的期望点可以累加。
现在给你n个句子,让你来建立统计,
然后给你m个句子,作为查询,
最后给你k,表示每个句子作为前缀的情况下,词频排在前k名的联想。
返回m个结果,每个结果最多k个单词。
前缀树。本来想用rust编写,但实力有限,实在写不出。所以用go语言了。
代码用golang编写。代码如下:
package main
import (
"fmt"
"sort"
"strings"
)
func main() {
sentences := []string{"i think you are good", "i think you are fine", "i think you are good man"}
k := 2
ai := NewAI(sentences, k)
for _, ans := range ai.suggest("i think you are") {
fmt.Println(ans)
}
fmt.Println("=====")
ai.fill("i think you are fucking good")
ai.fill("i think you are fucking great")
ai.fill("i think you are fucking genius")
for _, ans := range ai.suggest("i think you are") {
fmt.Println(ans)
}
fmt.Println("=====")
}
type TrieNode struct {
word string
times int
nextNodes map[string]*TrieNode
nextRanks []*TrieNode
}
func NewTrieNode(w string) *TrieNode {
ans := &TrieNode{}
ans.word = w
ans.times = 1
ans.nextNodes = make(map[string]*TrieNode)
ans.nextRanks = make([]*TrieNode, 0)
return ans
}
type AI struct {
root *TrieNode
topk int
}
func NewAI(sentences []string, k int) *AI {
ans := &AI{}
ans.root = NewTrieNode("")
ans.topk = k
for _, sentence := range sentences {
ans.fill(sentence)
}
return ans
}
func (this *AI) fill(sentence string) {
cur := this.root
var next *TrieNode
for _, word := range strings.Split(sentence, " ") {
if _, ok := cur.nextNodes[word]; !ok {
next = NewTrieNode(word)
cur.nextNodes[word] = next
cur.nextRanks = append(cur.nextRanks, next)
} else {
next = cur.nextNodes[word]
next.times++
}
cur = next
}
}
func (this *AI) suggest(sentence string) []string {
ans := make([]string, 0)
cur := this.root
for _, word := range strings.Split(sentence, " ") {
if _, ok := cur.nextNodes[word]; !ok {
return ans
} else {
cur = cur.nextNodes[word]
}
}
sort.Slice(cur.nextRanks, func(i, j int) bool {
a := cur.nextRanks[i]
b := cur.nextRanks[j]
if a.times != b.times {
return a.times > b.times
} else {
return a.word < b.word
}
})
for _, n := range cur.nextRanks {
ans = append(ans, n.word)
if len(ans) == this.topk {
break
}
}
return ans
}
执行结果如下: