结合源码看Restormer的网络设计

news/2025/2/9 3:14:59 标签: 深度学习, 人工智能

下面是restormer的结构示意图。restormer集合了众多的技术,包括unet结构,1x1卷积和深度卷积,还有nlp中常用的layer norm,attention结构。

OverlapPatchEmbed

第一个结构是OverlapPatchEmbed。这是通过卷积把图像映射到高维空间,在这里是把原来的rgb三通道升维到48通道,同时通过padding等参数经典的(3,1,1)的设置,保持特征图的长宽和输入相同。

## Overlapped image patch embedding with 3x3 Conv
class OverlapPatchEmbed(nn.Module):
    def __init__(self, in_c=3, embed_dim=48, bias=False):
        super(OverlapPatchEmbed, self).__init__()

        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)

    def forward(self, x):
        x = self.proj(x)

        return x

接下来就是最重要的TransformerBlock。如最开头的结构图所画的,TransformerBlock是构成u-net结构的基本单元。

layer norm

第一步是layer norm。layer norm是transform中常用的归一化方法,它和batch norm的区别在于归一化的维度不同:

至于为什么transform要使用layer norm,一个最简单的解释就是transform一开始用于nlp,而文本的长度是不一的,这样没办法在所有位置计算均值和方差。但为什么vit也仍然使用layer norm,可以再继续参考资料。这里着重代码实现。

layer norm是在通道层面归一化,所以要在通道维度计算方差。所以关键是搞清楚哪个维度是c通道。为了方便,先把bchw转化为bhwc,这样均值和方差都在最后一个维度计算:

    def forward(self, x):
        mu = x.mean(-1, keepdim=True)
        sigma = x.var(-1, keepdim=True, unbiased=False)  # unbiased表示无偏估计
        return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias

可以看到在归一化之后还会再按照weight和bias再缩放回去,这两个都是可学习的参数:

        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))

而根据是否减去均值,layer norm分为WithBias_LayerNorm和BiasFree_LayerNorm。

attention(MDTA)

attention类就对应文章中的MDTA模块Multi-Dconv Head Transposed Attention,从名字上看,它是一个特殊的attention。而既然是attention,就有qkv的计算。使用1x1的pixel wise卷积再接一个3x3的depth wise 卷积:

self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
# nn.Conv2d中的groups的作用是分成几个组来计算卷积,默认是1,就是普通的卷积,如果groups分成in_channels,那就是深度分离卷积

要得到qkv,先使用1x1卷积使得通道维度上升到原来的3倍,然后再使用3x3卷积。1x1卷积和3x3卷积的输出维度都是dim*3。同时注意文章中提到两个卷积都是bias-free的。

1x1卷积当然不改变特征图尺寸,而depth-wise卷积也是(3,1,1),也不改变尺寸。

然后使用chunk函数拆分成q,k,v三个向量:

q,k,v = qkv.chunk(3, dim=1)  # 从dim=1的维度(bchw)拆分得到三个向量,

q,k,v是按照通道划分得到的,而为了减少计算量,把原来空间上的SA转为通道间计算attention,需要把所以q,k,v进一步划分,使用爱因斯坦库‌Einops(Einstein Operations)实现:

        q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
# head*c表示原来的通道数。在合并hw的同时把通道维度拆分为head*c

上面rearrange的同时也起到了两个作用,一个是实现了多头,把原来的通道划分为head和c;一个是把二维空间(h,w)压缩到一维。

q,k决定最终的权重系数,所以要进行归一化,注意是在(h w)维度进行归一化:

        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)

计算内积(矩阵乘法)前,要把k的(h w)和c维度进行交换,这样得到的内积结果就会是c行c列的方阵,就可以去和value做内积得到加权后的结果:

        attn = (q @ k.transpose(-2, -1)) * self.temperature  #temperature是可学习参数,控制点积的尺度。α is a learnable scaling parameter to control the magnitude of the dot product
        attn = attn.softmax(dim=-1)

        out = (attn @ v)

注意MDTA其实没有改变输入数据的通道数,因为是通道扩充到3倍,然后对其实的v做的加权。

所以MDTA中的M指的是多头:

Similar to the conventional multi-head SA [17], we divide the number of channels into ‘heads’ and learn separate attention maps in parallel.

D是深度卷积,T是转置,A是attention。

每个level中的transformerBlock的attention中的heads数是相同的,不同level中不同,分别是1,2,4,8.

FeedForward(GDFN)

仍然是使用1x1卷积和depth-wise卷积,通道维度增加,然后使用chunk分成两个分支。

其中分支1经过gelu()的响应作为权重,对分支2进行加权:

    def forward(self, x):
        x = self.project_in(x)
        x1, x2 = self.dwconv(x).chunk(2, dim=1)
        x = F.gelu(x1) * x2
        x = self.project_out(x)
        return x

gelu是一个比relu更加平滑的激活函数(更容易收敛),引入了非线性(提高模型的表达能力),并且在输入接近0时输出是高斯分布的(提高神经网络的泛化能力)。GELU的设计灵感来自于随机神经网络和高斯误差函数,它试图模仿自然神经元的行为,即输入信号与噪声的交互。

