blob: e0e75f86126727c27a10d934581c9a236cc14e06 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
using System;
using System.Collections.Generic;
using System.Globalization;
using System.IO;
using System.Linq;
using Org.Apache.REEF.Common.Io;
using Org.Apache.REEF.Common.Services;
using Org.Apache.REEF.Common.Tasks;
using Org.Apache.REEF.Driver;
using Org.Apache.REEF.Driver.Bridge;
using Org.Apache.REEF.Driver.Context;
using Org.Apache.REEF.Driver.Evaluator;
using Org.Apache.REEF.Examples.MachineLearning.KMeans.codecs;
using Org.Apache.REEF.Network.Group.Config;
using Org.Apache.REEF.Network.Group.Driver;
using Org.Apache.REEF.Network.Group.Driver.Impl;
using Org.Apache.REEF.Network.Group.Pipelining.Impl;
using Org.Apache.REEF.Network.NetworkService;
using Org.Apache.REEF.Network.NetworkService.Codec;
using Org.Apache.REEF.Tang.Annotations;
using Org.Apache.REEF.Tang.Implementations.Configuration;
using Org.Apache.REEF.Tang.Implementations.Tang;
using Org.Apache.REEF.Tang.Interface;
using Org.Apache.REEF.Tang.Util;
using Org.Apache.REEF.Utilities.Logging;
using Org.Apache.REEF.Network.Group.Topology;
namespace Org.Apache.REEF.Examples.MachineLearning.KMeans
{
public class KMeansDriverHandlers :
IObserver<IAllocatedEvaluator>,
IObserver<IActiveContext>,
IObserver<IDriverStarted>
{
private static readonly Logger _Logger = Logger.GetLogger(typeof(KMeansDriverHandlers));
private readonly object _lockObj = new object();
private readonly string _executionDirectory;
// TODO: we may want to make this injectable
private readonly int _clustersNumber = 3;
private readonly int _totalEvaluators;
private int _partitionInex = 0;
private readonly IGroupCommDriver _groupCommDriver;
private readonly ICommunicationGroupDriver _commGroup;
private readonly TaskStarter _groupCommTaskStarter;
private readonly IConfiguration _centroidCodecConf;
private readonly IConfiguration _controlMessageCodecConf;
private readonly IConfiguration _processedResultsCodecConf;
private readonly IEvaluatorRequestor _evaluatorRequestor;
[Inject]
private KMeansDriverHandlers(
[Parameter(typeof(NumPartitions))] int numPartitions,
GroupCommDriver groupCommDriver,
IEvaluatorRequestor evaluatorRequestor,
CommandLineArguments arguments)
{
_executionDirectory = Path.Combine(Directory.GetCurrentDirectory(), Constants.KMeansExecutionBaseDirectory, Guid.NewGuid().ToString("N").Substring(0, 4));
string dataFile = arguments.Arguments.Single(a => a.StartsWith("DataFile", StringComparison.Ordinal)).Split(':')[1];
DataVector.ShuffleDataAndGetInitialCentriods(
Path.Combine(Directory.GetCurrentDirectory(), "reef", "global", dataFile),
numPartitions,
_clustersNumber,
_executionDirectory);
_totalEvaluators = numPartitions + 1;
_groupCommDriver = groupCommDriver;
_evaluatorRequestor = evaluatorRequestor;
_centroidCodecConf = CodecToStreamingCodecConfiguration<Centroids>.Conf
.Set(CodecConfiguration<Centroids>.Codec, GenericType<CentroidsCodec>.Class)
.Build();
IConfiguration dataConverterConfig1 = PipelineDataConverterConfiguration<Centroids>.Conf
.Set(PipelineDataConverterConfiguration<Centroids>.DataConverter, GenericType<DefaultPipelineDataConverter<Centroids>>.Class)
.Build();
_controlMessageCodecConf = CodecToStreamingCodecConfiguration<ControlMessage>.Conf
.Set(CodecConfiguration<ControlMessage>.Codec, GenericType<ControlMessageCodec>.Class)
.Build();
IConfiguration dataConverterConfig2 = PipelineDataConverterConfiguration<ControlMessage>.Conf
.Set(PipelineDataConverterConfiguration<ControlMessage>.DataConverter, GenericType<DefaultPipelineDataConverter<ControlMessage>>.Class)
.Build();
_processedResultsCodecConf = CodecToStreamingCodecConfiguration<ProcessedResults>.Conf
.Set(CodecConfiguration<ProcessedResults>.Codec, GenericType<ProcessedResultsCodec>.Class)
.Build();
IConfiguration reduceFunctionConfig = ReduceFunctionConfiguration<ProcessedResults>.Conf
.Set(ReduceFunctionConfiguration<ProcessedResults>.ReduceFunction, GenericType<KMeansMasterTask.AggregateMeans>.Class)
.Build();
IConfiguration dataConverterConfig3 = PipelineDataConverterConfiguration<ProcessedResults>.Conf
.Set(PipelineDataConverterConfiguration<ProcessedResults>.DataConverter, GenericType<DefaultPipelineDataConverter<ProcessedResults>>.Class)
.Build();
_commGroup = _groupCommDriver.DefaultGroup
.AddBroadcast<Centroids>(Constants.CentroidsBroadcastOperatorName, Constants.MasterTaskId, TopologyTypes.Flat, dataConverterConfig1)
.AddBroadcast<ControlMessage>(Constants.ControlMessageBroadcastOperatorName, Constants.MasterTaskId, TopologyTypes.Flat, dataConverterConfig2)
.AddReduce<ProcessedResults>(Constants.MeansReduceOperatorName, Constants.MasterTaskId, TopologyTypes.Flat, reduceFunctionConfig, dataConverterConfig3)
.Build();
_groupCommTaskStarter = new TaskStarter(_groupCommDriver, _totalEvaluators);
}
public void OnNext(IAllocatedEvaluator allocatedEvaluator)
{
IConfiguration contextConfiguration = _groupCommDriver.GetContextConfiguration();
int partitionNum;
if (_groupCommDriver.IsMasterContextConfiguration(contextConfiguration))
{
partitionNum = -1;
}
else
{
lock (_lockObj)
{
partitionNum = _partitionInex;
_partitionInex++;
}
}
IConfiguration gcServiceConfiguration = _groupCommDriver.GetServiceConfiguration();
gcServiceConfiguration = Configurations.Merge(gcServiceConfiguration, _centroidCodecConf, _controlMessageCodecConf, _processedResultsCodecConf);
IConfiguration commonServiceConfiguration = TangFactory.GetTang().NewConfigurationBuilder(gcServiceConfiguration)
.BindNamedParameter<DataPartitionCache.PartitionIndex, int>(GenericType<DataPartitionCache.PartitionIndex>.Class, partitionNum.ToString(CultureInfo.InvariantCulture))
.BindNamedParameter<KMeansConfiguratioinOptions.ExecutionDirectory, string>(GenericType<KMeansConfiguratioinOptions.ExecutionDirectory>.Class, _executionDirectory)
.BindNamedParameter<KMeansConfiguratioinOptions.TotalNumEvaluators, int>(GenericType<KMeansConfiguratioinOptions.TotalNumEvaluators>.Class, _totalEvaluators.ToString(CultureInfo.InvariantCulture))
.BindNamedParameter<KMeansConfiguratioinOptions.K, int>(GenericType<KMeansConfiguratioinOptions.K>.Class, _clustersNumber.ToString(CultureInfo.InvariantCulture))
.Build();
IConfiguration dataCacheServiceConfiguration = ServiceConfiguration.ConfigurationModule
.Set(ServiceConfiguration.Services, GenericType<DataPartitionCache>.Class)
.Build();
allocatedEvaluator.SubmitContextAndService(contextConfiguration, Configurations.Merge(commonServiceConfiguration, dataCacheServiceConfiguration));
}
public void OnNext(IActiveContext activeContext)
{
IConfiguration taskConfiguration;
if (_groupCommDriver.IsMasterTaskContext(activeContext))
{
// Configure Master Task
taskConfiguration = TaskConfiguration.ConfigurationModule
.Set(TaskConfiguration.Identifier, Constants.MasterTaskId)
.Set(TaskConfiguration.Task, GenericType<KMeansMasterTask>.Class)
.Build();
_commGroup.AddTask(Constants.MasterTaskId);
}
else
{
string slaveTaskId = Constants.SlaveTaskIdPrefix + activeContext.Id;
// Configure Slave Task
taskConfiguration = TaskConfiguration.ConfigurationModule
.Set(TaskConfiguration.Identifier, Constants.SlaveTaskIdPrefix + activeContext.Id)
.Set(TaskConfiguration.Task, GenericType<KMeansSlaveTask>.Class)
.Build();
_commGroup.AddTask(slaveTaskId);
}
_groupCommTaskStarter.QueueTask(taskConfiguration, activeContext);
}
public void OnNext(IDriverStarted value)
{
var request = _evaluatorRequestor.NewBuilder().SetCores(1).SetMegabytes(2048).Build();
_evaluatorRequestor.Submit(request);
}
public void OnError(Exception error)
{
throw new NotImplementedException();
}
public void OnCompleted()
{
throw new NotImplementedException();
}
}
[NamedParameter("Number of partitions")]
public class NumPartitions : Name<int>
{
}
}