【stable-diffusion企业级教程04】EMA你走,拥抱16G显存!Xformers是未来!
1、回顾
上一讲我们成功的利用deepspeed框架,将sd模型以fp16的精度训练了起来,显存的消耗也降低到18G左右。
不过有小伙伴在下面评论,说手上只有16G的显卡,有没有办法再降低一些训练时的资源需求呢。
今天我们就介绍另外两个方法,进一步降低整体的显存需求。
2、ema
2.1 ema简单介绍
熟悉股票技术分析的同学应该很熟悉ma(移动平均),简单讲就是随着时间的发展,取前K天值的均值来做为一个指标。
那深度学习中的ema(exponential moving average),可以理解为是一种更新模型权重的方法,通过维持一个影子权重的方法,来对模型参数做“平均”,使得模型在最后的测评集中效果更好。
我个人的理解是,batch_gradient_decent可以看做是不同样本共同决定更新方向;而ema则是跨batch来决定更新幅度。
这里不展开,具体的可以看看这两文章:
【炼丹技巧】指数移动平均(EMA)的原理及PyTorch实现
理解滑动平均(exponential moving average) - wuliytTaotao - 博客园
2.2 模型选择
ok,回到我们的显存上来,这里提到ema是为了说明,利用ema进行更新的模型参数,会和正常进行更新的模型参数不一致,从而会保存两份参数。
这也提现到了模型上面,在这个链接中,提供了下面两个模型来下载,我们之前使用的是
sd-v1-4-full-ema.ckpt
这个模型。
利用我之前分享的repo中的
test_parameters.py
,可以将模型的参数名字打印出来。我们会发现,在vqa和text-encoder之外,模型保存了两份不同的unet模型。而这个,就是导致我们加载模型耗显存的原因。
那解决的方案也很简单,我们的base模型,可以选择下面链接中
sd-v1-4.ckpt
就可以了。
2.3 训练设置
那除了在加载模型时避免加载ema的权重,我们还需要在训练时避免生成ema相关的权重。通过观察代码,我们会发现,ema权重,是通过调用
on_train_batch_end
这个方法来实现的。
main_torch_deepspeed.py
# 4、Start train
device= torch.device(model_engine.local_rank)
for epoch in range(10*6*5): # 800/8 = 100 2*50/gpu/epoch, 300
for i,bs in enumerate(tqdm(trainloader,desc=f"{epoch}")):
if fp16:
bs['image'] = bs['image'].cuda().half()
else:
bs['image'] = bs['image'].cuda()
loss = model.training_step(bs, i)
model_engine.backward(loss)
model_engine.step()
model.on_train_batch_end()
========》