全部产品
Search
文档中心

人工智能平台 PAI:向量召回评估

更新时间:Oct 31, 2023

向量召回评估组件计算召回的hitrate结果。hitrate作为结果好坏的评价,hitrate越高表示训练产出的向量去召回向量的结果越准确。本文为您介绍向量召回评估组件的原理和配置指导。

原理

向量召回评估组件同时支持u2i召回和i2i召回的计算。u2i召回时,拿user(用户)的向量去召回top k个items(物品),i2i召回时拿item的向量去召回top k个items。 hitrate的具体计算方法为,假设真实trigger(u2i召回时为user,i2i召回时为item)的关联item集合为M,而实际召回了top k个和trigger相似的items,若其中落在了M里的集合为N,则top k hitrate为|N| / |M|。为了进行bad case study,组件同时会输出在top k里面但是不在M里面的items以及对应的距离。组件同时支持单机和分布式运行。 具体实现流程为:

  1. 每个worker分片加载user、item的embedding表,构建KNN需要的索引。

  2. 每个worker分片按批次读真实序列表,查找embedding搜索KNN,得到top k的items。

  3. 根据真实的items序列和top k的items计算hitrate。

  4. 汇总结果并输出到ODPS表。

组件输入

item embedding表

item的embedding表,一般为GraphSAGE等训练算法的输出,表示例如下。

item id (bigint)

item embeddings (string)

23456677

0.1,0.2,0.3....

user embedding表

user的embedding表,一般为GraphSAGE等训练算法的输出,表示例如下。

user id (bigint)

user embeddings (string)

12345

0.1,0.2,0.3....

真实序列表

trigger和关联item的真实表,做为ground truth。u2i召回时trigger id为user id, i2i召回时为item id。item ids为真实的和trigger id相关的item列表,表示例如下。

trigger id (bigint)

item ids (string)

12345

23456677,2233445,6837292,...

组件输出

total_hitrate表

总的hitrate,表示例如下。

hitrate(double)

0.4

hitrate_details表

hitrate详情表,行数和真实序列表一致,表示例如下

id (bigint)

topk_ids (string)

topk_dists (string)

hitrate (double)

bad_ids (string)

bad_dists (string)

1123

2345,2367,2483,2567

0.8,0.7,0.2,0.1

0.39

2483,2567

0.2,0.1

该表的行数与真实序列表相同

  • id当u2i召回时为user_id, i2i召回时为item_id。

  • topk_ids是与trigger最相关的前k个item的id,以半角逗号(,)分割。

  • topk_dists是与topk_ids对应的距离。

  • hitrate为针对该trigger召回item的命中率。

  • bad_ids为召回但是未命中的item id。

  • bad_dists为与bad_cases对应的距离。

组件参数配置

向量召回评估组件支持界面化配置和命令方式配置,配置参数一致,参数配置指导如下。

参数

类型

参数说明

输入配置

item_emb_table

(item向量表)

string

item embedding表。

true_seq_table

(真实序列表)

string

真实序列表。u2i召回时为user和user关联的items;i2i召回时为item和item关联的items。

重要

测试召回效果时,训练embedding用T时间数据时,真实序列使用T+1时间的,否则出现穿越,hitrate偏高。

user_emb_table

(user向量表)

string (可选)

user embedding表,只在u2i召回时需要提供。

输出配置

total_hitrate

(向量召回评估值)

string

输出表,总的hitrate。

hitrate_details

(向量召回评估详情)

string

输出表,hitrate详情。

参数设置

recall_type

(召回类型)

string

召回类型,'u2i'或者'i2i'。

emb_dim

(向量表特征维度)

int

embedding表的embedding维度。

k(召回数目)

int

召回的数目。

metric

(召回相似度度量方式)

int

(可选,默认1)

召回相似度度量方式。0为L2距离,1为内积。 L2时返回距离最小的k个,内积时返回内积最大的k个。

strict(是否容错)

bool

(可选,默认False)

相似度计算有一定误差,如果需要严格结果,strict设为True,但是strict=True时,速度会比较慢。

lifecycle

int (可选,默认7)

输出表的lifecycle,单位为天。

执行调优

batch_size

int

(可选,默认1024)

一次计算的样本数量,内存不够时可以设小。

worker_count

(计算核心数)

int (可选,默认1)

运行的机器数,当输入表比较大或者单个worker运行比较慢时可以设大次数目。

worker_memory

(每个核心内存)

int

(可选,默认20000)

每个机器的内存大小,单位为M字节,默认20000 MB。

PAI命令示列

pai -name hitrate_gl_ext
		-Ditem_emb_table='item_emb_table'
    -Duser_emb_table='user_emb_table'
    -Dtrue_seq_table='true_seq_table'
    -Dhitrate_details='hitrate_details'
    -Dtotal_hitrate='total_hitrate'
    -Drecall_type='u2i'
    -Dk=5
    -Demb_dim=10
    -Dmetric=1
    -Dstrict=False
    -Dbatch_size=1024
    -Dworker_count=1
    -Dworker_memory=20000
    -Dlifecycle=7;

上述命令展示了u2i召回计算hitrate的例子,该命令指定按照内积方式计算向量相似度(距离),不要求距离计算的严格性,按批次计算,每次计算1024个true_seq_table里的内容,指定了1个worker,内存是20 GB,输出表hitrate_details和total_hitrate的生命周期是7天。