admin 发表于 2022-5-28 08:24:25

GAN用于生二维数据

import argparse
import distutils.util
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import numpy as np
import pickle
import matplotlib.pyplot as plt
import os
from new_batch import DataSet
from tensorflow.examples.tutorials.mnist import input_data
parser = argparse.ArgumentParser()
parser.add_argument("-g", "--gpu_number", type=str, default="0")
# 1:train 0:test
parser.add_argument("-f", "--flag", type=float, default=0)
# Train Iteration
parser.add_argument("-e", "--epochs", type=int, default=30)
# Train Parameter
parser.add_argument("-b", "--batch_size", type=int, default=10)
parser.add_argument("-lr", "--learning_rate", type=float, default=0.001)
parser.add_argument("-alpha", type=float, default=0.01,help='leaky ReLU')
parser.add_argument("-noise_size", type=int, default=20,help='generator noise size')
parser.add_argument("-g_units", type=int, default=6,help='generator units')
parser.add_argument("-d_units", type=int, default=6,help='discriminator units')
parser.add_argument("-smooth", type=float, default=0.1,help='label smoothing')


# 属性给与args实例: 把parser中设置的所有"add_argument"给返回到args子类实例当中, 那么parser中增加的属性内容都会在args实例中,使用即可。
args = parser.parse_args()



def GANbalance(X_train, Y_train):
    tf.reset_default_graph()
    with tf.device('/gpu:{0}'.format(args.gpu_number)):
      gpu_options = tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=0.90)
      config = tf.compat.v1.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options)

      with tf.compat.v1.Session(config=config) as sess:
            # TRAIN / TEST
            if args.flag == 0:
                samples= train(sess, 0,X_train,Y_train)
                  # show_result()
            else:
                train(sess, 1,X_train,Y_train)
    return samples


os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_number
k = 1

def inputs(real_size, noise_size):

    real_digit = tf.placeholder(tf.float32, , name='real_digit')
    real_data = tf.placeholder(tf.float32, , name='real_data')

    noise_data = tf.placeholder(tf.float32, , name='noise_data')

    return real_data, noise_data, real_digit

#generator(real_data_digit, noise_data, g_units, np.shape(dataMat))这里确实生成6个特征,没有标签
def generator(digit, noise_data, n_units, out_dim, reuse=False, alpha=0.01):
    """
    生成器
    noise_data: 生成器的输入
    n_units: 隐层单元个数
    out_dim: 生成器输出tensor的size,这里应该为7
    alpha: leaky ReLU系数
    """
    with tf.variable_scope("generator", reuse=reuse):
      concatenated_data_digit = tf.concat(, 1)#真实数据标签+噪声数据组合在一起
      # hidden layer
      hidden1 = tf.layers.dense(concatenated_data_digit, n_units)
      # leaky ReLU ,和ReLU区别:ReLU是将所有的负值都设为零,相反,Leaky ReLU是给所有负值赋予一个非零斜率。
      hidden1 = tf.maximum(alpha * hidden1, hidden1)
      # dropout
      hidden1 = tf.layers.dropout(hidden1, rate=0.2)
      # logits & outputs
      logits = tf.layers.dense(hidden1, out_dim)
      outputs = tf.tanh(logits)
      #print("outputs",outputs)

      return outputs

#discriminator(real_data_digit, g_outputs, d_units, reuse=True)
def discriminator(digit, data, n_units, reuse=False, alpha=0.01):
    """
    判别器
    n_units: 隐层结点数量
    alpha: Leaky ReLU系数
    """
    with tf.variable_scope("discriminator", reuse=reuse):
      concatenated_data_digit = tf.concat(, 1)#真实数据标签和生成器生成的数据组合
      # hidden layer
      hidden1 = tf.layers.dense(concatenated_data_digit, n_units)
      hidden1 = tf.maximum(alpha * hidden1, hidden1)
      # logits & outputs
      logits = tf.layers.dense(hidden1, 1)
      #print("logits------",logits)
      outputs = tf.sigmoid(logits)
      return logits, outputs


