图文匹配中, S ∈ [ 0 , 1 ] n × n S\in [0,1]^{n\times n} S∈[0,1]n×n 是一个相似度矩阵,即 S i j S_{ij} Sij 是第 i 幅图 I i I_i Ii 与第 j 条句子 T j T_j Tj 的相似度,而 ( I i , T i ) i = 1 n (I_i, T_i)_{i=1}^n (Ii,Ti)i=1n 是 ground-truth pair。检索文本(text retrieval)要求以 I i I_i Ii 为 quey 时 T i T_i Ti 排第几(即其 ranking);反过来检索图像(image retrieval)要求 T i T_i Ti 为 query 时 I i I_i Ii 的 ranking。
Text retrieval(逐行检索)可如此写:
import torch
# S[i][j] = sim(image_i, text_j)
S = torch.randperm(36).view(6, 6)
print("similarities:", S)
# text retrieval 求 ranking
asc_tx = S.argsort(1, descending=True)
print("asc_tx:", asc_tx)
sid = torch.arange(S.size(0)) # [n]
print("sample id:", sid)
rank_tx = torch.where(sid.unsqueeze(1) == asc_tx) # [n, 1]
print("rank_tx:", rank_tx)
# 两种判断是否为 top-1 的写法对拍,结果一致
print("top-1")
print(rank_tx[1] < 1)
print(sid == S.argmax(1))
结果:
similarities:
tensor([[ 5, 25, 28, 15, 29, 19],
[ 3, 13, 21, 1, 0, 16],
[ 9, 31, 12, 18, 32, 14],
[17, 2, 26, 4, 10, 7],
[ 8, 23, 11, 35, 34, 20],
[24, 6, 27, 30, 22, 33]])
asc_tx:
tensor([[4, 2, 1, 5, 3, 0],
[2, 5, 1, 0, 3, 4],
[4, 1, 3, 5, 2, 0],
[2, 0, 4, 5, 3, 1],
[3, 4, 1, 5, 2, 0],
[5, 3, 2, 0, 4, 1]])
sample id: tensor([0, 1, 2, 3, 4, 5])
rank_tx: (tensor([0, 1, 2, 3, 4, 5]), tensor([5, 2, 4, 4, 1, 0])) # <- 第一个向量是行序号,升序,没问题
top-1
tensor([False, False, False, False, False, True]) # 一致
tensor([False, False, False, False, False, True]) # 一致
这种写法的思路是用 torch.argsort
按行排序,然后用 torch.where
求每一行序号等于 sample ID 的位置,即为 ranking。torch.where
返回的结果 rank_tx
是两个向量,第一个是行座标,第二个是列座标,由于 text retrieval 是逐行检索,所以列座标是 ranking。从结果看,这种写法没问题。
但当用同样思路写 image retrieval(逐列检索)时,出问题了:
import torch
# S[i][j] = sim(image_i, text_j)
S = torch.randperm(36).view(6, 6)
print("similarities:", S)
sid = torch.arange(S.size(0)) # [n]
# print("sample id:", sid)
# image retrieval
asc_im = S.argsort(0, descending=True) # 排序轴换成 0
print("asc_im:", asc_im)
rank_im = torch.where(sid.unsqueeze(0) == asc_im) # [1, n]
print("rank_im:", rank_im) # 不对劲
# 两种 top-1 写法对拍不过
print("top-1")
print(rank_im[0] < 1) # 取第一个个向量,即行位置
print(sid == S.argmax(0))
结果:
similarities:
tensor([[19, 16, 1, 15, 24, 28],
[33, 21, 8, 3, 2, 34],
[14, 25, 7, 32, 17, 0],
[30, 6, 26, 11, 27, 4],
[31, 20, 29, 22, 35, 23],
[12, 13, 5, 18, 10, 9]])
asc_im:
tensor([[1, 2, 4, 2, 4, 1],
[4, 1, 3, 4, 3, 0],
[3, 4, 1, 5, 0, 4],
[0, 0, 2, 0, 2, 5],
[2, 5, 5, 3, 5, 3],
[5, 3, 0, 1, 1, 2]])
rank_im: (tensor([0, 1, 3, 3, 3, 4]), tensor([4, 1, 0, 2, 5, 3])) # <- 第二个向量是列序号,是乱序!
top-1
tensor([ True, False, False, False, False, False]) # 不一致
tensor([False, False, False, False, True, False]) # 不一致
这种 image retrieval 的写法是按照前面 text retrieval 的写法对称改过来的:
- argsort 排序轴 0 -> 1(按行 -> 按列)。这步没问题;
sid.unsqueeze(1)
->sid.unsqueeze(0)
,即换成求每列序号等于 sample ID 的位置。这步的结果就不对了,前面rank_tx
的第一个向量是升序的行序号,而rank_im
的第二个向量却是乱序的列序号!
这个现象就是题目所谓 torch.where
纵横不对称。从 rank_im
来看,torch.where
的策略是行主序搜索,即搜完一行再一行,保证其结果 rank_im
的第一个向量是非降的,rank_tx
也满足这点。
一个解决办法是:转置 argsort 结果,然后照抄逐行检索的写法:
import torch
# S[i][j] = sim(image_i, text_j)
S = torch.randperm(36).view(6, 6)
print("similarities:", S)
# image retrieval, corrected
asc_im = S.argsort(0, descending=True)
# rank_im = torch.where(sid.unsqueeze(0) == asc_im) # 出事写法
rank_im2 = torch.where(sid.unsqueeze(1) == asc_im.T) # 转置 argsort,按逐行检索写法来
print("rank_im2:", rank_im2)
print("top-1")
# print(rank_im[0] < 1)
print(rank_im2[1] < 1) # 还是用第二个向量
print(sid == S.argmax(0))
结果:
similarities:
tensor([[16, 17, 11, 10, 23, 33],
[13, 15, 27, 34, 7, 24],
[26, 29, 20, 6, 18, 31],
[ 0, 32, 14, 12, 25, 35],
[ 1, 2, 4, 9, 19, 22],
[28, 30, 3, 5, 8, 21]])
rank_im2: (tensor([0, 1, 2, 3, 4, 5]), tensor([2, 4, 1, 1, 2, 5]))
top-1
tensor([False, False, False, False, False, False]) # 一致
tensor([False, False, False, False, False, False]) # 一致