威望0
积分7976
贡献0
在线时间763 小时
UID1
注册时间2021-4-14
最后登录2024-11-24
管理员
- UID
- 1
- 威望
- 0
- 积分
- 7976
- 贡献
- 0
- 注册时间
- 2021-4-14
- 最后登录
- 2024-11-24
- 在线时间
- 763 小时
|
[mw_shl_code=python,true]import os
import torch
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader,Dataset
import random
h_dim = 200
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.net = nn.Sequential(
nn.Flatten(),
nn.Linear(784, h_dim),
nn.ReLU(),
nn.Linear(h_dim, h_dim),
nn.ReLU(),
nn.Linear(h_dim, h_dim),
nn.ReLU(),
nn.Linear(h_dim, 784))
def forward(self, x):
batch_size = x.size(0)
x = self.net(x)
x = x.view(batch_size, 1, 28, 28)
return x
class Descriminator(nn.Module):
def __init__(self):
super(Descriminator, self).__init__()
self.net = nn.Sequential(
nn.Flatten(),
nn.Linear(784, h_dim),
nn.ReLU(),
nn.Linear(h_dim, h_dim),
nn.ReLU(),
nn.Linear(h_dim, h_dim),
nn.ReLU(),
nn.Linear(h_dim, 1),
nn.Sigmoid()
)
def forward(self, x):
x = self.net(x)
return x
def main():
dataset_train = datasets.MNIST("../dataset", train=True, transform=ToTensor(), download=True)
dataset_test = datasets.MNIST("../dataset", train=False, transform=ToTensor(), download=True)
print(len(dataset_test))
batch_size = 32
datalaod_train = DataLoader(dataset_train, batch_size = batch_size, drop_last=True, shuffle=True)
datalaod_test = DataLoader(dataset_test, batch_size = batch_size, drop_last=True, shuffle=True)
device = torch.device("cuda")
'''paremeter: [32, 1, 28, 28]'''
Gen = Generator().to(device)
Des = Descriminator().to(device)
# if os.path.exists("./Gen.pth"):
# statu = torch.load("Gen.pth")
# Gen.load_state_dict(statu)
# if os.path.exists("./Des.pth"):
# statu = torch.load("Des.pth")
# Des.load_state_dict(statu)
losser_BCE_des = nn.BCELoss().to(device)
losser_BCE_gen = nn.BCELoss().to(device)
Optimizer_Gen = torch.optim.Adam(Gen.parameters(), lr=1e-3)
Optimizer_Des = torch.optim.Adam(Des.parameters(), lr=1e-3)
write = SummaryWriter("../logs")
for epoch in range(1500):
#train descriminater
#打乱数据集
mix_data_lable = []
iter1 = iter(datalaod_test)
for num in range(5):
i,_ = iter1.next()
mix_data_lable.append((i, torch.tensor([[1 for k in range(batch_size)]], dtype=torch.float32).transpose(0, 1)))
z = torch.randn(batch_size, 1, 28, 28).to(device)
xf = Gen(z).detach()
mix_data_lable.append((xf, torch.tensor([[0 for k in range(batch_size)]], dtype=torch.float32).transpose(0, 1)))
random.shuffle(mix_data_lable)
# mix_data_lable = list(zip(*mix_data_lable))
# data = mix_data_lable[0]
# lable = mix_data_lable[1]
for i in mix_data_lable:
imgs, labels = i
imgs =imgs.to(device)
labels = labels.to(device)
pre_des = Des(imgs)
loss_des = losser_BCE_des(pre_des, labels)
#三部曲
Optimizer_Des.zero_grad()
loss_des.backward()
Optimizer_Des.step()
if epoch % 10 == 0:
write.add_images("gan_train_des", torch.reshape(imgs, (batch_size, 1, 28, 28)), global_step=epoch)
write.add_scalar("des_loss", loss_des.item(), epoch)
# for i in datalaod_test:
# #real data
# img, _ = i
# img = img.to(device)
# prer = Des(img)
# loss_1 = -prer.mean()
# #fake data
# z = torch.randn(batch_size, 1, 28, 28).to(device)
# xf= Gen(z).detach()
# pref = Des(xf)
# loss_2 = pref.mean()
# loss = loss_1 + loss_2
# #三部曲
# Optimizer_Des.zero_grad()
# loss.backward()
# Optimizer_Des.step()
#train generator
#制作gen用的训练集
for i in range(50):
z = torch.randn(1, 1, 28, 28).to(device)
xf = Gen(z)
fake_images = xf
labels = torch.tensor([[1 for k in range(1)]], dtype=torch.float32).transpose(0, 1)
fake_images = fake_images.to(device)
labels = labels.to(device)
pre_gen = Des(fake_images)
loss_gen = losser_BCE_gen(pre_gen, labels)
#三部曲
Optimizer_Gen.zero_grad()
loss_gen.backward()
Optimizer_Gen.step()
if epoch % 10 == 0:
write.add_images("gan_train_gen", torch.reshape(fake_images, (1, 1, 28, 28)), global_step=epoch)
write.add_scalar("gen_loss", loss_gen.item(), epoch)
# torch.save(Gen.state_dict(), "Gen.pth")
# torch.save(Des.state_dict(), "Des.pth")
# for i in datalaod_test:
# img, _ = i
# z = torch.randn(batch_size, 1, 28, 28).to(device)
# xf = Gen(z)
# pref = Des(xf)
#
# #三部曲
# Optimizer_Gen.zero_grad()
# loss_gen.backward()
# Optimizer_Gen.step()
#展示结果
if __name__ == "__main__":
main()
# 解决思路1.将gen生成的图片和所有图片混合,且生成标签 2.将混合的图片交给des,用二分类结计算loss 3.优化loss,优化des 4.优化gen,先有gen生成图片,
#打好label,交付给des,用二分类loss做评价,反向优化gan,使得gen得到的图片能骗过des[/mw_shl_code] |
|