def train(sess, flag,dataMat,classLabels):
    # tf.reset_default_graph()
    real_data, noise_data, real_data_digit = inputs(np.shape(dataMat), args.noise_size)#分别表示特征数、噪声数7、标签数
    # generator
    g_outputs = generator(real_data_digit, noise_data, args.g_units, np.shape(dataMat))
    # discriminator
    d_logits_real, d_outputs_real = discriminator(real_data_digit, real_data, args.d_units)
    #真实数据和生成器生成的数据的判别
    d_logits_fake, d_outputs_fake = discriminator(real_data_digit, g_outputs, args.d_units, reuse=True)

    # Loss
    # discriminator的loss
    # 识别真实图片
    #tf.reduce_mean计算张量某一维度的平均值
    # 真实图片往1方向优化,sigmoid_cross_entropy_with_logits和sigmoid_softmax_entropy_with_logits一样,只是二分类
    #目的衡量分类任务中的概率误差
    #tf.ones_like创建一个都是1的张量
    d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real,
                                                                         labels=tf.ones_like(d_logits_real)) * (
                                             1 - args.smooth))
    # 识别生成的图片
    # 生成图片往0方向优化
    d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
                                                                         labels=tf.zeros_like(d_logits_fake)))
    # 总体loss
    d_loss = tf.add(d_loss_real, d_loss_fake)
    # generator的loss
    # 生成图片尽量往1方向优化
    g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
                                                                  labels=tf.ones_like(d_logits_fake)) * (1 - args.smooth))

    # Optimizer

    train_vars = tf.trainable_variables()

    # generator中的tensor
    g_vars =
    # discriminator中的tensor
    d_vars =

    # optimizer
    d_train_opt = tf.train.AdamOptimizer(args.learning_rate).minimize(d_loss, var_list=d_vars)
    g_train_opt = tf.train.AdamOptimizer(args.learning_rate).minimize(g_loss, var_list=g_vars)
    # 存储测试样例
    samples = []
    # 存储loss
    losses = []
    # 保存生成器变量
    saver = tf.train.Saver(var_list=g_vars)
    if flag == 0:
      #初始化模型参数
      ds=DataSet(dataMat, classLabels, 10)
      sess.run(tf.global_variables_initializer())
      for e in range(args.epochs):
            for batch_i in range(dataMat.shape // args.batch_size):
                batch = ds.next_batch(args.batch_size)#batch0是特征,batch1是标签
                # print("batch0\n",np.array(batch).shape)#结果为3
                # print("batch0\n", np.array(batch).shape)#结果为7
                digit = batch#,标签,这里要修改
                # print("batch\n",batch)#结果为
                batch_datas = batch.reshape((args.batch_size, 6))
                digits=digit.reshape((args.batch_size, 1))#真实数据标签集
                # print("digits\n",digits)
                # 对图像像素进行scale,这是因为tanh输出的结果介于(-1,1),real和fake图片共享discriminator的参数
                # 把图片灰度0~1变成 -1 到1的值, 以适应generator输出的结果(-1,1)
                batch_datas = batch_datas * 2 - 1

                # generator的输入噪声
                batch_noise = np.random.uniform(-1, 1, size=(args.batch_size, args.noise_size))

                # Run optimizers
                _ = sess.run(d_train_opt,#不断优化辨别器
                           feed_dict={real_data_digit: digits, real_data: batch_datas, noise_data: batch_noise})
                _ = sess.run(g_train_opt, feed_dict={real_data_digit: digits, noise_data: batch_noise})

            # 每一轮结束计算loss
            train_loss_d = sess.run(d_loss,
                                    feed_dict={real_data_digit: digits,
                                             real_data: batch_datas,
                                             noise_data: batch_noise})
            # real img loss
            train_loss_d_real = sess.run(d_loss_real,
                                       feed_dict={real_data_digit: digits,
                                                    real_data: batch_datas,
                                                    noise_data: batch_noise})
            # fake img loss
            train_loss_d_fake = sess.run(d_loss_fake,
                                       feed_dict={real_data_digit: digits,
                                                    real_data: batch_datas,
                                                    noise_data: batch_noise})
            # generator loss
            train_loss_g = sess.run(g_loss,
                                    feed_dict={real_data_digit: digits, noise_data: batch_noise})

            print("Epoch {}/{}...".format(e + 1, args.epochs),
                  "Discriminator Loss: {:.4f}(Real: {:.4f} + Fake: {:.4f})...".format(train_loss_d, train_loss_d_real,
                                                                                    train_loss_d_fake),
                  "Generator Loss: {:.4f}".format(train_loss_g))
            # 记录各类loss值
            losses.append((train_loss_d, train_loss_d_real, train_loss_d_fake, train_loss_g))

            # 抽取样本后期进行观察
            n_sample = 10
            #-1采样下界;1采样上届;size为int型或者tuple型,n_sample行noise_size列
            sample_noise = np.random.uniform(-1, 1, size=(n_sample, args.noise_size))
            #print("sample_noise",sample_noise)
            gen_samples = sess.run(generator(real_data_digit, noise_data, args.g_units, np.shape(dataMat), reuse=True),
                                 feed_dict={real_data_digit: digits, noise_data: sample_noise})
            samples.append(gen_samples)
            #print("samples",samples)
            #print("type(samples)",type(samples))list类型
            sampless=np.array(samples)
            # print("sampless", sampless)
            # print("sampless.shape\n",sampless.shape)
            sampless=sampless.reshape(sampless.shape * sampless.shape,sampless.shape)
            # print("sampless.reshape\n", sampless.shape)
            # 存储checkpoints,这是一个二进制文件,它保存了权重、偏置项、梯度以及其他所有的变量的取值,扩展名为.ckpt
            saver.save(sess, './checkpoints/generator.ckpt')
      #print("loss",losses)回头生成表格记录一下
      # 将sample的生成数据记录下来待修改,将数据存储在表格中
      np.savetxt("train_samples.csv", sampless, delimiter=",")
      # write_file = open('train_samples.csv', 'wb')
      # for i in range(len(sampless)):
      #   for j in range(7):
      #         # np.savetxt('train_samples.csv',samples+'\n',fmt ='%d')
      #         write_file.write(sampless)
      #   write_file.write('\n')
      # write_file.close()

      # with open('train_samples.pkl', 'wb') as f:
      #   pickle.dump(samples, f)

      # with open('train_loss.txt', 'wb') as f:
      #   pickle.dump(losses, f)

      losses = np.array(losses)
      plt.plot(losses.T, label='Discriminator Total Loss')
      plt.plot(losses.T, label='Discriminator Real Loss')
      plt.plot(losses.T, label='Discriminator Fake Loss')
      plt.plot(losses.T, label='Generator')
      plt.title("Training Losses")
      plt.legend()
      plt.show()
    else:
      saver.restore(sess, './checkpoints/generator.ckpt')
      sample_noise = np.random.uniform(-1, 1, size=(25, args.noise_size))

      # 生成标签用户生成图片
      digits = np.zeros((25, k))
      for i in range(0, 25):
            j = np.random.randint(0, 9, 1)
            digits = 1

      print(digits)
      gen_samples = sess.run(generator(real_data_digit, noise_data, args.g_units, np.shape(dataMat), reuse=True),
                               feed_dict={real_data_digit: digits, noise_data: sample_noise})

      print("gen_samples",gen_samples)

    return sampless









页: [1]
查看完整版本: GAN用于生二维数据