基本结构

对比学习属于自监督学习,不需要人为地去标注标签,就能让模型学习到数据的特征表示。

其目标是,在潜空间中,相似的样本应该距离较近,而不相似的样本应该距离较远。

其主要结构如下

graph LR
A[C的数据增强C']-->En{Encoder}
B[给定样本C]-->En
C[毫无关系的样本D]-->En
En-->Aen[C'对应表示]
En-->Ben[C对应表示]
En-->Cen[D对应表示]
Aen-->L1{loss1}
Ben-->L1
Cen-->L2{loss2}
Ben-->L2

我们希望样本C与其数据增强C’的对应表示距离较近(减小loss1);
我们希望样本C与毫无关系的样本D的对应表示距离较远(增大loss2);

这就是对比学习的基本原理。

这看起来好像是一个新知识,但其已经被应用在了GraphSAGE模型的训练中了。参考往期:https://blog.57u.tech/2025/09/17/GraphSAGE-note/

CLIP模型

论文: https://arxiv.org/abs/2103.00020

CLIP模型通过在网上获取带字幕的图片作为训练数据(文字-图片对),来学习文字和图片的特征表示。

因此需要两个Encoder,一个用于编码图片,一个用于编码文字。而对于相似度计算,可以使用余弦相似度,避免模型学到直接将模长拉长 similarity=xyxy\text{similarity} = \frac{\mathbf{x}\cdot\mathbf{y}}{|\mathbf{x}||\mathbf{y}|}


左图中描述的就是相似度矩阵。共有NN个文字样本与NN个图片样本,共同组成了一个N×NN\times N的矩阵。
只有对角线位置上的结果匹配的,需要增大这些值;而其余位置上是不匹配的值,需要减小这些值。

对于右图,讲的是如何对图片进行分类。只需要用文本去描述不同的类别,嵌入到潜空间。然后将图片嵌入到潜空间,计算相似度,取相似度最高的类别作为预测结果。

但是这篇论文还提出了一个创新点,就是这两个encoder的输出不一定需要相同(文中说的他们用的预训练模型,降本增效),因为可以在loss函数前添加一个线性层,将不同的维度映射到相同的维度。

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# extract feature representations of each modality
I_f = image_encoder(I) # [n, d_i]
T_f = text_encoder(T) # [n, d_t]

# joint multimodal embedding [n, d_e] 把表示映射到相同的维度
I_e = l2_normalize(np.dot(I_f, W_i), axis=1)
T_e = l2_normalize(np.dot(T_f, W_t), axis=1)

# 通过广播机制得到 [n, n],并应用温度系数(控制平滑程度,值越大,softmax时约尖锐)
logits = np.dot(I_e, T_e.T) * np.exp(t)

# symmetric loss function
labels = np.arange(n) #[0,1,...,n-1]
loss_i = cross_entropy_loss(logits, labels, axis=0) #对于图像,是竖向的,需要对列计算损失
loss_t = cross_entropy_loss(logits, labels, axis=1)# 对于文字,是横向的,需要对行计算损失
loss = (loss_i + loss_t) / 2

参考资料:

SimCLR

论文:https://arxiv.org/abs/2002.05709

这篇论文主要提出了3个创新点:

  1. 数据增强的组成起着关键作用
  2. 不直接在嵌入层计算loss,而是再经过一个非线性变换后计算loss
  3. 相较于监督学习,对比学习需要更大的batch size

数据增强的组成起着关键作用
数据增强需要多样性。不过图中也能看出,其实多种数据增强的组合是对顺序不敏感的(其几乎是一个对称矩阵)。重要的是各种数据增强的方式。

第二个就是我觉得最重要的点,即不直接去计算representation的loss,而是先经过一个非线性变换,再计算loss。

graph LR

i[image]--encoder-->representation--nonlinear--> 投影特征向量

对于这个非线性变换,其也有可学习参数,让encoder的生成的representation经过这个非线性变换后,更能适应loss函数。因为我们不能假定最好的representation就完全能够使用余弦相似度去衡量差异。

参考资料: