php中文网 | cnphp.com

 找回密码
 立即注册

QQ登录

只需一步,快速开始

搜索
查看: 637|回复: 0

GAN用于生二维数据

[复制链接]

3138

主题

3148

帖子

1万

积分

管理员

Rank: 9Rank: 9Rank: 9

UID
1
威望
0
积分
7946
贡献
0
注册时间
2021-4-14
最后登录
2024-11-21
在线时间
763 小时
QQ
发表于 2022-5-28 08:24:25 | 显示全部楼层 |阅读模式
[mw_shl_code=python,true]import argparse
import distutils.util
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import numpy as np
import pickle
import matplotlib.pyplot as plt
import os
from new_batch import DataSet
from tensorflow.examples.tutorials.mnist import input_data
parser = argparse.ArgumentParser()
parser.add_argument("-g", "--gpu_number", type=str, default="0")
# 1:train 0:test
parser.add_argument("-f", "--flag", type=float, default=0)
# Train Iteration
parser.add_argument("-e", "--epochs", type=int, default=30)
# Train Parameter
parser.add_argument("-b", "--batch_size", type=int, default=10)
parser.add_argument("-lr", "--learning_rate", type=float, default=0.001)
parser.add_argument("-alpha", type=float, default=0.01,help='leaky ReLU')
parser.add_argument("-noise_size", type=int, default=20,help='generator noise size')
parser.add_argument("-g_units", type=int, default=6,help='generator units')
parser.add_argument("-d_units", type=int, default=6,help='discriminator units')
parser.add_argument("-smooth", type=float, default=0.1,help='label smoothing')


# 属性给与args实例: 把parser中设置的所有"add_argument"给返回到args子类实例当中, 那么parser中增加的属性内容都会在args实例中,使用即可。
args = parser.parse_args()



def GANbalance(X_train, Y_train):
    tf.reset_default_graph()
    with tf.device('/gpu:{0}'.format(args.gpu_number)):
        gpu_options = tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=0.90)
        config = tf.compat.v1.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options)

        with tf.compat.v1.Session(config=config) as sess:
            # TRAIN / TEST
            if args.flag == 0:
                samples= train(sess, 0,X_train,Y_train)
                    # show_result()
            else:
                train(sess, 1,X_train,Y_train)
    return samples


os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_number
k = 1

def inputs(real_size, noise_size):

    real_digit = tf.placeholder(tf.float32, [None, k], name='real_digit')
    real_data = tf.placeholder(tf.float32, [None, real_size], name='real_data')

    noise_data = tf.placeholder(tf.float32, [None, noise_size], name='noise_data')

    return real_data, noise_data, real_digit

#generator(real_data_digit, noise_data, g_units, np.shape(dataMat)[1])这里确实生成6个特征,没有标签
def generator(digit, noise_data, n_units, out_dim, reuse=False, alpha=0.01):
    """
    生成器
    noise_data: 生成器的输入
    n_units: 隐层单元个数
    out_dim: 生成器输出tensor的size,这里应该为7
    alpha: leaky ReLU系数
    """
    with tf.variable_scope("generator", reuse=reuse):
        concatenated_data_digit = tf.concat([digit, noise_data], 1)#真实数据标签+噪声数据组合在一起
        # hidden layer
        hidden1 = tf.layers.dense(concatenated_data_digit, n_units)
        # leaky ReLU ,和ReLU区别:ReLU是将所有的负值都设为零,相反,Leaky ReLU是给所有负值赋予一个非零斜率。
        hidden1 = tf.maximum(alpha * hidden1, hidden1)
        # dropout
        hidden1 = tf.layers.dropout(hidden1, rate=0.2)
        # logits & outputs
        logits = tf.layers.dense(hidden1, out_dim)
        outputs = tf.tanh(logits)
        #print("outputs",outputs)

        return outputs

#discriminator(real_data_digit, g_outputs, d_units, reuse=True)
def discriminator(digit, data, n_units, reuse=False, alpha=0.01):
    """
    判别器
    n_units: 隐层结点数量
    alpha: Leaky ReLU系数
    """
    with tf.variable_scope("discriminator", reuse=reuse):
        concatenated_data_digit = tf.concat([digit, data], 1)#真实数据标签和生成器生成的数据组合
        # hidden layer
        hidden1 = tf.layers.dense(concatenated_data_digit, n_units)
        hidden1 = tf.maximum(alpha * hidden1, hidden1)
        # logits & outputs
        logits = tf.layers.dense(hidden1, 1)
        #print("logits------",logits)
        outputs = tf.sigmoid(logits)
        return logits, outputs


