>source

그래프에서 일부 노드를 삭제하고 .pb에 저장하려고합니다

필요한 노드 만 새로운 mod_graph_def 에 추가 할 수 있습니다  그래프,하지만 그래프에 여전히 다른 노드 입력에서 삭제 된 노드에 대한 일부 참조가 있지만 노드의 입력을 수정할 수 없습니다 :

def delete_ops_from_graph():
    with open(input_model_filepath, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    nodes = []
    for node in graph_def.node:
        if 'Neg' in node.name:
            print('Drop', node.name)
        else:
            nodes.append(node)
    mod_graph_def = tf.GraphDef()
    mod_graph_def.node.extend(nodes)
    # The problem that graph still have some references to deleted node in other nodes inputs
    for node in mod_graph_def.node:
        inp_names = []
        for inp in node.input:
            if 'Neg' in inp:
                pass
            else:
                inp_names.append(inp)
        node.input = inp_names # TypeError: Can't set composite field
    with open(output_model_filepath, 'wb') as f:
        f.write(mod_graph_def.SerializeToString())


  • 답변 # 1

    def delete_ops_from_graph():
        with open(input_model_filepath, 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
        # Delete nodes
        nodes = []
        for node in graph_def.node:
            if 'Neg' in node.name:
                print('Drop', node.name)
            else:
                nodes.append(node)
        mod_graph_def = tf.GraphDef()
        mod_graph_def.node.extend(nodes)
        # Delete references to deleted nodes
        for node in mod_graph_def.node:
            inp_names = []
            for inp in node.input:
                if 'Neg' in inp:
                    pass
                else:
                    inp_names.append(inp)
            del node.input[:]
            node.input.extend(inp_names)
        with open(output_model_filepath, 'wb') as f:
            f.write(mod_graph_def.SerializeToString())
    
    

관련 자료

  • 이전 r - 여러 열을 축소하고 축소 된 다른 수준/값에서 새 변수를 생성하려면 어떻게해야합니까?
  • 다음 node.js - npm 스크립트에서 2 개의 명령 실행 (nodemon&&sass --watch)