| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| using System; |
| using System.Collections.Concurrent; |
| using System.Collections.Generic; |
| using System.Globalization; |
| using System.IO; |
| using System.Linq; |
| using System.Net; |
| using System.Reactive; |
| using System.Threading; |
| using System.Threading.Tasks; |
| using Org.Apache.REEF.Common.Io; |
| using Org.Apache.REEF.Common.Tasks; |
| using Org.Apache.REEF.Network.Examples.GroupCommunication; |
| using Org.Apache.REEF.Network.Group.Config; |
| using Org.Apache.REEF.Network.Group.Driver; |
| using Org.Apache.REEF.Network.Group.Driver.Impl; |
| using Org.Apache.REEF.Network.Group.Operators; |
| using Org.Apache.REEF.Network.Group.Operators.Impl; |
| using Org.Apache.REEF.Network.Group.Pipelining; |
| using Org.Apache.REEF.Network.Group.Pipelining.Impl; |
| using Org.Apache.REEF.Network.Group.Task; |
| using Org.Apache.REEF.Network.Group.Topology; |
| using Org.Apache.REEF.Network.Naming; |
| using Org.Apache.REEF.Network.NetworkService; |
| using Org.Apache.REEF.Network.Tests.NamingService; |
| using Org.Apache.REEF.Tang.Annotations; |
| using Org.Apache.REEF.Tang.Formats; |
| using Org.Apache.REEF.Tang.Implementations.Configuration; |
| using Org.Apache.REEF.Tang.Implementations.Tang; |
| using Org.Apache.REEF.Tang.Interface; |
| using Org.Apache.REEF.Tang.Util; |
| using Org.Apache.REEF.Wake.Remote; |
| using Org.Apache.REEF.Wake.Remote.Impl; |
| using Org.Apache.REEF.Wake.StreamingCodec; |
| using Org.Apache.REEF.Wake.StreamingCodec.CommonStreamingCodecs; |
| using Xunit; |
| |
| namespace Org.Apache.REEF.Network.Tests.GroupCommunication |
| { |
| public class GroupCommunicationTests |
| { |
| [Fact] |
| public void TestSender() |
| { |
| using (var nameServer = NameServerTests.BuildNameServer()) |
| { |
| IPEndPoint endpoint = nameServer.LocalEndpoint; |
| |
| BlockingCollection<GeneralGroupCommunicationMessage> messages1 = |
| new BlockingCollection<GeneralGroupCommunicationMessage>(); |
| BlockingCollection<GeneralGroupCommunicationMessage> messages2 = |
| new BlockingCollection<GeneralGroupCommunicationMessage>(); |
| |
| var handler1 = |
| Observer.Create<NsMessage<GeneralGroupCommunicationMessage>>(msg => messages1.Add(msg.Data.First())); |
| var handler2 = |
| Observer.Create<NsMessage<GeneralGroupCommunicationMessage>>(msg => messages2.Add(msg.Data.First())); |
| |
| var networkServiceInjector1 = BuildNetworkServiceInjector(endpoint, handler1); |
| var networkServiceInjector2 = BuildNetworkServiceInjector(endpoint, handler2); |
| |
| var networkService1 = |
| networkServiceInjector1.GetInstance<StreamingNetworkService<GeneralGroupCommunicationMessage>>(); |
| var networkService2 = |
| networkServiceInjector2.GetInstance<StreamingNetworkService<GeneralGroupCommunicationMessage>>(); |
| networkService1.Register(new StringIdentifier("id1")); |
| networkService2.Register(new StringIdentifier("id2")); |
| |
| Sender sender1 = networkServiceInjector1.GetInstance<Sender>(); |
| Sender sender2 = networkServiceInjector2.GetInstance<Sender>(); |
| |
| sender1.Send(CreateGcmStringType("abc", "id1", "id2")); |
| sender1.Send(CreateGcmStringType("def", "id1", "id2")); |
| sender2.Send(CreateGcmStringType("ghi", "id2", "id1")); |
| |
| string msg1 = (messages2.Take() as GroupCommunicationMessage<string>).Data[0]; |
| string msg2 = (messages2.Take() as GroupCommunicationMessage<string>).Data[0]; |
| |
| Assert.Equal("abc", msg1); |
| Assert.Equal("def", msg2); |
| |
| string msg3 = (messages1.Take() as GroupCommunicationMessage<string>).Data[0]; |
| Assert.Equal("ghi", msg3); |
| } |
| } |
| |
| [Fact] |
| public void TestBroadcastReduceOperators() |
| { |
| string groupName = "group1"; |
| string broadcastOperatorName = "broadcast"; |
| string reduceOperatorName = "reduce"; |
| string masterTaskId = "task0"; |
| string driverId = "Driver Id"; |
| int numTasks = 3; |
| int fanOut = 2; |
| |
| var groupCommunicationDriver = GetInstanceOfGroupCommDriver(driverId, masterTaskId, groupName, fanOut, numTasks); |
| ICommunicationGroupDriver commGroup = groupCommunicationDriver.DefaultGroup |
| .AddBroadcast<int>( |
| broadcastOperatorName, |
| masterTaskId, |
| TopologyTypes.Flat, |
| GetDefaultDataConverterConfig()) |
| .AddReduce<int>( |
| reduceOperatorName, |
| masterTaskId, |
| TopologyTypes.Flat, |
| GetDefaultDataConverterConfig(), |
| GetDefaultReduceFuncConfig()) |
| .Build(); |
| |
| var commGroups = CommGroupClients(groupName, numTasks, groupCommunicationDriver, commGroup, GetDefaultCodecConfig()); |
| |
| // for master task |
| IBroadcastSender<int> broadcastSender = commGroups[0].GetBroadcastSender<int>(broadcastOperatorName); |
| IReduceReceiver<int> sumReducer = commGroups[0].GetReduceReceiver<int>(reduceOperatorName); |
| |
| IBroadcastReceiver<int> broadcastReceiver1 = commGroups[1].GetBroadcastReceiver<int>(broadcastOperatorName); |
| IReduceSender<int> triangleNumberSender1 = commGroups[1].GetReduceSender<int>(reduceOperatorName); |
| |
| IBroadcastReceiver<int> broadcastReceiver2 = commGroups[2].GetBroadcastReceiver<int>(broadcastOperatorName); |
| IReduceSender<int> triangleNumberSender2 = commGroups[2].GetReduceSender<int>(reduceOperatorName); |
| |
| for (int j = 1; j <= 10; j++) |
| { |
| broadcastSender.Send(j); |
| |
| int n1 = broadcastReceiver1.Receive(); |
| int n2 = broadcastReceiver2.Receive(); |
| Assert.Equal(j, n1); |
| Assert.Equal(j, n2); |
| |
| int triangleNum1 = TriangleNumber(n1); |
| triangleNumberSender1.Send(triangleNum1); |
| int triangleNum2 = TriangleNumber(n2); |
| triangleNumberSender2.Send(triangleNum2); |
| |
| int sum = sumReducer.Reduce(); |
| int expected = TriangleNumber(j) * (numTasks - 1); |
| Assert.Equal(sum, expected); |
| } |
| } |
| |
| /// <summary> |
| /// Test create a new group then remove it from the GroupDriver |
| /// </summary> |
| [Fact] |
| public void TestRemoveCommunicationGroup() |
| { |
| const string groupName = "group1"; |
| const string groupName2 = "group2"; |
| const string masterTaskId = "task0"; |
| const string driverId = "Driver Id"; |
| const int numTasks = 3; |
| const int fanOut = 2; |
| |
| var groupCommunicationDriver = GetInstanceOfGroupCommDriver(driverId, masterTaskId, groupName, fanOut, numTasks); |
| var group = groupCommunicationDriver.NewCommunicationGroup(groupName2, 5); |
| Assert.NotNull(group); |
| groupCommunicationDriver.RemoveCommunicationGroup(groupName2); |
| |
| Action remove = () => groupCommunicationDriver.RemoveCommunicationGroup(groupName2); |
| Assert.Throws<ArgumentException>(remove); |
| } |
| |
| /// <summary> |
| /// Test remove default group |
| /// </summary> |
| [Fact] |
| public void TestRemoveDefaultGroup() |
| { |
| const string groupName = "group1"; |
| const string masterTaskId = "task0"; |
| const string driverId = "Driver Id"; |
| const int numTasks = 3; |
| const int fanOut = 2; |
| |
| var groupCommunicationDriver = GetInstanceOfGroupCommDriver(driverId, masterTaskId, groupName, fanOut, numTasks); |
| var group = groupCommunicationDriver.DefaultGroup; |
| Assert.NotNull(group); |
| groupCommunicationDriver.RemoveCommunicationGroup(groupName); |
| |
| Action remove = () => groupCommunicationDriver.RemoveCommunicationGroup(groupName); |
| Assert.Throws<ArgumentException>(remove); |
| } |
| |
| [Fact] |
| public void TestRemoveNoExistGroup() |
| { |
| const string groupName = "group1"; |
| const string masterTaskId = "task0"; |
| const string driverId = "Driver Id"; |
| const int numTasks = 3; |
| const int fanOut = 2; |
| |
| var groupCommunicationDriver = GetInstanceOfGroupCommDriver(driverId, masterTaskId, groupName, fanOut, numTasks); |
| Action remove = () => groupCommunicationDriver.RemoveCommunicationGroup(groupName); |
| Assert.Throws<ArgumentException>(remove); |
| } |
| |
| /// <summary> |
| /// This is to test operator injection in CommunicationGroupClient with int[] as message type |
| /// </summary> |
| [Fact] |
| public void TestGetBroadcastReduceOperatorsForIntArrayMessageType() |
| { |
| const string groupName = "group1"; |
| const string broadcastOperatorName = "broadcast"; |
| const string reduceOperatorName = "reduce"; |
| const string masterTaskId = "task0"; |
| const string driverId = "Driver Id"; |
| const int numTasks = 3; |
| const int fanOut = 2; |
| |
| IConfiguration codecConfig = StreamingCodecConfiguration<int[]>.Conf |
| .Set(StreamingCodecConfiguration<int[]>.Codec, GenericType<IntArrayStreamingCodec>.Class) |
| .Build(); |
| |
| IConfiguration reduceFunctionConfig = ReduceFunctionConfiguration<int[]>.Conf |
| .Set(ReduceFunctionConfiguration<int[]>.ReduceFunction, GenericType<ArraySumFunction>.Class) |
| .Build(); |
| |
| IConfiguration dataConverterConfig = TangFactory.GetTang().NewConfigurationBuilder( |
| PipelineDataConverterConfiguration<int[]>.Conf |
| .Set(PipelineDataConverterConfiguration<int[]>.DataConverter, |
| GenericType<PipelineIntDataConverter>.Class) |
| .Build()) |
| .BindNamedParameter<GroupTestConfig.ChunkSize, int>( |
| GenericType<GroupTestConfig.ChunkSize>.Class, |
| GroupTestConstants.ChunkSize.ToString(CultureInfo.InvariantCulture)) |
| .Build(); |
| |
| var groupCommunicationDriver = GetInstanceOfGroupCommDriver(driverId, masterTaskId, groupName, fanOut, numTasks); |
| ICommunicationGroupDriver commGroup = groupCommunicationDriver.DefaultGroup |
| .AddBroadcast<int[]>( |
| broadcastOperatorName, |
| masterTaskId, |
| TopologyTypes.Flat, |
| dataConverterConfig) |
| .AddReduce<int[]>( |
| reduceOperatorName, |
| masterTaskId, |
| TopologyTypes.Flat, |
| dataConverterConfig, |
| reduceFunctionConfig) |
| .Build(); |
| |
| var commGroups = CommGroupClients(groupName, numTasks, groupCommunicationDriver, commGroup, codecConfig); |
| |
| // for master task |
| IBroadcastSender<int[]> broadcastSender = commGroups[0].GetBroadcastSender<int[]>(broadcastOperatorName); |
| IReduceReceiver<int[]> sumReducer = commGroups[0].GetReduceReceiver<int[]>(reduceOperatorName); |
| |
| IBroadcastReceiver<int[]> broadcastReceiver1 = commGroups[1].GetBroadcastReceiver<int[]>(broadcastOperatorName); |
| IReduceSender<int[]> triangleNumberSender1 = commGroups[1].GetReduceSender<int[]>(reduceOperatorName); |
| |
| IBroadcastReceiver<int[]> broadcastReceiver2 = commGroups[2].GetBroadcastReceiver<int[]>(broadcastOperatorName); |
| IReduceSender<int[]> triangleNumberSender2 = commGroups[2].GetReduceSender<int[]>(reduceOperatorName); |
| |
| Assert.NotNull(broadcastSender); |
| Assert.NotNull(sumReducer); |
| Assert.NotNull(broadcastReceiver1); |
| Assert.NotNull(triangleNumberSender1); |
| Assert.NotNull(broadcastReceiver2); |
| Assert.NotNull(triangleNumberSender2); |
| } |
| |
| [Fact] |
| public void TestScatterReduceOperators() |
| { |
| string groupName = "group1"; |
| string scatterOperatorName = "scatter"; |
| string reduceOperatorName = "reduce"; |
| string masterTaskId = "task0"; |
| string driverId = "Driver Id"; |
| int numTasks = 5; |
| int fanOut = 2; |
| |
| IGroupCommDriver groupCommDriver = GetInstanceOfGroupCommDriver(driverId, masterTaskId, groupName, fanOut, numTasks); |
| |
| ICommunicationGroupDriver commGroup = groupCommDriver.DefaultGroup |
| .AddScatter<int>( |
| scatterOperatorName, |
| masterTaskId, |
| TopologyTypes.Flat, |
| GetDefaultDataConverterConfig()) |
| .AddReduce<int>( |
| reduceOperatorName, |
| masterTaskId, |
| TopologyTypes.Flat, |
| GetDefaultReduceFuncConfig(), |
| GetDefaultDataConverterConfig()) |
| .Build(); |
| |
| List<ICommunicationGroupClient> commGroups = CommGroupClients(groupName, numTasks, groupCommDriver, commGroup, GetDefaultCodecConfig()); |
| |
| IScatterSender<int> sender = commGroups[0].GetScatterSender<int>(scatterOperatorName); |
| IReduceReceiver<int> sumReducer = commGroups[0].GetReduceReceiver<int>(reduceOperatorName); |
| |
| IScatterReceiver<int> receiver1 = commGroups[1].GetScatterReceiver<int>(scatterOperatorName); |
| IReduceSender<int> sumSender1 = commGroups[1].GetReduceSender<int>(reduceOperatorName); |
| |
| IScatterReceiver<int> receiver2 = commGroups[2].GetScatterReceiver<int>(scatterOperatorName); |
| IReduceSender<int> sumSender2 = commGroups[2].GetReduceSender<int>(reduceOperatorName); |
| |
| IScatterReceiver<int> receiver3 = commGroups[3].GetScatterReceiver<int>(scatterOperatorName); |
| IReduceSender<int> sumSender3 = commGroups[3].GetReduceSender<int>(reduceOperatorName); |
| |
| IScatterReceiver<int> receiver4 = commGroups[4].GetScatterReceiver<int>(scatterOperatorName); |
| IReduceSender<int> sumSender4 = commGroups[4].GetReduceSender<int>(reduceOperatorName); |
| |
| Assert.NotNull(sender); |
| Assert.NotNull(receiver1); |
| Assert.NotNull(receiver2); |
| Assert.NotNull(receiver3); |
| Assert.NotNull(receiver4); |
| |
| List<int> data = Enumerable.Range(1, 100).ToList(); |
| List<string> order = new List<string> { "task4", "task3", "task2", "task1" }; |
| |
| sender.Send(data, order); |
| |
| ScatterReceiveReduce(receiver4, sumSender4); |
| ScatterReceiveReduce(receiver3, sumSender3); |
| ScatterReceiveReduce(receiver2, sumSender2); |
| ScatterReceiveReduce(receiver1, sumSender1); |
| |
| int sum = sumReducer.Reduce(); |
| |
| Assert.Equal(sum, data.Sum()); |
| } |
| |
| [Fact] |
| public void TestBroadcastOperator() |
| { |
| string groupName = "group1"; |
| string operatorName = "broadcast"; |
| string masterTaskId = "task0"; |
| string driverId = "Driver Id"; |
| int numTasks = 10; |
| int value = 1337; |
| int fanOut = 3; |
| |
| IGroupCommDriver groupCommDriver = GetInstanceOfGroupCommDriver(driverId, masterTaskId, groupName, fanOut, numTasks); |
| |
| var commGroup = groupCommDriver.DefaultGroup |
| .AddBroadcast(operatorName, masterTaskId) |
| .Build(); |
| |
| List<ICommunicationGroupClient> commGroups = CommGroupClients(groupName, numTasks, groupCommDriver, commGroup, GetDefaultCodecConfig()); |
| |
| IBroadcastSender<int> sender = commGroups[0].GetBroadcastSender<int>(operatorName); |
| IBroadcastReceiver<int> receiver1 = commGroups[1].GetBroadcastReceiver<int>(operatorName); |
| IBroadcastReceiver<int> receiver2 = commGroups[2].GetBroadcastReceiver<int>(operatorName); |
| |
| Assert.NotNull(sender); |
| Assert.NotNull(receiver1); |
| Assert.NotNull(receiver2); |
| |
| sender.Send(value); |
| Assert.Equal(value, receiver1.Receive()); |
| Assert.Equal(value, receiver2.Receive()); |
| } |
| |
| [Fact] |
| public void TestBroadcastOperatorWithDefaultCodec() |
| { |
| INameServer nameServer = NameServerTests.BuildNameServer(); |
| |
| string groupName = "group1"; |
| string operatorName = "broadcast"; |
| string masterTaskId = "task0"; |
| string driverId = "Driver Id"; |
| int numTasks = 10; |
| int value = 1337; |
| int fanOut = 3; |
| |
| IGroupCommDriver groupCommDriver = GetInstanceOfGroupCommDriver(driverId, masterTaskId, groupName, fanOut, numTasks); |
| |
| var commGroup = groupCommDriver.DefaultGroup |
| .AddBroadcast(operatorName, masterTaskId) |
| .Build(); |
| |
| List<ICommunicationGroupClient> commGroups = CommGroupClients(groupName, numTasks, groupCommDriver, commGroup, GetDefaultCodecConfig()); |
| |
| IBroadcastSender<int> sender = commGroups[0].GetBroadcastSender<int>(operatorName); |
| IBroadcastReceiver<int> receiver1 = commGroups[1].GetBroadcastReceiver<int>(operatorName); |
| IBroadcastReceiver<int> receiver2 = commGroups[2].GetBroadcastReceiver<int>(operatorName); |
| |
| Assert.NotNull(sender); |
| Assert.NotNull(receiver1); |
| Assert.NotNull(receiver2); |
| |
| sender.Send(value); |
| Assert.Equal(value, receiver1.Receive()); |
| Assert.Equal(value, receiver2.Receive()); |
| } |
| |
| [Fact] |
| public void TestBroadcastOperator2() |
| { |
| string groupName = "group1"; |
| string operatorName = "broadcast"; |
| string driverId = "driverId"; |
| string masterTaskId = "task0"; |
| int numTasks = 3; |
| int value1 = 1337; |
| int value2 = 42; |
| int value3 = 99; |
| int fanOut = 2; |
| |
| IGroupCommDriver groupCommDriver = GetInstanceOfGroupCommDriver(driverId, masterTaskId, groupName, fanOut, numTasks); |
| |
| var commGroup = groupCommDriver.DefaultGroup |
| .AddBroadcast(operatorName, masterTaskId) |
| .Build(); |
| |
| List<ICommunicationGroupClient> commGroups = CommGroupClients(groupName, numTasks, groupCommDriver, commGroup, GetDefaultCodecConfig()); |
| |
| IBroadcastSender<int> sender = commGroups[0].GetBroadcastSender<int>(operatorName); |
| IBroadcastReceiver<int> receiver1 = commGroups[1].GetBroadcastReceiver<int>(operatorName); |
| IBroadcastReceiver<int> receiver2 = commGroups[2].GetBroadcastReceiver<int>(operatorName); |
| |
| Assert.NotNull(sender); |
| Assert.NotNull(receiver1); |
| Assert.NotNull(receiver2); |
| |
| sender.Send(value1); |
| Assert.Equal(value1, receiver1.Receive()); |
| Assert.Equal(value1, receiver2.Receive()); |
| |
| sender.Send(value2); |
| Assert.Equal(value2, receiver1.Receive()); |
| Assert.Equal(value2, receiver2.Receive()); |
| |
| sender.Send(value3); |
| Assert.Equal(value3, receiver1.Receive()); |
| Assert.Equal(value3, receiver2.Receive()); |
| } |
| |
| [Fact] |
| public void TestReduceOperator() |
| { |
| string groupName = "group1"; |
| string operatorName = "reduce"; |
| int numTasks = 4; |
| string driverId = "driverid"; |
| string masterTaskId = "task0"; |
| int fanOut = 2; |
| |
| IGroupCommDriver groupCommDriver = GetInstanceOfGroupCommDriver(driverId, masterTaskId, groupName, fanOut, numTasks); |
| |
| var commGroup = groupCommDriver.DefaultGroup |
| .AddReduce<int>(operatorName, "task0", TopologyTypes.Flat, GetDefaultDataConverterConfig(), GetDefaultReduceFuncConfig()) |
| .Build(); |
| |
| List<ICommunicationGroupClient> commGroups = CommGroupClients(groupName, numTasks, groupCommDriver, commGroup, GetDefaultCodecConfig()); |
| |
| IReduceReceiver<int> receiver = commGroups[0].GetReduceReceiver<int>(operatorName); |
| IReduceSender<int> sender1 = commGroups[1].GetReduceSender<int>(operatorName); |
| IReduceSender<int> sender2 = commGroups[2].GetReduceSender<int>(operatorName); |
| IReduceSender<int> sender3 = commGroups[3].GetReduceSender<int>(operatorName); |
| |
| Assert.NotNull(receiver); |
| Assert.NotNull(sender1); |
| Assert.NotNull(sender2); |
| Assert.NotNull(sender3); |
| |
| sender3.Send(5); |
| sender1.Send(1); |
| sender2.Send(3); |
| |
| Assert.Equal(9, receiver.Reduce()); |
| } |
| |
| [Fact] |
| public void TestReduceOperator2() |
| { |
| string groupName = "group1"; |
| string operatorName = "reduce"; |
| int numTasks = 4; |
| string driverId = "driverid"; |
| string masterTaskId = "task0"; |
| int fanOut = 2; |
| |
| IGroupCommDriver groupCommDriver = GetInstanceOfGroupCommDriver(driverId, masterTaskId, groupName, fanOut, numTasks); |
| |
| var commGroup = groupCommDriver.DefaultGroup |
| .AddReduce<int>(operatorName, "task0", TopologyTypes.Flat, GetDefaultDataConverterConfig(), GetDefaultReduceFuncConfig()) |
| .Build(); |
| |
| List<ICommunicationGroupClient> commGroups = CommGroupClients(groupName, numTasks, groupCommDriver, commGroup, GetDefaultCodecConfig()); |
| |
| IReduceReceiver<int> receiver = commGroups[0].GetReduceReceiver<int>(operatorName); |
| IReduceSender<int> sender1 = commGroups[1].GetReduceSender<int>(operatorName); |
| IReduceSender<int> sender2 = commGroups[2].GetReduceSender<int>(operatorName); |
| IReduceSender<int> sender3 = commGroups[3].GetReduceSender<int>(operatorName); |
| |
| Assert.NotNull(receiver); |
| Assert.NotNull(sender1); |
| Assert.NotNull(sender2); |
| Assert.NotNull(sender3); |
| |
| sender3.Send(5); |
| sender1.Send(1); |
| sender2.Send(3); |
| Assert.Equal(9, receiver.Reduce()); |
| |
| sender3.Send(6); |
| sender1.Send(2); |
| sender2.Send(4); |
| Assert.Equal(12, receiver.Reduce()); |
| |
| sender3.Send(9); |
| sender1.Send(3); |
| sender2.Send(6); |
| Assert.Equal(18, receiver.Reduce()); |
| } |
| |
| [Fact] |
| public void TestScatterOperator() |
| { |
| string groupName = "group1"; |
| string operatorName = "scatter"; |
| string masterTaskId = "task0"; |
| string driverId = "Driver Id"; |
| int numTasks = 5; |
| int fanOut = 2; |
| |
| IGroupCommDriver groupCommDriver = GetInstanceOfGroupCommDriver(driverId, masterTaskId, groupName, fanOut, numTasks); |
| |
| var commGroup = groupCommDriver.DefaultGroup |
| .AddScatter(operatorName, masterTaskId) |
| .Build(); |
| |
| List<ICommunicationGroupClient> commGroups = CommGroupClients(groupName, numTasks, groupCommDriver, commGroup, GetDefaultCodecConfig()); |
| |
| IScatterSender<int> sender = commGroups[0].GetScatterSender<int>(operatorName); |
| IScatterReceiver<int> receiver1 = commGroups[1].GetScatterReceiver<int>(operatorName); |
| IScatterReceiver<int> receiver2 = commGroups[2].GetScatterReceiver<int>(operatorName); |
| IScatterReceiver<int> receiver3 = commGroups[3].GetScatterReceiver<int>(operatorName); |
| IScatterReceiver<int> receiver4 = commGroups[4].GetScatterReceiver<int>(operatorName); |
| |
| Assert.NotNull(sender); |
| Assert.NotNull(receiver1); |
| Assert.NotNull(receiver2); |
| Assert.NotNull(receiver3); |
| Assert.NotNull(receiver4); |
| |
| List<int> data = new List<int> { 1, 2, 3, 4 }; |
| |
| sender.Send(data); |
| Assert.Equal(1, receiver1.Receive().Single()); |
| Assert.Equal(2, receiver2.Receive().Single()); |
| Assert.Equal(3, receiver3.Receive().Single()); |
| Assert.Equal(4, receiver4.Receive().Single()); |
| } |
| |
| [Fact] |
| public void TestScatterOperatorWithDefaultCodec() |
| { |
| string groupName = "group1"; |
| string operatorName = "scatter"; |
| string masterTaskId = "task0"; |
| string driverId = "Driver Id"; |
| int numTasks = 5; |
| int fanOut = 2; |
| |
| IGroupCommDriver groupCommDriver = GetInstanceOfGroupCommDriver(driverId, masterTaskId, groupName, fanOut, numTasks); |
| |
| var commGroup = groupCommDriver.DefaultGroup |
| .AddScatter(operatorName, masterTaskId) |
| .Build(); |
| |
| List<ICommunicationGroupClient> commGroups = CommGroupClients(groupName, numTasks, groupCommDriver, commGroup, GetDefaultCodecConfig()); |
| |
| IScatterSender<int> sender = commGroups[0].GetScatterSender<int>(operatorName); |
| IScatterReceiver<int> receiver1 = commGroups[1].GetScatterReceiver<int>(operatorName); |
| IScatterReceiver<int> receiver2 = commGroups[2].GetScatterReceiver<int>(operatorName); |
| IScatterReceiver<int> receiver3 = commGroups[3].GetScatterReceiver<int>(operatorName); |
| IScatterReceiver<int> receiver4 = commGroups[4].GetScatterReceiver<int>(operatorName); |
| |
| Assert.NotNull(sender); |
| Assert.NotNull(receiver1); |
| Assert.NotNull(receiver2); |
| Assert.NotNull(receiver3); |
| Assert.NotNull(receiver4); |
| |
| List<int> data = new List<int> { 1, 2, 3, 4 }; |
| |
| sender.Send(data); |
| Assert.Equal(1, receiver1.Receive().Single()); |
| Assert.Equal(2, receiver2.Receive().Single()); |
| Assert.Equal(3, receiver3.Receive().Single()); |
| Assert.Equal(4, receiver4.Receive().Single()); |
| } |
| |
| [Fact] |
| public void TestScatterOperator2() |
| { |
| string groupName = "group1"; |
| string operatorName = "scatter"; |
| string masterTaskId = "task0"; |
| string driverId = "Driver Id"; |
| int numTasks = 5; |
| int fanOut = 2; |
| |
| IGroupCommDriver groupCommDriver = GetInstanceOfGroupCommDriver(driverId, masterTaskId, groupName, fanOut, numTasks); |
| |
| var commGroup = groupCommDriver.DefaultGroup |
| .AddScatter<int>(operatorName, masterTaskId, TopologyTypes.Flat, GetDefaultDataConverterConfig(), GetDefaultReduceFuncConfig()) |
| .Build(); |
| |
| List<ICommunicationGroupClient> commGroups = CommGroupClients(groupName, numTasks, groupCommDriver, commGroup, GetDefaultCodecConfig()); |
| |
| IScatterSender<int> sender = commGroups[0].GetScatterSender<int>(operatorName); |
| IScatterReceiver<int> receiver1 = commGroups[1].GetScatterReceiver<int>(operatorName); |
| IScatterReceiver<int> receiver2 = commGroups[2].GetScatterReceiver<int>(operatorName); |
| IScatterReceiver<int> receiver3 = commGroups[3].GetScatterReceiver<int>(operatorName); |
| IScatterReceiver<int> receiver4 = commGroups[4].GetScatterReceiver<int>(operatorName); |
| |
| Assert.NotNull(sender); |
| Assert.NotNull(receiver1); |
| Assert.NotNull(receiver2); |
| Assert.NotNull(receiver3); |
| Assert.NotNull(receiver4); |
| |
| List<int> data = new List<int> { 1, 2, 3, 4, 5, 6, 7, 8 }; |
| |
| sender.Send(data); |
| var data1 = receiver1.Receive(); |
| Assert.Equal(1, data1.First()); |
| Assert.Equal(2, data1.Last()); |
| |
| var data2 = receiver2.Receive(); |
| Assert.Equal(3, data2.First()); |
| Assert.Equal(4, data2.Last()); |
| |
| var data3 = receiver3.Receive(); |
| Assert.Equal(5, data3.First()); |
| Assert.Equal(6, data3.Last()); |
| |
| var data4 = receiver4.Receive(); |
| Assert.Equal(7, data4.First()); |
| Assert.Equal(8, data4.Last()); |
| } |
| |
| [Fact] |
| public void TestScatterOperator3() |
| { |
| string groupName = "group1"; |
| string operatorName = "scatter"; |
| string masterTaskId = "task0"; |
| string driverId = "Driver Id"; |
| int numTasks = 4; |
| int fanOut = 2; |
| |
| IGroupCommDriver groupCommDriver = GetInstanceOfGroupCommDriver(driverId, masterTaskId, groupName, fanOut, numTasks); |
| |
| var commGroup = groupCommDriver.DefaultGroup |
| .AddScatter<int>(operatorName, masterTaskId, TopologyTypes.Flat, GetDefaultDataConverterConfig()) |
| .Build(); |
| |
| List<ICommunicationGroupClient> commGroups = CommGroupClients(groupName, numTasks, groupCommDriver, commGroup, GetDefaultCodecConfig()); |
| |
| IScatterSender<int> sender = commGroups[0].GetScatterSender<int>(operatorName); |
| IScatterReceiver<int> receiver1 = commGroups[1].GetScatterReceiver<int>(operatorName); |
| IScatterReceiver<int> receiver2 = commGroups[2].GetScatterReceiver<int>(operatorName); |
| IScatterReceiver<int> receiver3 = commGroups[3].GetScatterReceiver<int>(operatorName); |
| |
| Assert.NotNull(sender); |
| Assert.NotNull(receiver1); |
| Assert.NotNull(receiver2); |
| Assert.NotNull(receiver3); |
| |
| List<int> data = new List<int> { 1, 2, 3, 4, 5, 6, 7, 8 }; |
| |
| sender.Send(data); |
| |
| var data1 = receiver1.Receive().ToArray(); |
| Assert.Equal(1, data1[0]); |
| Assert.Equal(2, data1[1]); |
| Assert.Equal(3, data1[2]); |
| |
| var data2 = receiver2.Receive().ToArray(); |
| Assert.Equal(4, data2[0]); |
| Assert.Equal(5, data2[1]); |
| Assert.Equal(6, data2[2]); |
| |
| var data3 = receiver3.Receive().ToArray(); |
| Assert.Equal(7, data3[0]); |
| Assert.Equal(8, data3[1]); |
| } |
| |
| [Fact] |
| public void TestScatterOperator4() |
| { |
| string groupName = "group1"; |
| string operatorName = "scatter"; |
| string masterTaskId = "task0"; |
| string driverId = "Driver Id"; |
| int numTasks = 4; |
| int fanOut = 2; |
| |
| IGroupCommDriver groupCommDriver = GetInstanceOfGroupCommDriver(driverId, masterTaskId, groupName, fanOut, numTasks); |
| |
| var commGroup = groupCommDriver.DefaultGroup |
| .AddScatter<int>(operatorName, masterTaskId, TopologyTypes.Flat, GetDefaultDataConverterConfig()) |
| .Build(); |
| |
| List<ICommunicationGroupClient> commGroups = CommGroupClients(groupName, numTasks, groupCommDriver, commGroup, GetDefaultCodecConfig()); |
| IScatterSender<int> sender = commGroups[0].GetScatterSender<int>(operatorName); |
| IScatterReceiver<int> receiver1 = commGroups[1].GetScatterReceiver<int>(operatorName); |
| IScatterReceiver<int> receiver2 = commGroups[2].GetScatterReceiver<int>(operatorName); |
| IScatterReceiver<int> receiver3 = commGroups[3].GetScatterReceiver<int>(operatorName); |
| |
| Assert.NotNull(sender); |
| Assert.NotNull(receiver1); |
| Assert.NotNull(receiver2); |
| Assert.NotNull(receiver3); |
| |
| List<int> data = new List<int> { 1, 2, 3, 4, 5, 6, 7, 8 }; |
| List<string> order = new List<string> { "task3", "task2", "task1" }; |
| |
| sender.Send(data, order); |
| |
| var data3 = receiver3.Receive().ToArray(); |
| Assert.Equal(1, data3[0]); |
| Assert.Equal(2, data3[1]); |
| Assert.Equal(3, data3[2]); |
| |
| var data2 = receiver2.Receive().ToArray(); |
| Assert.Equal(4, data2[0]); |
| Assert.Equal(5, data2[1]); |
| Assert.Equal(6, data2[2]); |
| |
| var data1 = receiver1.Receive().ToArray(); |
| Assert.Equal(7, data1[0]); |
| Assert.Equal(8, data1[1]); |
| } |
| |
| [Fact] |
| public void TestConfigurationBroadcastSpec() |
| { |
| FlatTopology<int> topology = new FlatTopology<int>("Operator", "Operator", "task1", "driverid", |
| new BroadcastOperatorSpec("Sender", GetDefaultCodecConfig(), GetDefaultDataConverterConfig())); |
| |
| topology.AddTask("task1"); |
| var conf = topology.GetTaskConfiguration("task1"); |
| |
| IStreamingCodec<int> codec = TangFactory.GetTang().NewInjector(conf).GetInstance<IStreamingCodec<int>>(); |
| |
| var stream = new MemoryStream(); |
| IDataWriter writer = new StreamDataWriter(stream); |
| codec.Write(3, writer); |
| stream.Position = 0; |
| IDataReader reader = new StreamDataReader(stream); |
| int res = codec.Read(reader); |
| Assert.Equal(3, res); |
| } |
| |
| [Fact] |
| public void TestConfigurationReduceSpec() |
| { |
| FlatTopology<int> topology = new FlatTopology<int>("Operator", "Group", "task1", "driverid", |
| new ReduceOperatorSpec("task1", Configurations.Merge(GetDefaultCodecConfig(), GetDefaultDataConverterConfig(), GetDefaultReduceFuncConfig()))); |
| |
| topology.AddTask("task1"); |
| var conf2 = topology.GetTaskConfiguration("task1"); |
| |
| IReduceFunction<int> reduceFunction = TangFactory.GetTang().NewInjector(conf2).GetInstance<IReduceFunction<int>>(); |
| Assert.Equal(10, reduceFunction.Reduce(new int[] { 1, 2, 3, 4 })); |
| } |
| |
| public static IGroupCommDriver GetInstanceOfGroupCommDriver(string driverId, string masterTaskId, string groupName, int fanOut, int numTasks) |
| { |
| var c = TangFactory.GetTang().NewConfigurationBuilder() |
| .BindStringNamedParam<GroupCommConfigurationOptions.DriverId>(driverId) |
| .BindStringNamedParam<GroupCommConfigurationOptions.MasterTaskId>(masterTaskId) |
| .BindStringNamedParam<GroupCommConfigurationOptions.GroupName>(groupName) |
| .BindIntNamedParam<GroupCommConfigurationOptions.FanOut>(fanOut.ToString()) |
| .BindIntNamedParam<GroupCommConfigurationOptions.NumberOfTasks>(numTasks.ToString()) |
| .BindImplementation(GenericType<IConfigurationSerializer>.Class, GenericType<AvroConfigurationSerializer>.Class) |
| .Build(); |
| |
| IGroupCommDriver groupCommDriver = TangFactory.GetTang().NewInjector(c).GetInstance<GroupCommDriver>(); |
| return groupCommDriver; |
| } |
| |
| [Fact] |
| public async Task TestCodecToStreamingCodecConfiguration() |
| { |
| var config = CodecToStreamingCodecConfiguration<int>.Conf |
| .Set(CodecToStreamingCodecConfiguration<int>.Codec, GenericType<IntCodec>.Class) |
| .Build(); |
| |
| IStreamingCodec<PipelineMessage<int>> streamingCodec = |
| TangFactory.GetTang().NewInjector(config).GetInstance<IStreamingCodec<PipelineMessage<int>>>(); |
| |
| CancellationToken token = new CancellationToken(); |
| |
| int obj = 5; |
| PipelineMessage<int> message = new PipelineMessage<int>(obj, true); |
| var stream = new MemoryStream(); |
| IDataWriter writer = new StreamDataWriter(stream); |
| streamingCodec.Write(message, writer); |
| PipelineMessage<int> message1 = new PipelineMessage<int>(obj + 1, false); |
| await streamingCodec.WriteAsync(message1, writer, token); |
| |
| stream.Position = 0; |
| IDataReader reader = new StreamDataReader(stream); |
| var res1 = streamingCodec.Read(reader); |
| var res2 = await streamingCodec.ReadAsync(reader, token); |
| Assert.Equal(obj, res1.Data); |
| Assert.Equal(obj + 1, res2.Data); |
| Assert.Equal(true, res1.IsLast); |
| Assert.Equal(false, res2.IsLast); |
| } |
| |
| public static List<ICommunicationGroupClient> CommGroupClients(string groupName, int numTasks, IGroupCommDriver groupCommDriver, ICommunicationGroupDriver commGroupDriver, IConfiguration userServiceConfig) |
| { |
| List<ICommunicationGroupClient> commGroups = new List<ICommunicationGroupClient>(); |
| IConfiguration serviceConfig = groupCommDriver.GetServiceConfiguration(); |
| serviceConfig = Configurations.Merge(serviceConfig, userServiceConfig); |
| |
| List<IConfiguration> partialConfigs = new List<IConfiguration>(); |
| for (int i = 0; i < numTasks; i++) |
| { |
| string taskId = "task" + i; |
| IConfiguration partialTaskConfig = TangFactory.GetTang().NewConfigurationBuilder( |
| TaskConfiguration.ConfigurationModule |
| .Set(TaskConfiguration.Identifier, taskId) |
| .Set(TaskConfiguration.Task, GenericType<MyTask>.Class) |
| .Build()) |
| .Build(); |
| commGroupDriver.AddTask(taskId); |
| partialConfigs.Add(partialTaskConfig); |
| } |
| |
| for (int i = 0; i < numTasks; i++) |
| { |
| // get task configuration at driver side |
| string taskId = "task" + i; |
| IConfiguration groupCommTaskConfig = groupCommDriver.GetGroupCommTaskConfiguration(taskId); |
| IConfiguration mergedConf = Configurations.Merge(groupCommTaskConfig, partialConfigs[i], serviceConfig); |
| |
| var conf = TangFactory.GetTang() |
| .NewConfigurationBuilder(mergedConf) |
| .BindNamedParameter(typeof(GroupCommConfigurationOptions.Initialize), "false") |
| .Build(); |
| IInjector injector = TangFactory.GetTang().NewInjector(conf); |
| |
| // simulate injection at evaluator side |
| IGroupCommClient groupCommClient = injector.GetInstance<IGroupCommClient>(); |
| commGroups.Add(groupCommClient.GetCommunicationGroup(groupName)); |
| } |
| return commGroups; |
| } |
| |
| public static IInjector BuildNetworkServiceInjector( |
| IPEndPoint nameServerEndpoint, IObserver<NsMessage<GeneralGroupCommunicationMessage>> handler) |
| { |
| var config = TangFactory.GetTang().NewConfigurationBuilder() |
| .BindNamedParameter(typeof(NamingConfigurationOptions.NameServerAddress), |
| nameServerEndpoint.Address.ToString()) |
| .BindNamedParameter(typeof(NamingConfigurationOptions.NameServerPort), |
| nameServerEndpoint.Port.ToString()) |
| .BindNamedParameter(typeof(NetworkServiceOptions.NetworkServicePort), |
| 0.ToString(CultureInfo.InvariantCulture)) |
| .BindImplementation(GenericType<INameClient>.Class, GenericType<NameClient>.Class) |
| .Build(); |
| |
| var codecConfig = StreamingCodecConfiguration<string>.Conf |
| .Set(StreamingCodecConfiguration<string>.Codec, GenericType<StringStreamingCodec>.Class) |
| .Build(); |
| |
| config = Configurations.Merge(config, codecConfig); |
| |
| var injector = TangFactory.GetTang().NewInjector(config); |
| injector.BindVolatileInstance( |
| GenericType<IObserver<NsMessage<GeneralGroupCommunicationMessage>>>.Class, handler); |
| |
| return injector; |
| } |
| |
| private GroupCommunicationMessage<string> CreateGcmStringType(string message, string from, string to) |
| { |
| return new GroupCommunicationMessage<string>("g1", "op1", from, to, message); |
| } |
| |
| private static void ScatterReceiveReduce(IScatterReceiver<int> receiver, IReduceSender<int> sumSender) |
| { |
| List<int> data1 = receiver.Receive(); |
| int sum1 = data1.Sum(); |
| sumSender.Send(sum1); |
| } |
| |
| public static int TriangleNumber(int n) |
| { |
| return Enumerable.Range(1, n).Sum(); |
| } |
| |
| private static IConfiguration GetDefaultCodecConfig() |
| { |
| return StreamingCodecConfiguration<int>.Conf |
| .Set(StreamingCodecConfiguration<int>.Codec, GenericType<IntStreamingCodec>.Class) |
| .Build(); |
| } |
| |
| private static IConfiguration GetDefaultReduceFuncConfig() |
| { |
| return ReduceFunctionConfiguration<int>.Conf |
| .Set(ReduceFunctionConfiguration<int>.ReduceFunction, GenericType<SumFunction>.Class) |
| .Build(); |
| } |
| |
| private static IConfiguration GetDefaultDataConverterConfig() |
| { |
| return PipelineDataConverterConfiguration<int>.Conf |
| .Set(PipelineDataConverterConfiguration<int>.DataConverter, GenericType<DefaultPipelineDataConverter<int>>.Class) |
| .Build(); |
| } |
| } |
| |
| public class SumFunction : IReduceFunction<int> |
| { |
| [Inject] |
| public SumFunction() |
| { |
| } |
| |
| public int Reduce(IEnumerable<int> elements) |
| { |
| return elements.Sum(); |
| } |
| } |
| |
| public class MyTask : ITask |
| { |
| public void Dispose() |
| { |
| throw new NotImplementedException(); |
| } |
| |
| public byte[] Call(byte[] memento) |
| { |
| throw new NotImplementedException(); |
| } |
| } |
| |
| class ArraySumFunction : IReduceFunction<int[]> |
| { |
| [Inject] |
| private ArraySumFunction() |
| { |
| } |
| |
| public int[] Reduce(IEnumerable<int[]> elements) |
| { |
| int[] result = null; |
| int count = 0; |
| |
| foreach (var element in elements) |
| { |
| if (count == 0) |
| { |
| result = element.Clone() as int[]; |
| } |
| else |
| { |
| if (element.Length != result.Length) |
| { |
| throw new Exception("integer arrays are of different sizes"); |
| } |
| |
| for (int i = 0; i < result.Length; i++) |
| { |
| result[i] += element[i]; |
| } |
| } |
| count++; |
| } |
| |
| return result; |
| } |
| } |
| |
| class PipelineIntDataConverter : IPipelineDataConverter<int[]> |
| { |
| readonly int _chunkSize; |
| |
| [Inject] |
| private PipelineIntDataConverter([Parameter(typeof(GroupTestConfig.ChunkSize))] int chunkSize) |
| { |
| _chunkSize = chunkSize; |
| } |
| |
| public List<PipelineMessage<int[]>> PipelineMessage(int[] message) |
| { |
| List<PipelineMessage<int[]>> messageList = new List<PipelineMessage<int[]>>(); |
| int totalChunks = message.Length / _chunkSize; |
| |
| if (message.Length % _chunkSize != 0) |
| { |
| totalChunks++; |
| } |
| |
| int counter = 0; |
| for (int i = 0; i < message.Length; i += _chunkSize) |
| { |
| int[] data = new int[Math.Min(_chunkSize, message.Length - i)]; |
| Buffer.BlockCopy(message, i * sizeof(int), data, 0, data.Length * sizeof(int)); |
| |
| messageList.Add(counter == totalChunks - 1 |
| ? new PipelineMessage<int[]>(data, true) |
| : new PipelineMessage<int[]>(data, false)); |
| |
| counter++; |
| } |
| |
| return messageList; |
| } |
| |
| public int[] FullMessage(List<PipelineMessage<int[]>> pipelineMessage) |
| { |
| int size = pipelineMessage.Select(x => x.Data.Length).Sum(); |
| int[] data = new int[size]; |
| int offset = 0; |
| |
| foreach (var message in pipelineMessage) |
| { |
| Buffer.BlockCopy(message.Data, 0, data, offset, message.Data.Length * sizeof(int)); |
| offset += message.Data.Length * sizeof(int); |
| } |
| |
| return data; |
| } |
| } |
| } |