>

목표 : TFX->TF Lite 변환기->모바일/IoT 기기에 모델 배포

현재 Tensorflow Extended 를 배우고 있습니다. .com/tensorflow/tfx/tree/master/examples/chicago_taxi_pipeline "rel ="nofollow noreferrer ">시카고 택시 파이프 라인 예 . 파이프 라인은 (많은 어려움을 겪었음에도 불구하고) 실행되고Pusher구성 요소가Tensorflow SavedModel파일 (.pb)을 방출했습니다.

그러나 여기서 새로운 문제가 발생합니다. Tensorflow nightly/1.13.1 (둘 다 시도)과 Python 2.7.6에 의해 간단한 python 코드로SavedModel(유틸리티 테스트를위한 mnist 숫자 데이터 모델)을 생성하고 저장하고로드 할 수 있습니다 saved_model.simple_save 와 같은  그리고 saved_model.loader.load , 그러나 다음과 같이TFX Pusher에서 방출하는 모델에 적용 할 때 오류가 계속 발생합니다.

(TFX 파이프 라인에 문제가 있었습니까?)

내가 사용한 코드 :
import tensorflow as tf
with tf.Session(graph=tf.Graph()) as sess:
    tf.compat.v1.saved_model.loader.load(sess, ["serve"], "/home/tigerpaws/taxi/serving_model/taxi_simple/1553187887")#"/home/tigerpaws/saved_model_example/model")
    graph=tf.get_default_graph()

오류 :

KeyError                                  Traceback (most recent call last)
<ipython-input-11-a6978b82c3d2> in <module>()
      1 with tf.Session(graph=tf.Graph()) as sess:
----> 2     tf.compat.v1.saved_model.loader.load(sess, ["serve"], "/home/tigerpaws/taxi/serving_model/taxi_simple/1553187887")#"/home/tigerpaws/saved_model_example/model")
      3     graph=tf.get_default_graph()
/home/tigerpaws/anaconda/lib/python2.7/site-packages/tensorflow/python/util/deprecation.pyc in new_func(*args, **kwargs)
    322               'in a future version' if date is None else ('after %s' % date),
    323               instructions)
--> 324       return func(*args, **kwargs)
    325     return tf_decorator.make_decorator(
    326         func, new_func, 'deprecated',
/home/tigerpaws/anaconda/lib/python2.7/site-packages/tensorflow/python/saved_model/loader_impl.pyc in load(sess, tags, export_dir, import_scope, **saver_kwargs)
    267   """
    268   loader = SavedModelLoader(export_dir)
--> 269   return loader.load(sess, tags, import_scope, **saver_kwargs)
    270 
    271 
/home/tigerpaws/anaconda/lib/python2.7/site-packages/tensorflow/python/saved_model/loader_impl.pyc in load(self, sess, tags, import_scope, **saver_kwargs)
    418     with sess.graph.as_default():
    419       saver, _ = self.load_graph(sess.graph, tags, import_scope,
--> 420                                  **saver_kwargs)
    421       self.restore_variables(sess, saver, import_scope)
    422       self.run_init_ops(sess, tags, import_scope)
/home/tigerpaws/anaconda/lib/python2.7/site-packages/tensorflow/python/saved_model/loader_impl.pyc in load_graph(self, graph, tags, import_scope, **saver_kwargs)
    348     with graph.as_default():
    349       return tf_saver._import_meta_graph_with_return_elements(  # pylint: disable=protected-access
--> 350           meta_graph_def, import_scope=import_scope, **saver_kwargs)
    351 
    352   def restore_variables(self, sess, saver, import_scope=None):
/home/tigerpaws/anaconda/lib/python2.7/site-packages/tensorflow/python/training/saver.pyc in _import_meta_graph_with_return_elements(meta_graph_or_file, clear_devices, import_scope, return_elements, **kwargs)
   1455           import_scope=import_scope,
   1456           return_elements=return_elements,
-> 1457           **kwargs))
   1458 
   1459   saver = _create_saver_from_imported_meta_graph(
/home/tigerpaws/anaconda/lib/python2.7/site-packages/tensorflow/python/framework/meta_graph.pyc in import_scoped_meta_graph_with_return_elements(meta_graph_or_file, clear_devices, graph, import_scope, input_map, unbound_inputs_col_name, restore_collections_predicate, return_elements)
    804         input_map=input_map,
    805         producer_op_list=producer_op_list,
--> 806         return_elements=return_elements)
    807 
    808     # Restores all the other collections.
/home/tigerpaws/anaconda/lib/python2.7/site-packages/tensorflow/python/util/deprecation.pyc in new_func(*args, **kwargs)
    505                 'in a future version' if date is None else ('after %s' % date),
    506                 instructions)
--> 507       return func(*args, **kwargs)
    508 
    509     doc = _add_deprecated_arg_notice_to_docstring(
/home/tigerpaws/anaconda/lib/python2.7/site-packages/tensorflow/python/framework/importer.pyc in import_graph_def(graph_def, input_map, return_elements, name, op_dict, producer_op_list)
    397   if producer_op_list is not None:
    398     # TODO(skyewm): make a copy of graph_def so we're not mutating the argument?
--> 399     _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)
    400 
    401   graph = ops.get_default_graph()
/home/tigerpaws/anaconda/lib/python2.7/site-packages/tensorflow/python/framework/importer.pyc in _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)
    157     # Remove any default attr values that aren't in op_def.
    158     if node.op in producer_op_dict:
--> 159       op_def = op_dict[node.op]
    160       producer_op_def = producer_op_dict[node.op]
    161       # We make a copy of node.attr to iterate through since we may modify
KeyError: u'BucketizeWithInputBoundaries'

SavedModelGraphDef (냉동 그래프)으로 변환하려고 시도한 또 다른 시도가 있었으므로 변환기에 다시 시도해 볼 수 있습니다. 변환에는 output_node_names 가 필요합니다 나는 모른다. 코드에서 모델이 저장된 위치도 찾을 수 없습니다 (아마도 출력 노드 이름을 찾을 수 있습니다).

문제 나 다른 방법에 대한 아이디어가 있습니까? 미리 감사드립니다.

수정 :누군가 태그 생성을 도와 줄 수 있습니까? 나는 1500의 명성에 도달하지 못했지만이 질문은 실제로 tfx 에 관한 것입니다.  / tensorflow-extended

  • 답변 # 1

    혼란을 일으켜 죄송합니다. 실제로 저장된 모델 파일을 읽음으로써 문제가 발생합니다.

    SavedModel에는 BucketizeWithInputBoundaries 작업이 있습니다 op_dict 에 정의되어 있지 않습니다. .

    이것은 여전히 ​​구글의 TODO 목록에 있으며, 두 스크립트로 주석 처리되었습니다.

    여기와 여기. (Github 링크) :

    # TODO(jyzhao): BucketizeWithInputBoundaries error without this.
    
    

    지정된 스크립트를 가져온 후이 문제가 해결되었습니다.

    from tensorflow.contrib.boosted_trees.python.ops import quantile_ops  # pylint: disable=unused-import
    
    

  • 이전 python - 다차원 배열을 2D로 변환 및 후속 인덱싱
  • 다음 배치 스크립트를 사용하여 파일 이름에서 하위 문자열을 어떻게 가져올 수 있습니까?