Dcgan网络实例,DCGAN

发表时间:2020-10-15

关于keras包的导入会出现报错问题tensorflow.python.framework.errors_impl.FailedPreconditionError: Error while reading resource variable _AnonymousVar42 from Container: localhost.
关于这个问题可以尝试两种方法:
1.降低keras包的版本到2.0.0以下。
2.采用如下导入tensorflow包的格式如下代码所示。本人仅尝试了第二种方法,可以运行。

from __future__ import print_function, division

from tensorflow.keras.datasets import mnist

from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout
from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D
from tensorflow.python.keras.layers.advanced_activations import LeakyReLU
from tensorflow.python.keras.layers.convolutional import UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam

import matplotlib.pyplot as plt

import sys

import numpy as np

class GAN():
    def __init__(self):
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)

        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
                                   optimizer=optimizer,
                                   metrics=['accuracy'])

        # Build and compile the generator
        self.generator = self.build_generator()
        self.generator.compile(loss='binary_crossentropy', optimizer=optimizer)

        # The generator takes noise as input and generated imgs
        z = Input(shape=(100,))
        img = self.generator(z)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The valid takes generated images as input and determines validity
        valid = self.discriminator(img)

        # The combined model  (stacked generator and discriminator) takes
        # noise as input => generates images => determines validity
        self.combined = Model(z, valid)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

    def build_generator(self):

        noise_shape = (100,)

        model = Sequential()

        model.add(Dense(256, input_shape=noise_shape))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))

        model.summary()

        noise = Input(shape=noise_shape)
        img = model(noise)

        return Model(noise, img)

    def build_discriminator(self):

        img_shape = (self.img_rows, self.img_cols, self.channels)

        model = Sequential()

        model.add(Flatten(input_shape=img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()

        img = Input(shape=img_shape)
        validity = model(img)

        return Model(img, validity)

    def train(self, epochs, batch_size=128, save_interval=50):

        # Load the dataset
        (X_train, _), (_, _) = mnist.load_data()

        # Rescale -1 to 1
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        X_train = np.expand_dims(X_train, axis=3)

        half_batch = int(batch_size / 2)

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random half batch of images
            idx = np.random.randint(0, X_train.shape[0], half_batch)
            imgs = X_train[idx]

            noise = np.random.normal(0, 1, (half_batch, 100))

            # Generate a half batch of new images
            gen_imgs = self.generator.predict(noise)

            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            noise = np.random.normal(0, 1, (batch_size, 100))

            # The generator wants the discriminator to label the generated samples
            # as valid (ones)
            valid_y = np.array([1] * batch_size)

            # Train the generator
            g_loss = self.combined.train_on_batch(noise, valid_y)

            # Plot the progress
            if epoch % 1000 == 0:
                print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss))

            # If at save interval => save generated image samples
            if epoch % save_interval == 0:
                self.save_imgs(epoch)

    def save_imgs(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, 100))
        gen_imgs = self.generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
                axs[i, j].axis('off')
                cnt += 1
        fig.savefig("images/%d.png" % epoch)
        plt.close()


if __name__ == '__main__':
    gan = GAN()
    gan.train(epochs=30000, batch_size=32, save_interval=200)

Keras参考
1.1. 关于Keras模型
Keras有两种类型的模型,顺序模型(Sequential)和泛型模型(Model),两类模型有一些方法是相同的:
model.summary():打印出模型概况
model.get_config():返回包含模型配置信息的Python字典。模型也可以从它的config信息中重构回去

泛型模型接口
Keras的泛型模型为Model,即广义的拥有输入和输出的模型。

常用Model属性
model.layers:组成模型图的各个层
model.inputs:模型的输入张量列表
model.outputs:模型的输出张量列表
Model模型方法

compile
compile(self, optimizer, loss, metrics=[], loss_weights=None, sample_weight_mode=None)
本函数编译模型以供训练,参数有:

optimizer:优化器,为预定义优化器名或优化器对象,参考优化器
loss:目标函数,为预定义损失函数名或一个目标函数,参考目标函数
metrics:列表,包含评估模型在训练和测试时的性能的指标,典型用法是metrics=[‘accuracy’]如果要在多输出模型中为不同的输出指定不同的指标,可像该参数传递一个字典,例如metrics={‘ouput_a’: ‘accuracy’}

predict
predict(self, x, batch_size=32, verbose=0)
本函数按batch获得输入数据对应的输出,其参数有:

函数的返回值是预测值的numpy array

train_on_batch
train_on_batch(self, x, y, class_weight=None, sample_weight=None)
本函数在一个batch的数据上进行一次参数更新

函数返回训练误差的标量值或标量值的list,与evaluate的情形相同。

1.2. 关于优化器
优化器optimizers,优化器是编译Keras模型必要的两个参数之一,可以在调用model.compile()之前初始化一个优化器对象,然后传入该函数(如上所示),也可以在调用model.compile()时传递一个预定义优化器名。在后者情形下,优化器的参数将使用默认值。

Adam
keras.optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
Adam 优化器,默认参数遵循原论文中提供的值,参数:

lr: float >= 0. 学习率。
beta_1: float, 0 < beta < 1. 通常接近于 1。
beta_2: float, 0 < beta < 1. 通常接近于 1。
epsilon: float >= 0. 模糊因子. 若为 None, 默认为 K.epsilon()。
decay: float >= 0. 每次参数更新后学习率衰减值。
amsgrad: boolean. 是否应用此算法的 AMSGrad 变种,来自论文 “On the Convergence of Adam and Beyond”。
1.3. 其他
UpSampling2D
keras.layers.UpSampling2D(size=(2, 2), data_format=None, interpolation=‘nearest’)
2D 输入的上采样层,沿着数据的行和列分别重复 size[0] 和 size[1] 次,参数:

size: 整数,或 2 个整数的元组。 行和列的上采样因子。
data_format: 字符串, channels_last (默认) 或 channels_first 之一, 表示输入中维度的顺序。channels_last 对应输入尺寸为 (batch, height, width, channels), channels_first 对应输入尺寸为 (batch, channels, height, width)。 它默认为从 Keras 配置文件 ~/.keras/keras.json 中 找到的 image_data_format 值。 如果你从未设置它,将使用 “channels_last”。
interpolation: 字符串,nearest 或 bilinear 之一。 注意 CNTK 暂不支持 bilinear upscaling, 以及对于 Theano,只可以使用 size=(2, 2)。
输入尺寸
如果 data_format 为 “channels_last”, 输入 4D 张量,尺寸为 (batch, rows, cols, channels)。
如果 data_format 为 “channels_first”, 输入 4D 张量,尺寸为 (batch, channels, rows, cols)。

输出尺寸
如果 data_format 为 “channels_last”, 输出 4D 张量,尺寸为 (batch, upsampled_rows, upsampled_cols, channels)。
如果 data_format 为 “channels_first”, 输出 4D 张量,尺寸为 (batch, channels, upsampled_rows, upsampled_cols)。

文章来源互联网,如有侵权,请联系管理员删除。邮箱:417803890@qq.com / QQ:417803890

微配音

Python Free

邮箱:417803890@qq.com
QQ:417803890

皖ICP备19001818号
© 2019 copyright www.pythonf.cn - All rights reserved

微信扫一扫关注公众号:

联系方式

Python Free