威望0
积分7946
贡献0
在线时间763 小时
UID1
注册时间2021-4-14
最后登录2024-11-21
管理员
- UID
- 1
- 威望
- 0
- 积分
- 7946
- 贡献
- 0
- 注册时间
- 2021-4-14
- 最后登录
- 2024-11-21
- 在线时间
- 763 小时
|
[mw_shl_code=python,true]import argparse
import distutils.util
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import numpy as np
import pickle
import matplotlib.pyplot as plt
import os
from new_batch import DataSet
from tensorflow.examples.tutorials.mnist import input_data
parser = argparse.ArgumentParser()
parser.add_argument("-g", "--gpu_number", type=str, default="0")
# 1:train 0:test
parser.add_argument("-f", "--flag", type=float, default=0)
# Train Iteration
parser.add_argument("-e", "--epochs", type=int, default=30)
# Train Parameter
parser.add_argument("-b", "--batch_size", type=int, default=10)
parser.add_argument("-lr", "--learning_rate", type=float, default=0.001)
parser.add_argument("-alpha", type=float, default=0.01,help='leaky ReLU')
parser.add_argument("-noise_size", type=int, default=20,help='generator noise size')
parser.add_argument("-g_units", type=int, default=6,help='generator units')
parser.add_argument("-d_units", type=int, default=6,help='discriminator units')
parser.add_argument("-smooth", type=float, default=0.1,help='label smoothing')
# 属性给与args实例: 把parser中设置的所有"add_argument"给返回到args子类实例当中, 那么parser中增加的属性内容都会在args实例中,使用即可。
args = parser.parse_args()
def GANbalance(X_train, Y_train):
tf.reset_default_graph()
with tf.device('/gpu:{0}'.format(args.gpu_number)):
gpu_options = tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=0.90)
config = tf.compat.v1.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options)
with tf.compat.v1.Session(config=config) as sess:
# TRAIN / TEST
if args.flag == 0:
samples= train(sess, 0,X_train,Y_train)
# show_result()
else:
train(sess, 1,X_train,Y_train)
return samples
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_number
k = 1
def inputs(real_size, noise_size):
real_digit = tf.placeholder(tf.float32, [None, k], name='real_digit')
real_data = tf.placeholder(tf.float32, [None, real_size], name='real_data')
noise_data = tf.placeholder(tf.float32, [None, noise_size], name='noise_data')
return real_data, noise_data, real_digit
#generator(real_data_digit, noise_data, g_units, np.shape(dataMat)[1])这里确实生成6个特征,没有标签
def generator(digit, noise_data, n_units, out_dim, reuse=False, alpha=0.01):
"""
生成器
noise_data: 生成器的输入
n_units: 隐层单元个数
out_dim: 生成器输出tensor的size,这里应该为7
alpha: leaky ReLU系数
"""
with tf.variable_scope("generator", reuse=reuse):
concatenated_data_digit = tf.concat([digit, noise_data], 1)#真实数据标签+噪声数据组合在一起
# hidden layer
hidden1 = tf.layers.dense(concatenated_data_digit, n_units)
# leaky ReLU ,和ReLU区别:ReLU是将所有的负值都设为零,相反,Leaky ReLU是给所有负值赋予一个非零斜率。
hidden1 = tf.maximum(alpha * hidden1, hidden1)
# dropout
hidden1 = tf.layers.dropout(hidden1, rate=0.2)
# logits & outputs
logits = tf.layers.dense(hidden1, out_dim)
outputs = tf.tanh(logits)
#print("outputs",outputs)
return outputs
#discriminator(real_data_digit, g_outputs, d_units, reuse=True)
def discriminator(digit, data, n_units, reuse=False, alpha=0.01):
"""
判别器
n_units: 隐层结点数量
alpha: Leaky ReLU系数
"""
with tf.variable_scope("discriminator", reuse=reuse):
concatenated_data_digit = tf.concat([digit, data], 1)#真实数据标签和生成器生成的数据组合
# hidden layer
hidden1 = tf.layers.dense(concatenated_data_digit, n_units)
hidden1 = tf.maximum(alpha * hidden1, hidden1)
# logits & outputs
logits = tf.layers.dense(hidden1, 1)
#print("logits------",logits)
outputs = tf.sigmoid(logits)
return logits, outputs
def train(sess, flag,dataMat,classLabels):
# tf.reset_default_graph()
real_data, noise_data, real_data_digit = inputs(np.shape(dataMat)[1], args.noise_size)#分别表示特征数、噪声数7、标签数
# generator
g_outputs = generator(real_data_digit, noise_data, args.g_units, np.shape(dataMat)[1])
# discriminator
d_logits_real, d_outputs_real = discriminator(real_data_digit, real_data, args.d_units)
#真实数据和生成器生成的数据的判别
d_logits_fake, d_outputs_fake = discriminator(real_data_digit, g_outputs, args.d_units, reuse=True)
# Loss
# discriminator的loss
# 识别真实图片
#tf.reduce_mean计算张量某一维度的平均值
# 真实图片往1方向优化,sigmoid_cross_entropy_with_logits和sigmoid_softmax_entropy_with_logits一样,只是二分类
#目的衡量分类任务中的概率误差
#tf.ones_like创建一个都是1的张量
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real,
labels=tf.ones_like(d_logits_real)) * (
1 - args.smooth))
# 识别生成的图片
# 生成图片往0方向优化
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
labels=tf.zeros_like(d_logits_fake)))
# 总体loss
d_loss = tf.add(d_loss_real, d_loss_fake)
# generator的loss
# 生成图片尽量往1方向优化
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
labels=tf.ones_like(d_logits_fake)) * (1 - args.smooth))
# Optimizer
train_vars = tf.trainable_variables()
# generator中的tensor
g_vars = [var for var in train_vars if var.name.startswith("generator")]
# discriminator中的tensor
d_vars = [var for var in train_vars if var.name.startswith("discriminator")]
# optimizer
d_train_opt = tf.train.AdamOptimizer(args.learning_rate).minimize(d_loss, var_list=d_vars)
g_train_opt = tf.train.AdamOptimizer(args.learning_rate).minimize(g_loss, var_list=g_vars)
# 存储测试样例
samples = []
# 存储loss
losses = []
# 保存生成器变量
saver = tf.train.Saver(var_list=g_vars)
if flag == 0:
#初始化模型参数
ds=DataSet(dataMat, classLabels, 10)
sess.run(tf.global_variables_initializer())
for e in range(args.epochs):
for batch_i in range(dataMat.shape[0] // args.batch_size):
batch = ds.next_batch(args.batch_size)#batch0是特征,batch1是标签
# print("batch0\n",np.array(batch[0]).shape[0])#结果为3
# print("batch0\n", np.array(batch[0]).shape[1])#结果为7
digit = batch[1]#,标签,这里要修改
# print("batch[1]\n",batch[1])#结果为[1.1.1]
batch_datas = batch[0].reshape((args.batch_size, 6))
digits=digit.reshape((args.batch_size, 1))#真实数据标签集
# print("digits\n",digits)
# 对图像像素进行scale,这是因为tanh输出的结果介于(-1,1),real和fake图片共享discriminator的参数
# 把图片灰度0~1变成 -1 到1的值, 以适应generator输出的结果(-1,1)
batch_datas = batch_datas * 2 - 1
# generator的输入噪声
batch_noise = np.random.uniform(-1, 1, size=(args.batch_size, args.noise_size))
# Run optimizers
_ = sess.run(d_train_opt,#不断优化辨别器
feed_dict={real_data_digit: digits, real_data: batch_datas, noise_data: batch_noise})
_ = sess.run(g_train_opt, feed_dict={real_data_digit: digits, noise_data: batch_noise})
# 每一轮结束计算loss
train_loss_d = sess.run(d_loss,
feed_dict={real_data_digit: digits,
real_data: batch_datas,
noise_data: batch_noise})
# real img loss
train_loss_d_real = sess.run(d_loss_real,
feed_dict={real_data_digit: digits,
real_data: batch_datas,
noise_data: batch_noise})
# fake img loss
train_loss_d_fake = sess.run(d_loss_fake,
feed_dict={real_data_digit: digits,
real_data: batch_datas,
noise_data: batch_noise})
# generator loss
train_loss_g = sess.run(g_loss,
feed_dict={real_data_digit: digits, noise_data: batch_noise})
print("Epoch {}/{}...".format(e + 1, args.epochs),
"Discriminator Loss: {:.4f}(Real: {:.4f} + Fake: {:.4f})...".format(train_loss_d, train_loss_d_real,
train_loss_d_fake),
"Generator Loss: {:.4f}".format(train_loss_g))
# 记录各类loss值
losses.append((train_loss_d, train_loss_d_real, train_loss_d_fake, train_loss_g))
# 抽取样本后期进行观察
n_sample = 10
#-1采样下界;1采样上届;size为int型或者tuple型,n_sample行noise_size列
sample_noise = np.random.uniform(-1, 1, size=(n_sample, args.noise_size))
#print("sample_noise",sample_noise)
gen_samples = sess.run(generator(real_data_digit, noise_data, args.g_units, np.shape(dataMat)[1], reuse=True),
feed_dict={real_data_digit: digits, noise_data: sample_noise})
samples.append(gen_samples)
#print("samples",samples)
#print("type(samples)",type(samples)) list类型
sampless=np.array(samples)
# print("sampless", sampless)
# print("sampless.shape\n",sampless.shape)
sampless=sampless.reshape(sampless.shape[1] * sampless.shape[0],sampless.shape[2])
# print("sampless.reshape\n", sampless.shape)
# 存储checkpoints,这是一个二进制文件,它保存了权重、偏置项、梯度以及其他所有的变量的取值,扩展名为.ckpt
saver.save(sess, './checkpoints/generator.ckpt')
#print("loss",losses)回头生成表格记录一下
# 将sample的生成数据记录下来 待修改,将数据存储在表格中
np.savetxt("train_samples.csv", sampless, delimiter=",")
# write_file = open('train_samples.csv', 'wb')
# for i in range(len(sampless)):
# for j in range(7):
# # np.savetxt('train_samples.csv',samples+'\n',fmt ='%d')
# write_file.write(sampless[j])
# write_file.write('\n')
# write_file.close()
# with open('train_samples.pkl', 'wb') as f:
# pickle.dump(samples, f)
# with open('train_loss.txt', 'wb') as f:
# pickle.dump(losses, f)
losses = np.array(losses)
plt.plot(losses.T[0], label='Discriminator Total Loss')
plt.plot(losses.T[1], label='Discriminator Real Loss')
plt.plot(losses.T[2], label='Discriminator Fake Loss')
plt.plot(losses.T[3], label='Generator')
plt.title("Training Losses")
plt.legend()
plt.show()
else:
saver.restore(sess, './checkpoints/generator.ckpt')
sample_noise = np.random.uniform(-1, 1, size=(25, args.noise_size))
# 生成标签用户生成图片
digits = np.zeros((25, k))
for i in range(0, 25):
j = np.random.randint(0, 9, 1)
digits[j] = 1
print(digits)
gen_samples = sess.run(generator(real_data_digit, noise_data, args.g_units, np.shape(dataMat)[1], reuse=True),
feed_dict={real_data_digit: digits, noise_data: sample_noise})
print("gen_samples",gen_samples)
return sampless
[/mw_shl_code] |
|