php中文网 | cnphp.com

 找回密码
 立即注册

QQ登录

只需一步,快速开始

搜索
查看: 574|回复: 0

python神经网络

[复制链接]

3150

主题

3160

帖子

1万

积分

管理员

Rank: 9Rank: 9Rank: 9

UID
1
威望
0
积分
7976
贡献
0
注册时间
2021-4-14
最后登录
2024-11-24
在线时间
763 小时
QQ
发表于 2022-5-25 22:47:40 | 显示全部楼层 |阅读模式
[mw_shl_code=python,true]import os

import torch
from  torch import nn
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets
from  torchvision.transforms import ToTensor
from  torch.utils.data import DataLoader,Dataset
import random
h_dim = 200

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
                                 nn.Flatten(),
                                 nn.Linear(784, h_dim),
                                 nn.ReLU(),
                                 nn.Linear(h_dim, h_dim),
                                 nn.ReLU(),
                                 nn.Linear(h_dim, h_dim),
                                 nn.ReLU(),
                                 nn.Linear(h_dim, 784))
    def forward(self, x):
        batch_size = x.size(0)
        x = self.net(x)
        x = x.view(batch_size, 1, 28, 28)
        return x
class Descriminator(nn.Module):
    def __init__(self):
        super(Descriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(784, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, 1),
            nn.Sigmoid()
        )
    def forward(self, x):

        x = self.net(x)

        return x
def main():
    dataset_train = datasets.MNIST("../dataset", train=True, transform=ToTensor(), download=True)
    dataset_test = datasets.MNIST("../dataset", train=False, transform=ToTensor(), download=True)
    print(len(dataset_test))
    batch_size = 32
    datalaod_train = DataLoader(dataset_train, batch_size = batch_size, drop_last=True, shuffle=True)
    datalaod_test = DataLoader(dataset_test, batch_size = batch_size, drop_last=True, shuffle=True)
    device = torch.device("cuda")
    '''paremeter: [32, 1, 28, 28]'''
    Gen = Generator().to(device)
    Des = Descriminator().to(device)
    # if os.path.exists("./Gen.pth"):
    #     statu = torch.load("Gen.pth")
    #     Gen.load_state_dict(statu)
    # if os.path.exists("./Des.pth"):
    #     statu = torch.load("Des.pth")
    #     Des.load_state_dict(statu)
    losser_BCE_des = nn.BCELoss().to(device)
    losser_BCE_gen = nn.BCELoss().to(device)
    Optimizer_Gen = torch.optim.Adam(Gen.parameters(), lr=1e-3)
    Optimizer_Des = torch.optim.Adam(Des.parameters(), lr=1e-3)
    write  = SummaryWriter("../logs")
    for epoch in range(1500):
        #train descriminater
        #打乱数据集
        mix_data_lable = []
        iter1 = iter(datalaod_test)
        for num in range(5):
            i,_ = iter1.next()
            mix_data_lable.append((i, torch.tensor([[1 for k in range(batch_size)]], dtype=torch.float32).transpose(0, 1)))
            z = torch.randn(batch_size, 1, 28, 28).to(device)
            xf = Gen(z).detach()
            mix_data_lable.append((xf, torch.tensor([[0 for k in range(batch_size)]], dtype=torch.float32).transpose(0, 1)))
        random.shuffle(mix_data_lable)
        # mix_data_lable = list(zip(*mix_data_lable))
        # data = mix_data_lable[0]
        # lable = mix_data_lable[1]
        for i in mix_data_lable:
            imgs, labels = i
            imgs =imgs.to(device)
            labels = labels.to(device)
            pre_des = Des(imgs)
            loss_des = losser_BCE_des(pre_des, labels)
            #三部曲
            Optimizer_Des.zero_grad()
            loss_des.backward()
            Optimizer_Des.step()
        if epoch % 10 == 0:
            write.add_images("gan_train_des", torch.reshape(imgs, (batch_size, 1, 28, 28)), global_step=epoch)
            write.add_scalar("des_loss", loss_des.item(), epoch)
        # for i in datalaod_test:
        #     #real data
        #     img, _ = i
        #     img = img.to(device)
        #     prer = Des(img)
        #     loss_1 = -prer.mean()
        #     #fake data
        #     z = torch.randn(batch_size, 1, 28, 28).to(device)
        #     xf= Gen(z).detach()
        #     pref = Des(xf)
        #     loss_2 = pref.mean()
        #     loss = loss_1 + loss_2
        #     #三部曲
        #     Optimizer_Des.zero_grad()
        #     loss.backward()
        #     Optimizer_Des.step()

        #train generator
        #制作gen用的训练集
        for i in range(50):
            z = torch.randn(1, 1, 28, 28).to(device)
            xf = Gen(z)
            fake_images = xf
            labels = torch.tensor([[1 for k in range(1)]], dtype=torch.float32).transpose(0, 1)
            fake_images = fake_images.to(device)
            labels = labels.to(device)
            pre_gen = Des(fake_images)
            loss_gen = losser_BCE_gen(pre_gen, labels)
            #三部曲
            Optimizer_Gen.zero_grad()
            loss_gen.backward()
            Optimizer_Gen.step()
        if epoch % 10 == 0:
            write.add_images("gan_train_gen", torch.reshape(fake_images, (1, 1, 28, 28)), global_step=epoch)
            write.add_scalar("gen_loss", loss_gen.item(), epoch)
    # torch.save(Gen.state_dict(), "Gen.pth")
    # torch.save(Des.state_dict(), "Des.pth")

        # for i in datalaod_test:
        #     img, _ = i
        #     z = torch.randn(batch_size, 1, 28, 28).to(device)
        #     xf = Gen(z)
        #     pref = Des(xf)
        #
        #     #三部曲
        #     Optimizer_Gen.zero_grad()
        #     loss_gen.backward()
        #     Optimizer_Gen.step()

       #展示结果

if __name__ == "__main__":
    main()
    # 解决思路1.将gen生成的图片和所有图片混合,且生成标签 2.将混合的图片交给des,用二分类结计算loss 3.优化loss,优化des 4.优化gen,先有gen生成图片,
    #打好label,交付给des,用二分类loss做评价,反向优化gan,使得gen得到的图片能骗过des[/mw_shl_code]

回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

QQ|php中文网 | cnphp.com ( 赣ICP备2021002321号-2 )

GMT+8, 2024-11-24 13:15 , Processed in 0.250459 second(s), 35 queries , Gzip On.

Powered by Discuz! X3.4 Licensed

Copyright © 2001-2020, Tencent Cloud.

申明:本站所有资源皆搜集自网络,相关版权归版权持有人所有,如有侵权,请电邮(fiorkn@foxmail.com)告之,本站会尽快删除。

快速回复 返回顶部 返回列表