Graph.h 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 class Node {public : int id () const { return id_; } int const_id () const { return cost_id_; } const string& name () const ; void set_name (string name) ; const string& type_string () const ; const NodeDef& def () const ; const OpDef& op_def () const ; private : friend class Graph ; Node (); int id_; int const_id_; NodeClass class_; EdgeSet in_edges_; EdgeSet out_edges_; Graph* graph_; };
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 class Edge {public : Node* src () const { return src_; } Node* dst () const { return dst_; } int id () const { return id_; } int src_output () const { return src_output_; } int dst_input () const { return dst_input_; } bool IsControlEdge () const ; private : Edge () {} friend class Graph ; Node* src_; Node* dst_; int id_; int src_output_; int dst_input_; };
控制依赖边,其src_output/dst_output
均为Graph::kControlSlot
(-1),意味着控制依赖边不承载任何数据。
计算图的普通边承载Tensor,并使用TensorId
标识,TensorId由二元组node_name:src_output
唯一标识,其中node_name
为边的前驱节点。src_output
缺省为0,即node_name
与node_name:0
等价,
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 class Graph {public : explicit Graph (const OpRegistryInterface* ops) ; explicit Graph (const FunctionLibraryDefinition& flib_def) ; ~Graph (); static const int kControlSlot; void set_versions (const VersionDef& versions) ; Node* AddNode (NodeDef node_def, Status* status) ; Node* CopyNode (const Node* node) ; void RemoveNode (Node* node) ; const Edge* AddEdge (Node* source, int x, Node* dest, int y) ; const Edge* AddControlEdge (Node* source, Node* dest, bool allow_deuplicates=false ) ; void RemoveEdge (const Edge* edge) ; void RemoveControlEdge (const Edge* edge) ; enum {kSourceId = 0 , kSinkId = 1 }; Node* FindNodeId (int id) const { return nodes_[id]; } Node* source_node () const { return FindNodeId (kSourceId); } Node* sink_node () const { return FindNodeId (kSinkId); } private : FunctionLibraryDefinition ops_; const std::unique_ptr<VersionDef> versions_; core::Arena arena_; vector<Node*> nodes_; int64 num_nodes_ = 0 ; vector<Edge*> edges_; int num_edges_ = 0 ; }; Graph::Graph (const OpRegistryInterface* ops) : ops_ (ops, FunctionDefLibrary ()), versions_ (new VersionDef), arena_ (8 << 10 ) { device_names.push_back ("" ); NodeDef def; def.set_name ("_SOURCE" ); def.set_op ("NoOp" ); Node* source = AddNode (def, &status); def.set_name ("_SINK" ); Node* sink = AddNode (def, &status); AddControlEdge (source, sink); }
Graph是一个DAG,按照拓扑排序运行,若存在多个入度为0的节点,则并行运行。初始状态,有一个起始节点Source和终止节点Sink,普通节点的id必大于1。
Source和Sink之间有一个控制依赖边,保证计算图的执行始于Source,止于Sink。