基于keras的图像分割,使用,做

发表时间:2020-09-29

先 看一下效果 ,虽然 算不上很好

直接上代码:

import tensorflow as tf
import matplotlib.pyplot as plt
import  os
import time
import numpy as np
import io
import PIL
from IPython.display import clear_output
import cv2
import sys
sys.path.append("/opt/LIP/examples")
from tensorflow_examples.models.pix2pix import pix2pix

IMG_WIDTH = 128
IMG_WIDTH = 128
IM_PATH='/opt/LIP/images/'
MS_PATH='/opt/LIP/masks/'
OUTPUT_CHANNELS = 20
EPOCHS = 20
BATCH_SIZE=256

def display(display_list):
    plt.figure(figsize=(15, 15))
    title = ['Input Image', 'True Mask', 'Predicted Mask']
    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
        plt.axis('off')
    plt.show()


def load_input(image_file):
    print(image_file)
    #print(str(image_file))
    img=tf.io.read_file(image_file)
    img=tf.image.decode_jpeg(img, channels=3)
    image=tf.image.resize(img, [IMG_WIDTH, IMG_WIDTH])
    image=(image / 127.5) - 1#normalizing the images to [-1, 1]
    #image=image /255.0
    #image=image.reshape()
    #image=tf.reshape(image,[1,IMG_WIDTH,IMG_WIDTH,3])
    return image
def load_mask(image_file):
    img=tf.io.read_file(image_file)
    img=tf.image.decode_png(img, channels=1)
    image=tf.image.resize(img, [IMG_WIDTH, IMG_WIDTH])
    #image -= 1
    #image=image.reshape(1,IMG_WIDTH,IMG_WIDTH,3)
    #image=tf.reshape(image,[1,IMG_WIDTH,IMG_WIDTH,3])
    return image
def load(image_file,mask_file):
    _in=load_input(image_file)
    _mask=load_mask(mask_file)
    return _in,_mask

train_image_path = os.path.join(IM_PATH+'train/')
train_mask_path = os.path.join(MS_PATH+'train/')
train_images = os.listdir(train_image_path)
train_masks = os.listdir(train_mask_path)
train_images.sort()
train_masks.sort()
train_ls_images=[]
train_ls_masks=[]
for i in train_images:
    train_ls_images.append(IM_PATH+'train/'+i)
for j in train_masks:
    train_ls_masks.append(MS_PATH+'train/'+j)
train_images = tf.constant(train_ls_images)
train_labels = tf.constant(train_ls_masks)
train_data = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
train_data = train_data.map(load, num_parallel_calls=4)
train_batched_data = train_data.batch(BATCH_SIZE)

val_image_path = os.path.join(IM_PATH+'val/')
val_mask_path = os.path.join(MS_PATH+'val/')
val_images = os.listdir(val_image_path)
val_masks = os.listdir(val_mask_path)
val_images.sort()
val_masks.sort()
val_ls_images=[]
val_ls_masks=[]
for i in val_images:
    val_ls_images.append(IM_PATH+'val/'+i)
for j in val_masks:
    val_ls_masks.append(MS_PATH+'val/'+j)
val_images = tf.constant(val_ls_images)
val_labels = tf.constant(val_ls_masks)
val_data = tf.data.Dataset.from_tensor_slices((val_images, val_labels))
val_data = val_data.map(load, num_parallel_calls=4)
val_batched_data = val_data.batch(BATCH_SIZE)


base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False)
# Use the activations of these layers
layer_names = [
    'block_1_expand_relu',   # 64x64
    'block_3_expand_relu',   # 32x32
    'block_6_expand_relu',   # 16x16
    'block_13_expand_relu',  # 8x8
    'block_16_project',      # 4x4
]
layers = [base_model.get_layer(name).output for name in layer_names]
# Create the feature extraction model
down_stack = tf.keras.Model(inputs=base_model.input, outputs=layers)
down_stack.trainable = False
up_stack = [
    pix2pix.upsample(512, 3),  # 4x4 -> 8x8
    pix2pix.upsample(256, 3),  # 8x8 -> 16x16
    pix2pix.upsample(128, 3),  # 16x16 -> 32x32
    pix2pix.upsample(64, 3),   # 32x32 -> 64x64
]

def unet_model(output_channels):
    inputs = tf.keras.layers.Input(shape=[128, 128, 3])
    x = inputs
    # Downsampling through the model
    skips = down_stack(x)
    x = skips[-1]
    skips = reversed(skips[:-1])
    for up, skip in zip(up_stack, skips):
        x = up(x)
        concat = tf.keras.layers.Concatenate()
        x = concat([x, skip])
    last = tf.keras.layers.Conv2DTranspose(
        output_channels, 3, strides=2,
        padding='same')  #64x64 -> 128x128
    x = last(x)
    return tf.keras.Model(inputs=inputs, outputs=x)

model = unet_model(OUTPUT_CHANNELS)
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
tf.keras.utils.plot_model(model, show_shapes=True)


for image, mask in train_data.take(1):
    sample_image, sample_mask = image, mask
display([sample_image, sample_mask])

def create_mask(pred_mask):
    pred_mask = tf.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    return pred_mask[0]

def show_predictions(dataset=None, num=1):
    if dataset:
        for image, mask in dataset.take(num):
            pred_mask = model.predict(image)
            display([image[0], mask[0], create_mask(pred_mask)])
    else:
        display([sample_image, sample_mask,
             create_mask(model.predict(sample_image[tf.newaxis, ...]))])

show_predictions()


class DisplayCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        clear_output(wait=True)
        show_predictions()
        print ('\nSample Prediction after epoch {}\n'.format(epoch+1))

model_history = model.fit(train_batched_data, epochs=EPOCHS,
                          #validation_data=val_batched_data,
                          callbacks=[DisplayCallback()]
                         )

说明:

执行这个代码之前需要的准备工作:

(1)下载数据集,此次数据集使用的是LIP数据集,不过此次我对数据集的下载好的数据位置进行了调整,方便数据读取

(2)下载 tensorflow_examples 的文件放置在指定的路径(我是放置在/opt/LIP/examples)

(3)一定要使用tf-2.3版本,以下的版本很容易出错,特别的低于2.0的

最终保存了模型文件 ,大小在60兆左右

本次代码主要参考: https://tensorflow.google.cn/tutorials/images/segmentation

准确率:

效果不是很好,但是也还可以使用了 ,毕竟也就训练了20轮

对于模型文件MobileNetV2,如果代码拉取下载很慢可以提前下好放在这个路径下面就行:/root/.keras/ models/

可以看一下:

因为这个自动下载的路径 就是下载到/root/.keras/ models/这里

文章来源互联网,如有侵权,请联系管理员删除。邮箱:417803890@qq.com / QQ:417803890

微配音

Python Free

邮箱:417803890@qq.com
QQ:417803890

皖ICP备19001818号
© 2019 copyright www.pythonf.cn - All rights reserved

微信扫一扫关注公众号:

联系方式

Python Free