本教程示範如何使用向量檢索服務(DashVector),結合ModelScope上的中文CLIP多模態檢索模型,構建即時的“文本搜圖片”的多模態檢索能力。作為樣本,我們採用多模態牧歌資料集作為圖片語料庫,使用者通過輸入文本來跨模態檢索最相似的圖片。
整體流程
主要分為兩個階段:
圖片資料Embedding入庫。將牧歌資料集通過中文CLIP模型Embedding介面轉化為高維向量,然後寫入DashVector向量檢索服務。
文本Query檢索。使用對應的中文CLIP模型擷取文本的Embedding向量,然後通過DashVector檢索相似圖片。
前提準備
1. API-KEY 準備
2. 環境準備
本教程使用的是ModelScope最新的CLIP Huge模型(224解析度),該模型使用大規模中文資料進行訓練(~2億圖文對),在中文圖文檢索和映像、文本的表徵提取等情境表現優異。根據模型官網教程,我們提取出相關的環境依賴如下:
需要提前安裝 Python3.7 及以上版本,請確保相應的 python 版本
# 安裝 dashvector 用戶端
pip3 install dashvector
# 安裝 modelscope
# require modelscope>=0.3.7,目前預設已經超過,您檢查一下即可
# 按照更新鏡像的方法處理或者下面的方法
pip3 install --upgrade modelscope -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
# 需要單獨安裝decord
# pip3 install decord
# 另外,modelscope 的安裝過程會出現其他的依賴,目前的版本的依賴列舉如下
# pip3 install torch torchvision opencv-python timm librosa fairseq transformers unicodedata2 zhconv rapidfuzz
3. 資料準備
本教程使用多模態牧歌資料集的validation驗證集作為入庫的圖片資料集,可以通過調用ModelScope的資料集介面擷取。
from modelscope.msdatasets import MsDataset
dataset = MsDataset.load("muge", split="validation")
具體步驟
本教程所涉及的 your-xxx-api-key 以及 your-xxx-cluster-endpoint,均需要替換為您自己的API-KAY及CLUSTER_ENDPOINT後,代碼才能正常運行。
1. 圖片資料Embedding入庫
多模態牧歌資料集的 validation 驗證集包含 30588 張多模態情境的圖片資料資訊,這裡我們需要通過CLIP模型提取原始圖片的Embedding向量入庫,另外為了方便後續的圖片展示,我們也將原始圖片資料編碼後一起入庫。代碼執行個體如下:
import torch
from modelscope.utils.constant import Tasks
from modelscope.pipelines import pipeline
from modelscope.msdatasets import MsDataset
from dashvector import Client, Doc, DashVectorException, DashVectorCode
from PIL import Image
import base64
import io
def image2str(image):
image_byte_arr = io.BytesIO()
image.save(image_byte_arr, format='PNG')
image_bytes = image_byte_arr.getvalue()
return base64.b64encode(image_bytes).decode()
if __name__ == '__main__':
# 初始化 dashvector client
client = Client(
api_key='{your-dashvector-api-key}',
endpoint='{your-dashvector-cluster-endpoint}'
)
# 建立集合:指定集合名稱和向量維度, CLIP huge 模型產生的向量統一為 1024 維
rsp = client.create('muge_embedding', 1024)
if not rsp:
raise DashVectorException(rsp.code, reason=rsp.message)
# 批量產生圖片Embedding,並完成向量入庫
collection = client.get('muge_embedding')
pipe = pipeline(task=Tasks.multi_modal_embedding,
model='damo/multi-modal_clip-vit-huge-patch14_zh',
model_revision='v1.0.0')
ds = MsDataset.load("muge", split="validation")
BATCH_COUNT = 10
TOTAL_DATA_NUM = len(ds)
print(f"Start indexing muge validation data, total data size: {TOTAL_DATA_NUM}, batch size:{BATCH_COUNT}")
idx = 0
while idx < TOTAL_DATA_NUM:
batch_range = range(idx, idx + BATCH_COUNT) if idx + BATCH_COUNT <= TOTAL_DATA_NUM else range(idx, TOTAL_DATA_NUM)
images = [ds[i]['image'] for i in batch_range]
# 中文 CLIP 模型產生圖片 Embedding 向量
image_embeddings = pipe.forward({'img': images})['img_embedding']
image_vectors = image_embeddings.detach().cpu().numpy()
collection.insert(
[
Doc(
id=str(img_id),
vector=img_vec,
fields={'png_img': image2str(img)}
)
for img_id, img_vec, img in zip(batch_range, image_vectors, images)
]
)
idx += BATCH_COUNT
print("Finish indexing muge validation data")
上述代碼裡模型預設在 cpu 環境下運行,在 gpu 環境下會視 gpu 效能得到不同程度的效能提升
2. 文本Query檢索
完成上述圖片資料向量化入庫後,我們可以輸入文本,通過同樣的CLIP Embedding模型擷取文本向量,再通過DashVector向量檢索服務的檢索介面,快速檢索相似的圖片了,程式碼範例如下:
import torch
from modelscope.utils.constant import Tasks
from modelscope.pipelines import pipeline
from modelscope.msdatasets import MsDataset
from dashvector import Client, Doc, DashVectorException
from PIL import Image
import base64
import io
def str2image(image_str):
image_bytes = base64.b64decode(image_str)
return Image.open(io.BytesIO(image_bytes))
def multi_modal_search(input_text):
# 初始化 DashVector client
client = Client(
api_key='{your-dashvector-api-key}',
endpoint='{your-dashvector-cluster-endpoint}'
)
# 擷取上述入庫的集合
collection = client.get('muge_embedding')
# 擷取文本 query 的 Embedding 向量
pipe = pipeline(task=Tasks.multi_modal_embedding,
model='damo/multi-modal_clip-vit-huge-patch14_zh', model_revision='v1.0.0')
text_embedding = pipe.forward({'text': input_text})['text_embedding'] # 2D Tensor, [文本數, 特徵維度]
text_vector = text_embedding.detach().cpu().numpy()[0]
# DashVector 向量檢索
rsp = collection.query(text_vector, topk=3)
image_list = list()
for doc in rsp:
image_str = doc.fields['png_img']
image_list.append(str2image(image_str))
return image_list
if __name__ == '__main__':
text_query = "戴眼鏡的狗"
images = multi_modal_search(text_query)
for img in images:
# 注意:show() 函數在 Linux 伺服器上可能需要安裝必要的映像瀏覽器組件才生效
# 建議在支援 jupyter notebook 的伺服器上運行該代碼
img.show()
運行上述代碼,輸出結果如下: