admin 发表于 2022-5-25 22:47:40

python神经网络

import os

import torch
fromtorch import nn
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets
fromtorchvision.transforms import ToTensor
fromtorch.utils.data import DataLoader,Dataset
import random
h_dim = 200

class Generator(nn.Module):
    def __init__(self):
      super(Generator, self).__init__()
      self.net = nn.Sequential(
                                 nn.Flatten(),
                                 nn.Linear(784, h_dim),
                                 nn.ReLU(),
                                 nn.Linear(h_dim, h_dim),
                                 nn.ReLU(),
                                 nn.Linear(h_dim, h_dim),
                                 nn.ReLU(),
                                 nn.Linear(h_dim, 784))
    def forward(self, x):
      batch_size = x.size(0)
      x = self.net(x)
      x = x.view(batch_size, 1, 28, 28)
      return x
class Descriminator(nn.Module):
    def __init__(self):
      super(Descriminator, self).__init__()
      self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(784, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, 1),
            nn.Sigmoid()
      )
    def forward(self, x):

      x = self.net(x)

      return x
def main():
    dataset_train = datasets.MNIST("../dataset", train=True, transform=ToTensor(), download=True)
    dataset_test = datasets.MNIST("../dataset", train=False, transform=ToTensor(), download=True)
    print(len(dataset_test))
    batch_size = 32
    datalaod_train = DataLoader(dataset_train, batch_size = batch_size, drop_last=True, shuffle=True)
    datalaod_test = DataLoader(dataset_test, batch_size = batch_size, drop_last=True, shuffle=True)
    device = torch.device("cuda")
    '''paremeter: '''
    Gen = Generator().to(device)
    Des = Descriminator().to(device)
    # if os.path.exists("./Gen.pth"):
    #   statu = torch.load("Gen.pth")
    #   Gen.load_state_dict(statu)
    # if os.path.exists("./Des.pth"):
    #   statu = torch.load("Des.pth")
    #   Des.load_state_dict(statu)
    losser_BCE_des = nn.BCELoss().to(device)
    losser_BCE_gen = nn.BCELoss().to(device)
    Optimizer_Gen = torch.optim.Adam(Gen.parameters(), lr=1e-3)
    Optimizer_Des = torch.optim.Adam(Des.parameters(), lr=1e-3)
    write= SummaryWriter("../logs")
    for epoch in range(1500):
      #train descriminater
      #打乱数据集
      mix_data_lable = []
      iter1 = iter(datalaod_test)
      for num in range(5):
            i,_ = iter1.next()
            mix_data_lable.append((i, torch.tensor([], dtype=torch.float32).transpose(0, 1)))
            z = torch.randn(batch_size, 1, 28, 28).to(device)
            xf = Gen(z).detach()
            mix_data_lable.append((xf, torch.tensor([], dtype=torch.float32).transpose(0, 1)))
      random.shuffle(mix_data_lable)
      # mix_data_lable = list(zip(*mix_data_lable))
      # data = mix_data_lable
      # lable = mix_data_lable
      for i in mix_data_lable:
            imgs, labels = i
            imgs =imgs.to(device)
            labels = labels.to(device)
            pre_des = Des(imgs)
            loss_des = losser_BCE_des(pre_des, labels)
            #三部曲
            Optimizer_Des.zero_grad()
            loss_des.backward()
            Optimizer_Des.step()
      if epoch % 10 == 0:
            write.add_images("gan_train_des", torch.reshape(imgs, (batch_size, 1, 28, 28)), global_step=epoch)
            write.add_scalar("des_loss", loss_des.item(), epoch)
      # for i in datalaod_test:
      #   #real data
      #   img, _ = i
      #   img = img.to(device)
      #   prer = Des(img)
      #   loss_1 = -prer.mean()
      #   #fake data
      #   z = torch.randn(batch_size, 1, 28, 28).to(device)
      #   xf= Gen(z).detach()
      #   pref = Des(xf)
      #   loss_2 = pref.mean()
      #   loss = loss_1 + loss_2
      #   #三部曲
      #   Optimizer_Des.zero_grad()
      #   loss.backward()
      #   Optimizer_Des.step()

      #train generator
      #制作gen用的训练集
      for i in range(50):
            z = torch.randn(1, 1, 28, 28).to(device)
            xf = Gen(z)
            fake_images = xf
            labels = torch.tensor([], dtype=torch.float32).transpose(0, 1)
            fake_images = fake_images.to(device)
            labels = labels.to(device)
            pre_gen = Des(fake_images)
            loss_gen = losser_BCE_gen(pre_gen, labels)
            #三部曲
            Optimizer_Gen.zero_grad()
            loss_gen.backward()
            Optimizer_Gen.step()
      if epoch % 10 == 0:
            write.add_images("gan_train_gen", torch.reshape(fake_images, (1, 1, 28, 28)), global_step=epoch)
            write.add_scalar("gen_loss", loss_gen.item(), epoch)
    # torch.save(Gen.state_dict(), "Gen.pth")
    # torch.save(Des.state_dict(), "Des.pth")

      # for i in datalaod_test:
      #   img, _ = i
      #   z = torch.randn(batch_size, 1, 28, 28).to(device)
      #   xf = Gen(z)
      #   pref = Des(xf)
      #
      #   #三部曲
      #   Optimizer_Gen.zero_grad()
      #   loss_gen.backward()
      #   Optimizer_Gen.step()

       #展示结果

if __name__ == "__main__":
    main()
    # 解决思路1.将gen生成的图片和所有图片混合,且生成标签 2.将混合的图片交给des,用二分类结计算loss 3.优化loss,优化des 4.优化gen,先有gen生成图片,
    #打好label,交付给des,用二分类loss做评价,反向优化gan,使得gen得到的图片能骗过des
页: [1]
查看完整版本: python神经网络