zoukankan      html  css  js  c++  java
  • AJGAN

    def train(self):
            """Train StarGAN within a single dataset."""
            # Set data loader.
            data_loader = self.celeba_loader
            data_iter = iter(data_loader)
            # Learning rate cache for decaying.
            g_lr = self.g_lr
            d_lr = self.d_lr
    
            # Start training from scratch or resume training.
            start_iters = 0 
            #加入加载模型
            self.resume_iters = start_iters
            if self.resume_iters: #参数resume_iters 设置为none 
                start_iters = self.resume_iters #可以不连续训练,从之前训练好后的结果处开始
                self.restore_model(self.resume_iters, 'both')
            
            # Start training.
            print('Start training...')
            start_time = time.time()
            for i in range(start_iters, self.num_iters):
    
                # =================================================================================== #
                #                             1. Preprocess input data                                #
                # =================================================================================== #
    
                # Fetch real images and labels.
                try:
                    x_fixed, x_illumination,label_org = next(data_iter)
                except:
                    data_iter = iter(data_loader)
                    x_fixed, x_illumination,label_org = next(data_iter)
               
                         
                x_fixed = x_fixed.to(self.device)          
                x_illumination = x_illumination.to(self.device) 
                label_org = label_org.to(self.device)
     
    
                # =================================================================================== #
                #                        #加上gluon中的网络,normalizaion                               #
                # =================================================================================== #
               
                fake_out = self.netG1(x_illumination)
                # update D
                self.set_requires_grad(self.netD1, True)
                self.optimizer_D.zero_grad()
                self.backward_D(x_illumination,fake_out,x_fixed)
                self.optimizer_D.step()
                # update G
                self.set_requires_grad(self.netD1, False)
                self.optimizer_G.zero_grad()
                self.backward_G(x_illumination,fake_out,x_fixed,label_org)
                self.optimizer_G.step()
                
                                 
                # =================================================================================== #
                #                             2. Train the discriminator                              #
                # =================================================================================== #
    
                # Compute loss with real images.
                out_src, out_cls = self.D(x_illumination) #D接受的就只是一幅图像,真实的具有光照的图像
                #判别器以一个batch(16张)的真实图片为输入,输出out_src[16, 1, 2, 2],用来判断图片真假。
                #out_cls[16, 5],得到图片的标签估计。 
                d_loss_real = - torch.mean(out_src) # d_loss_real最小,那么 out_src 最大==1 (针对图像)
                d_loss_cls = self.classification_loss(out_cls, label_org, self.dataset) #针对标签 
                #d_loss_cls = -self.classification_loss(out_cls, label_org, dataset = 'RaFD')
                ##衡量真实标签与标签估计的差距
                x_fake = self.G(x_fixed, label_org) #x_fake 生成一个图像数据
                out_src, out_cls = self.D(x_fake.detach())#在这里表示不用求上面一行中的G的梯度
                d_loss_fake = torch.mean(out_src) #假图像为0 
                #判定越接近为假,损失越小
                #加到这个地方,归类生成图像的光照
                #d_loss_cls = self.classification_loss(out_cls, label_org, self.dataset)
                # Compute loss for gradient penalty.
                #计算梯度惩罚因子alpha,根据alpha结合x_real,x_fake,输入判别网络,计算梯度,得到梯度损失函数,
                alpha = torch.rand(x_fixed.size(0), 1, 1, 1).to(self.device) 
                # alpha是一个随机数 tensor([[[[ 0.7610]]]])
                x_hat = (alpha * x_fixed.data + (1 - alpha) * x_fake.data).requires_grad_(True)
                # x_hat是一个图像大小的张量数据,随着alpha的改变而变化
                out_src, _ = self.D(x_hat) #x_hat 表示梯度惩罚因子
                d_loss_gp = self.gradient_penalty(out_src, x_hat)
                d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls + self.lambda_gp * d_loss_gp
                #print(d_loss_real,d_loss_fake,d_loss_cls,d_loss_gp)
                #(1.00000e-04 *1.1113) (1.00000e-05 * -3.0589) (13.8667) (0.9953)
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()
    
                # Logging.
                loss = {}
                loss['D/loss_real'] = d_loss_real.item()
                loss['D/loss_fake'] = d_loss_fake.item()
                loss['D/loss_cls'] = (self.lambda_cls *d_loss_cls).item()
                loss['D/loss_gp'] = (self.lambda_gp * d_loss_gp).item()
                
                # =================================================================================== #
                #                               3. Train the generator                                #
                # =================================================================================== #
                #生成网络的作用是,输入original域的图可以生成目标域的图像,输入为目标域的图像,生成original域的图像(重建)
                if (i+1) % self.n_critic == 0: #每更新5次判别器再更新一次生成器
                    # Original-to-target domain.
                    #将真实图像输入x_real和假的标签c_trg输入生成网络,得到生成图像x_fake
                    x_fake = self.G(x_fixed, label_org)
                    out_src, out_cls = self.D(x_fake)
                    g_loss_fake = - torch.mean(out_src) #这里是对抗损失,希望生成的假图像为1
                    g_loss_cls = self.classification_loss(out_cls, label_org, self.dataset)#向目标标签进行转化
                    #g_loss_cls = -self.classification_loss(out_cls, label_org, dataset = 'RaFD')
                    # Target-to-original domain.
                    # 这里结合另一个GAN 进行重建
                    #x_reconst = self.G(x_fake, c_org)
                    #g_loss_rec = torch.mean(torch.abs(x_fixed - x_reconst))
                    g_ground_truth = torch.mean(torch.abs(x_illumination - x_fake))
                    #和normlization结合进行重建
                    g_loss_rec = torch.mean(torch.abs(self.G(self.netG1(x_illumination),label_org) - x_illumination))
                    
                    # Backward and optimize.
                    g_loss = g_loss_fake + 100 * g_ground_truth + self.lambda_cls * g_loss_cls +
                    self.lambda_rec  * g_loss_rec
                    #print(g_loss_fake,g_ground_truth,g_loss_cls,g_loss_rec)
                    #tensor(-0.4776) tensor(0.4306) tensor(5.2388) tensor(0.4283)
                    
                    self.reset_grad()
                    g_loss.backward()
                    self.g_optimizer.step()
    
                    # Logging.
                    loss['G/loss_fake'] = g_loss_fake.item()
                    loss['G/loss_gt'] = (self.lambda_rec *g_ground_truth).item()
                    loss['G/loss_rec'] = (self.lambda_rec *g_loss_rec).item()
                    loss['G/loss_cls'] = g_loss_cls.item()
    
                # =================================================================================== #
                #                                 4. Miscellaneous                                    #
                # =================================================================================== #
    
                # Print out training information.
                if (i+1) % self.log_step == 0:
                    et = time.time() - start_time
                    et = str(datetime.timedelta(seconds=et))[:-7]
                    log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters)
                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print(log)
    
                    if self.use_tensorboard:
                        for tag, value in loss.items():
                            self.logger.scalar_summary(tag, value, i+1)
    
                # Translate fixed images for debugging. 每100轮保存一次图像
                if (i+1) % self.sample_step == 0:
                    with torch.no_grad():
                        x_fake_list = [x_fixed]
                       
                        x_fake_list.append(self.G(x_fixed, label_org))
                        x_concat = torch.cat(x_fake_list, dim=3)
                        sample_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(i+1))
                        save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0)
                        print('Saved real and fake images into {}...'.format(sample_path))
    
                # Save model checkpoints. 每100轮保存一下模型
                if (i+1) % self.model_save_step == 0:
                    G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i+1))
                    D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1))
                    torch.save(self.G.state_dict(), G_path)
                    torch.save(self.D.state_dict(), D_path)
                    G1_path = os.path.join(self.model_save_dir, '{}-G1.ckpt'.format(i+1))
                    D1_path = os.path.join(self.model_save_dir, '{}-D1.ckpt'.format(i+1))
                    torch.save(self.netG1.state_dict(), G1_path)
                    torch.save(self.netD1.state_dict(), D1_path)
                    print('Saved model checkpoints into {}...'.format(self.model_save_dir))
    
                # Decay learning rates. 降低学习率
                if (i+1) % self.lr_update_step == 0 and (i+1) > (self.num_iters - self.num_iters_decay):
                    g_lr -= (self.g_lr / float(self.num_iters_decay))
                    d_lr -= (self.d_lr / float(self.num_iters_decay))
                    self.update_lr(g_lr, d_lr)
                    print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))
                    
  • 相关阅读:
    MCMC等采样算法
    【模拟退火】旅行商问题
    团队项目博客汇总
    2018年春季学期-助教总结
    linux简单命令常用随记
    记一次计算机网络作业
    [东北师大软工]Week2-作业2:个人项目实战 初步测试结果
    为什么你学过Java却忘光了——记第一次助教同学见面会
    ahk打印成pdf记录
    PSP总结报告
  • 原文地址:https://www.cnblogs.com/hxjbc/p/10018692.html
Copyright © 2011-2022 走看看