def train(sess, flag,dataMat,classLabels):
    # tf.reset_default_graph()
    real_data, noise_data, real_data_digit = inputs(np.shape(dataMat)[1], args.noise_size)#分别表示特征数、噪声数7、标签数
    # generator
    g_outputs = generator(real_data_digit, noise_data, args.g_units, np.shape(dataMat)[1])
    # discriminator
    d_logits_real, d_outputs_real = discriminator(real_data_digit, real_data, args.d_units)
    #真实数据和生成器生成的数据的判别
    d_logits_fake, d_outputs_fake = discriminator(real_data_digit, g_outputs, args.d_units, reuse=True)

    # Loss
    # discriminator的loss
    # 识别真实图片
    #tf.reduce_mean计算张量某一维度的平均值
    # 真实图片往1方向优化,sigmoid_cross_entropy_with_logits和sigmoid_softmax_entropy_with_logits一样,只是二分类
    #目的衡量分类任务中的概率误差
    #tf.ones_like创建一个都是1的张量
    d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real,
                                                                         labels=tf.ones_like(d_logits_real)) * (
                                             1 - args.smooth))
    # 识别生成的图片
    # 生成图片往0方向优化
    d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
                                                                         labels=tf.zeros_like(d_logits_fake)))
    # 总体loss
    d_loss = tf.add(d_loss_real, d_loss_fake)
    # generator的loss
    # 生成图片尽量往1方向优化
    g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
                                                                    labels=tf.ones_like(d_logits_fake)) * (1 - args.smooth))

    # Optimizer

    train_vars = tf.trainable_variables()

    # generator中的tensor
    g_vars = [var for var in train_vars if var.name.startswith("generator")]
    # discriminator中的tensor
    d_vars = [var for var in train_vars if var.name.startswith("discriminator")]

    # optimizer
    d_train_opt = tf.train.AdamOptimizer(args.learning_rate).minimize(d_loss, var_list=d_vars)
    g_train_opt = tf.train.AdamOptimizer(args.learning_rate).minimize(g_loss, var_list=g_vars)
    # 存储测试样例
    samples = []
    # 存储loss
    losses = []
    # 保存生成器变量
    saver = tf.train.Saver(var_list=g_vars)
    if flag == 0:
        #初始化模型参数
        ds=DataSet(dataMat, classLabels, 10)
        sess.run(tf.global_variables_initializer())
        for e in range(args.epochs):
            for batch_i in range(dataMat.shape[0] // args.batch_size):
                batch = ds.next_batch(args.batch_size)#batch0是特征,batch1是标签
                # print("batch0\n",np.array(batch[0]).shape[0])#结果为3
                # print("batch0\n", np.array(batch[0]).shape[1])#结果为7
                digit = batch[1]#,标签,这里要修改
                # print("batch[1]\n",batch[1])#结果为[1.1.1]
                batch_datas = batch[0].reshape((args.batch_size, 6))
                digits=digit.reshape((args.batch_size, 1))#真实数据标签集
                # print("digits\n",digits)
                # 对图像像素进行scale,这是因为tanh输出的结果介于(-1,1),real和fake图片共享discriminator的参数
                # 把图片灰度0~1变成 -1 到1的值, 以适应generator输出的结果(-1,1)
                batch_datas = batch_datas * 2 - 1

                # generator的输入噪声
                batch_noise = np.random.uniform(-1, 1, size=(args.batch_size, args.noise_size))

                # Run optimizers
                _ = sess.run(d_train_opt,#不断优化辨别器
                             feed_dict={real_data_digit: digits, real_data: batch_datas, noise_data: batch_noise})
                _ = sess.run(g_train_opt, feed_dict={real_data_digit: digits, noise_data: batch_noise})

            # 每一轮结束计算loss
            train_loss_d = sess.run(d_loss,
                                    feed_dict={real_data_digit: digits,
                                               real_data: batch_datas,
                                               noise_data: batch_noise})
            # real img loss
            train_loss_d_real = sess.run(d_loss_real,
                                         feed_dict={real_data_digit: digits,
                                                    real_data: batch_datas,
                                                    noise_data: batch_noise})
            # fake img loss
            train_loss_d_fake = sess.run(d_loss_fake,
                                         feed_dict={real_data_digit: digits,
                                                    real_data: batch_datas,
                                                    noise_data: batch_noise})
            # generator loss
            train_loss_g = sess.run(g_loss,
                                    feed_dict={real_data_digit: digits, noise_data: batch_noise})

            print("Epoch {}/{}...".format(e + 1, args.epochs),
                  "Discriminator Loss: {:.4f}(Real: {:.4f} + Fake: {:.4f})...".format(train_loss_d, train_loss_d_real,
                                                                                      train_loss_d_fake),
                  "Generator Loss: {:.4f}".format(train_loss_g))
            # 记录各类loss值
            losses.append((train_loss_d, train_loss_d_real, train_loss_d_fake, train_loss_g))

            # 抽取样本后期进行观察
            n_sample = 10
            #-1采样下界;1采样上届;size为int型或者tuple型,n_sample行noise_size列
            sample_noise = np.random.uniform(-1, 1, size=(n_sample, args.noise_size))
            #print("sample_noise",sample_noise)
            gen_samples = sess.run(generator(real_data_digit, noise_data, args.g_units, np.shape(dataMat)[1], reuse=True),
                                   feed_dict={real_data_digit: digits, noise_data: sample_noise})
            samples.append(gen_samples)
            #print("samples",samples)
            #print("type(samples)",type(samples))  list类型
            sampless=np.array(samples)
            # print("sampless", sampless)
            # print("sampless.shape\n",sampless.shape)
            sampless=sampless.reshape(sampless.shape[1] * sampless.shape[0],sampless.shape[2])
            # print("sampless.reshape\n", sampless.shape)
            # 存储checkpoints,这是一个二进制文件,它保存了权重、偏置项、梯度以及其他所有的变量的取值,扩展名为.ckpt
            saver.save(sess, './checkpoints/generator.ckpt')
        #print("loss",losses)回头生成表格记录一下
        # 将sample的生成数据记录下来  待修改,将数据存储在表格中
        np.savetxt("train_samples.csv", sampless, delimiter=",")
        # write_file = open('train_samples.csv', 'wb')
        # for i in range(len(sampless)):
        #     for j in range(7):
        #         # np.savetxt('train_samples.csv',samples+'\n',fmt ='%d')
        #         write_file.write(sampless[j])
        #     write_file.write('\n')
        # write_file.close()

        # with open('train_samples.pkl', 'wb') as f:
        #     pickle.dump(samples, f)

        # with open('train_loss.txt', 'wb') as f:
        #     pickle.dump(losses, f)

        losses = np.array(losses)
        plt.plot(losses.T[0], label='Discriminator Total Loss')
        plt.plot(losses.T[1], label='Discriminator Real Loss')
        plt.plot(losses.T[2], label='Discriminator Fake Loss')
        plt.plot(losses.T[3], label='Generator')
        plt.title("Training Losses")
        plt.legend()
        plt.show()
    else:
        saver.restore(sess, './checkpoints/generator.ckpt')
        sample_noise = np.random.uniform(-1, 1, size=(25, args.noise_size))

        # 生成标签用户生成图片
        digits = np.zeros((25, k))
        for i in range(0, 25):
            j = np.random.randint(0, 9, 1)
            digits[j] = 1

        print(digits)
        gen_samples = sess.run(generator(real_data_digit, noise_data, args.g_units, np.shape(dataMat)[1], reuse=True),
                               feed_dict={real_data_digit: digits, noise_data: sample_noise})

        print("gen_samples",gen_samples)

    return sampless









[/mw_shl_code]

回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

QQ|php中文网 | cnphp.com ( 赣ICP备2021002321号-2 )

GMT+8, 2024-11-22 02:51 , Processed in 0.287604 second(s), 35 queries , Gzip On.

Powered by Discuz! X3.4 Licensed

Copyright © 2001-2020, Tencent Cloud.

申明:本站所有资源皆搜集自网络,相关版权归版权持有人所有,如有侵权,请电邮(fiorkn@foxmail.com)告之,本站会尽快删除。

快速回复 返回顶部 返回列表