Merge pull request #787 from XJDKC/keep-order
Fix the graph operation when tensor is written by multiple independent ops
diff --git a/examples/rnn/imdb_train.py b/examples/rnn/imdb_train.py
index 669c5ad..4952639 100644
--- a/examples/rnn/imdb_train.py
+++ b/examples/rnn/imdb_train.py
@@ -108,9 +108,7 @@
num_layers=args.num_layers)
m.set_opt(opt.SGD(args.lr, 0.9))
-m.compile([tx], is_train=True, use_graph=True, sequential=True)
-# dry run
-out, loss = m(tx, ty)
+m.compile([tx], is_train=True, use_graph=True, sequential=False)
# training
m.train()
diff --git a/src/core/scheduler/scheduler.cc b/src/core/scheduler/scheduler.cc
index d752ec8..0172ee8 100644
--- a/src/core/scheduler/scheduler.cc
+++ b/src/core/scheduler/scheduler.cc
@@ -483,6 +483,25 @@
if (blkInfo->type_ == BlockType::kInput) {
blkInfo->type_ = BlockType::kParam;
}
+
+ Edge *write_edge = blkInfo->write_edge_;
+ if (write_edge) {
+ if (!write_edge->dst_node_) {
+ write_edge->dst_node_ = node;
+ node->AddInEdge(write_edge);
+ } else {
+ Node *lastNode = write_edge->src_node_;
+ auto outEdges = lastNode->out_edges();
+ for (auto outEdge : outEdges) {
+ if (outEdge->blk_ == blk && outEdge->dst_node_ != node) {
+ Edge *edge =
+ new Edge(edges_.size(), blk, outEdge->dst_node_, node);
+ outEdge->dst_node_->AddOutEdge(edge);
+ node->AddInEdge(edge);
+ }
+ }
+ }
+ }
}
// create new edge for new block
diff --git a/test/singa/test_scheduler.cc b/test/singa/test_scheduler.cc
index 2a2b516..c94f8f7 100644
--- a/test/singa/test_scheduler.cc
+++ b/test/singa/test_scheduler.cc
@@ -512,6 +512,47 @@
}
}
+TEST_F(TestGraph, MultipleIndependentOps) {
+ for (auto &it : devices) {
+ GOUT << "Test graph on device [" << it.first << "]" << std::endl;
+
+ auto dev = it.second;
+ Graph graph(dev.get());
+
+ auto &nodes = graph.nodes();
+
+ Tensor workspace(Shape{1}, dev);
+ Tensor b1(Shape{1}, dev);
+ Tensor b2(Shape{1}, dev);
+ Tensor b3(Shape{1}, dev);
+ Tensor b4(Shape{1}, dev);
+
+ // emulate clean up workspace, use the rnn design as reference
+ auto clean1 = [workspace](Context *ctx) mutable {};
+ auto clean2 = [workspace](Context *ctx) mutable {};
+ auto clean3 = [workspace](Context *ctx) mutable {};
+ auto clean4 = [workspace](Context *ctx) mutable {};
+
+ // emulate usage of workspace, use the rnn design as reference
+ auto op1 = [workspace, b1](Context *ctx) mutable {};
+ auto op2 = [workspace, b2](Context *ctx) mutable {};
+ auto op3 = [workspace, b2](Context *ctx) mutable {};
+ auto op4 = [workspace, b2](Context *ctx) mutable {};
+
+ graph.AddOperation(clean1, {}, {workspace.block()});
+ graph.AddOperation(op1, {b1.block()}, {workspace.block(), b1.block()});
+ graph.AddOperation(clean2, {}, {workspace.block()});
+ graph.AddOperation(op2, {b2.block()}, {workspace.block(), b2.block()});
+ graph.AddOperation(clean3, {}, {workspace.block()});
+ graph.AddOperation(op3, {b3.block()}, {workspace.block(), b3.block()});
+ graph.AddOperation(clean4, {}, {workspace.block()});
+ graph.AddOperation(op4, {b4.block()}, {workspace.block(), b4.block()});
+
+ EXPECT_EQ(8u, nodes.size());
+ graph.RunGraph();
+ }
+}
+
TEST_F(TestGraph, RunInSerial) {
for (auto &it : devices) {
GOUT << "Test graph on device [" << it.first << "]" << std::endl;