形式上Gated-Dconv Feed-Forward Network (GDFN)和Multi-DConv Head Transposed Self-Attention (MDTA)是很相似的,都是layer norm之后,使用1x1和深度卷积提取特征,然后分几个通道,在通道层面做加权。

区别是GDFN更关注的是信息流,完成不同level之间的信息互补:

Overall, the GDFN controls the information flow through the respective hierarchical levels in our pipeline, thereby allowing each level to focus on the fine details complimentary to the other levels. That is, GDFN offers a distinct role compared to MDTA (focused on enriching features with contextual information)

还有一点区别是GDFN中在project_in增加通道的工程中引入了一个超参数ffn_expansion_factor,控制了通道增加的程度。默认是2.66:

hidden_channels = int(channels * expansion_factor) #会直接截断小数部分,返回不大于原数的最大整数
self.project_in = nn.Conv2d(channels, hidden_channels * 2, kernel_size=1, bias=False)

TransformerBlock

有了MDTA和GDFN,很容易搭出TransformerBlock。不过准确地说,前面的attention和FFN需要有残差的结构才是MDTA和GDFN。所以最终的TransformerBlock:

 class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
        super(TransformerBlock, self).__init__()

        self.norm1 = LayerNorm(dim, LayerNorm_type)
        self.attn = Attention(dim, num_heads, bias)
        self.norm2 = LayerNorm(dim, LayerNorm_type)
        self.ffn = FeedForward(dim, ffn_expansion_factor, bias)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.ffn(self.norm2(x))

        return x

restormer

最终的网络结构就可以使用上面的子结构搭建出来。

第一个当然是overlapPatchEMbed,使用一层卷积把提取的特征分布在48个通道中。

然后是4 level的主结构,每个level由TransformerBlock构成,只不过数目不同,从上到下分别是4,6,6,8个TransformerBlock。

网络结构:Restormer/basicsr/models/archs/restormer_arch.py at main · swz30/Restormer · GitHub


http://www.niftyadmin.cn/n/5845478.html

相关文章

feign Api接口中注解问题:not annotated with HTTP method type (ex. GET, POST)

Bug Description 在调用Feign api时,出现如下异常: java.lang.IllegalStateException: Method PayFeignSentinelApi#getPayByOrderNo(String) not annotated with HTTPReproduciton Steps 1.启动nacos-pay-provider服务,并启动nacos-pay-c…

Java 的 CopyOnWriteArrayList 和 Collections.synchronizedList 有什么区别?分别有什么优缺点?

参考答案拆解 1. 核心概念对比 特性CopyOnWriteArrayListCollections.synchronizedList实现机制写时复制(Copy-On-Write)方法级同步(synchronized块)锁粒度写操作使用ReentrantLock,读操作无锁所有操作使用对象级锁&…

Elasticsearch:向量搜索的快速介绍

作者:来自 Elastic Valentin Crettaz 本文是三篇系列文章中的第一篇,将深入探讨向量搜索(也称为语义搜索)的复杂性,以及它在 Elasticsearch 中的实现方式。 本文是三篇系列文章中的第一篇,将深入探讨向量搜…

Gitee AI上线:开启免费DeepSeek模型新时代

一、引言 在当今数字化浪潮汹涌澎湃的时代,人工智能(AI)已成为推动各行业变革与发展的核心驱动力。从智能语音助手到图像识别技术,从自动驾驶汽车到金融风险预测,AI的应用无处不在,深刻地改变着我们的生活和…

2025蓝桥杯JAVA编程题练习Day3

1.黛玉泡茶【算法赛】 问题描述 话说林黛玉闲来无事,打算在潇湘馆摆个茶局,邀上宝钗、探春她们一起品茗赏花。黛玉素来讲究,用的茶杯也各有不同,大的小的,高的矮的,煞是好看。这不,她从柜子里…

青少年编程与数学 02-008 Pyhon语言编程基础 25课题、文件操作

青少年编程与数学 02-008 Pyhon语言编程基础 25课题、文件操作 一、文件操作二、文本文件读取文本文件写入文本文件追加文本到文件修改文本文件复制文本文件文件编码错误处理 三、JSON文件读取JSON文件写入JSON文件修改JSON文件处理大型JSON文件错误处理 四、练习1. 将JSON文件…

绿虫光伏仿真设计软件基于Unity3D引擎的革命性突破

绿虫光伏仿真设计软件凭借其技术突破与功能创新,正在重塑光伏电站设计领域的行业范式。以下从技术架构、功能创新及行业价值三个维度深度解析其核心竞争力: 一、颠覆性技术架构 1、游戏引擎赋能工业软件 采用Unity3D引擎构建底层架构,实现影…

第7章《VTK与OPenGL集成》

VTK 本身基于 OpenGL 进行渲染,但如果想要在 VTK 场景中结合 OpenGL 进行底层渲染(如自定义 Shader、直接绘制 OpenGL 图元等),可以通过 VTK 的 OpenGL 接口 实现。这一部分主要讲解 VTK 如何与 OpenGL 交互,包括 使用 OpenGL 直接绘制图形、自定义着色器(Shader)、Fram…