zoukankan      html  css  js  c++  java
  • photo2cartoon 中 face_seg.py 关于TensorFlow版本的问题修改

    import os
    import cv2
    import numpy as np
    import tensorflow as tf
    from tensorflow.python.platform import gfile
    
    
    curPath = os.path.abspath(os.path.dirname(__file__))
    
    
    class FaceSeg:
        def __init__(self, model_path=os.path.join(curPath, 'seg_model_384.pb')):
            #config = tf.ConfigProto()
            config = tf.compat.v1.ConfigProto()
            config.gpu_options.allow_growth = True
            self._graph = tf.Graph()
            #self._graph = tf.compat.v1.GraphDef()
    
            #self._sess = tf.Session(config=config, graph=self._graph)
            self._sess = tf.compat.v1.Session(config=config, graph=self._graph)  
    
            self.pb_file_path = model_path
            self._restore_from_pb()
            self.input_op = self._sess.graph.get_tensor_by_name('input_1:0')
            self.output_op = self._sess.graph.get_tensor_by_name('sigmoid/Sigmoid:0')
    
        def _restore_from_pb(self):
            with self._sess.as_default():
                with self._graph.as_default():
                    with gfile.FastGFile(self.pb_file_path, 'rb') as f:
                        #graph_def = tf.GraphDef()
                        graph_def = tf.compat.v1.GraphDef()
                        graph_def.ParseFromString(f.read())
                        tf.import_graph_def(graph_def, name='')
    
        def input_transform(self, image):
            image = cv2.resize(image, (384, 384), interpolation=cv2.INTER_AREA)
            image_input = (image / 255.)[np.newaxis, :, :, :]
            return image_input
    
        def output_transform(self, output, shape):
            output = cv2.resize(output, (shape[1], shape[0]))
            image_output = (output * 255).astype(np.uint8)
            return image_output
    
        def get_mask(self, image):
            image_input = self.input_transform(image)
            output = self._sess.run(self.output_op, feed_dict={self.input_op: image_input})[0]
            return self.output_transform(output, shape=image.shape[:2])
  • 相关阅读:
    TCP/IP 网路基础
    三、Django之请求与响应-Part 1
    二、Django快速安装
    Linux 优化详解
    缓存的正确使用方式
    HTML从入门到放弃
    Ansible开发之路
    程序猿数据库学习指南
    MySQL错误代码大全
    Python之网络编程
  • 原文地址:https://www.cnblogs.com/hxjbc/p/12836519.html
Copyright © 2011-2022 走看看