Elasticsearch提供了标准的REST接口,以及Java、Python、Go等语言编写的客户端。
基于开源数据集SIFT1M(http://corpus-texmex.irisa.fr/)和Python Elasticsearch Client,本节提供一份创建向量索引、导入向量数据和查询向量数据的代码示例,介绍如何使用客户端实现向量检索。
前提条件
客户端已经安装python依赖包。如果未安装可以执行如下命令安装:
pip install numpy
pip install elasticsearch==7.6.0
代码示例
import numpy as np
import time
import json
from concurrent.futures import ThreadPoolExecutor, wait
from elasticsearch import Elasticsearch
from elasticsearch import helpers
endpoint = 'http://xxx.xxx.xxx.xxx:9200/'
# 构建es客户端对象
es = Elasticsearch(endpoint)
# 索引mapping信息
index_mapping = '''
{
"settings": {
"index": {
"vector": "true"
}
},
"mappings": {
"properties": {
"my_vector": {
"type": "vector",
"dimension": 128,
"indexing": true,
"algorithm": "GRAPH",
"metric": "euclidean"
}
}
}
}
'''
# 创建索引
def create_index(index_name, mapping):
res = es.indices.create(index=index_name, ignore=400, body=mapping)
print(res)
# 删除索引
def delete_index(index_name):
res = es.indices.delete(index=index_name)
print(res)
# 刷新索引
def refresh_index(index_name):
res = es.indices.refresh(index=index_name)
print(res)
# 索引段合并
def merge_index(index_name, seg_cnt=1):
start = time.time()
es.indices.forcemerge(index=index_name, max_num_segments=seg_cnt, request_timeout=36000)
print(f"在{time.time() - start}秒内完成merge")
# 加载向量数据
def load_vectors(file_name):
fv = np.fromfile(file_name, dtype=np.float32)
dim = fv.view(np.int32)[0]
vectors = fv.reshape(-1, 1 + dim)[:, 1:]
return vectors
# 加载ground_truth数据
def load_gts(file_name):
fv = np.fromfile(file_name, dtype=np.int32)
dim = fv.view(np.int32)[0]
gts = fv.reshape(-1, 1 + dim)[:, 1:]
return gts
def partition(ls, size):
return [ls[i:i + size] for i in range(0, len(ls), size)]
# 写入向量数据
def write_index(index_name, vec_file):
pool = ThreadPoolExecutor(max_workers=8)
tasks = []
vectors = load_vectors(vec_file)
bulk_size = 1000
partitions = partition(vectors, bulk_size)
start = time.time()
start_id = 0
for vecs in partitions:
tasks.append(pool.submit(write_bulk, index_name, vecs, start_id))
start_id += len(vecs)
wait(tasks)
print(f"在{time.time() - start}秒内完成写入")
def write_bulk(index_name, vecs, start_id):
actions = [
{
"_index": index_name,
"my_vector": vecs[j].tolist(),
"_id": str(j + start_id)
}
for j in range(len(vecs))
]
helpers.bulk(es, actions, request_timeout=3600)
# 查询索引
def search_index(index_name, query_file, gt_file, k):
print("Start query! Index name: " + index_name)
queries = load_vectors(query_file)
gt = load_gts(gt_file)
took = 0
precision = []
for idx, query in enumerate(queries):
hits = set()
query_json = {
"size": k,
"_source": False,
"query": {
"vector": {
"my_vector": {
"vector": query.tolist(),
"topk": k
}
}
}
}
res = es.search(index=index_name, body=json.dumps(query_json))
for hit in res['hits']['hits']:
hits.add(int(hit['_id']))
precision.append(len(hits.intersection(set(gt[idx, :k]))) / k)
took += res['took']
print("precision: " + str(sum(precision) / len(precision)))
print(f"在{took / 1000:.2f}秒内完成检索,平均took大小为{took / len(queries):.2f}毫秒")
if __name__ == "__main__":
vec_file = r"./data/sift/sift_base.fvecs"
qry_file = r"./data/sift/sift_query.fvecs"
gt_file = r"./data/sift/sift_groundtruth.ivecs"
index = "test"
create_index(index, index_mapping)
write_index(index, vec_file)
merge_index(index)
refresh_index(index)
search_index(index, qry_file, gt_file, 10)