enable graph in rnn model
diff --git a/examples/rnn/train.py b/examples/rnn/train.py
index 0060d06..107f0f1 100644
--- a/examples/rnn/train.py
+++ b/examples/rnn/train.py
@@ -197,7 +197,7 @@
cuda = device.create_cuda_gpu()
model = CharRNN(data.vocab_size, hidden_size)
model.on_device(cuda)
- model.graph(True, True)
+ model.graph(True, False)
inputs, labels = None, None
diff --git a/src/core/scheduler/scheduler.cc b/src/core/scheduler/scheduler.cc
index 317a726..0989325 100644
--- a/src/core/scheduler/scheduler.cc
+++ b/src/core/scheduler/scheduler.cc
@@ -379,6 +379,45 @@
nodes_.push_back(node);
}
+void Graph::AddSyncOp(function<void(Context *)> &&op) {
+ // create new node
+ Node *node = new Node(nodes_.size(), std::move(op));
+
+ for (size_t i = 0; i < write_blocks_.size(); ++i) {
+ Block *blk = write_blocks_[i];
+ BlkInfo *blkInfo = blocks_[blk];
+ Edge *edge = nullptr;
+
+ if (blkInfo->type_ == BlockType::kEnd) {
+ blkInfo->type_ = BlockType::kInter;
+ }
+
+ Edge *write_edge = blkInfo->write_edge_;
+ if (!write_edge->dst_node_) {
+ // change the dst node of the write_edge
+ write_edge->dst_node_ = node;
+ edge = write_edge;
+ } else {
+ Node *src_node = write_edge->src_node_;
+ edge = new Edge(edges_.size(), blk, src_node, node);
+ src_node->AddOutEdge(edge);
+ edges_.push_back(edge);
+ }
+
+ node->AddInEdge(edge);
+
+ // fake edges, no need to add the graph ref
+ edge = new Edge(edges_.size(), blk, node, nullptr);
+ blkInfo->write_edge_ = edge;
+
+ node->AddOutEdge(edge);
+ edges_.push_back(edge);
+ }
+
+ // add node into nodes
+ nodes_.push_back(node);
+}
+
void Graph::Analyze() {
begin_nodes_.clear();
next_nodes_.resize(nodes_.size());
@@ -520,45 +559,6 @@
}
}
-void Graph::AddSyncOp(function<void(Context *)> &&op) {
- // create new node
- Node *node = new Node(nodes_.size(), std::move(op));
-
- for (size_t i = 0; i < write_blocks_.size(); ++i) {
- Block *blk = write_blocks_[i];
- BlkInfo *blkInfo = blocks_[blk];
- Edge *edge = nullptr;
-
- if (blkInfo->type_ == BlockType::kEnd) {
- blkInfo->type_ = BlockType::kInter;
- }
-
- Edge *write_edge = blkInfo->write_edge_;
- if (!write_edge->dst_node_) {
- // change the dst node of the write_edge
- write_edge->dst_node_ = node;
- edge = write_edge;
- } else {
- Node *src_node = write_edge->src_node_;
- edge = new Edge(edges_.size(), blk, src_node, node);
- src_node->AddOutEdge(edge);
- edges_.push_back(edge);
- }
-
- node->AddInEdge(edge);
-
- // fake edges, no need to add the graph ref
- edge = new Edge(edges_.size(), blk, node, nullptr);
- blkInfo->write_edge_ = edge;
-
- node->AddOutEdge(edge);
- edges_.push_back(edge);
- }
-
- // add node into nodes
- nodes_.push_back(node);
-}
-
/*
void CUDART_CB Graph::Callback(cudaStream_t stream, cudaError_t status,
void *data) {