blob: c1b3e90b5cedfbdbdee0fba095d40a944b3409d1 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
using System;
using System.Collections.Generic;
using System.Globalization;
using System.IO;
using System.Linq;
using Org.Apache.REEF.Common.Tasks;
using Org.Apache.REEF.Network.Group.Operators;
using Org.Apache.REEF.Network.Group.Task;
using Org.Apache.REEF.Network.NetworkService;
using Org.Apache.REEF.Tang.Annotations;
using Org.Apache.REEF.Utilities.Logging;
namespace Org.Apache.REEF.Examples.MachineLearning.KMeans
{
public class KMeansMasterTask : ITask
{
private static readonly Logger Logger = Logger.GetLogger(typeof(KMeansMasterTask));
private int _iteration = 0;
private readonly ICommunicationGroupClient _commGroup;
private readonly IBroadcastSender<Centroids> _dataBroadcastSender;
private readonly IBroadcastSender<ControlMessage> _controlBroadcastSender;
private readonly IReduceReceiver<ProcessedResults> _meansReducerReceiver;
private readonly string _kMeansExecutionDirectory;
private Centroids _centroids;
private bool _isInitialIteration;
[Inject]
public KMeansMasterTask(
[Parameter(typeof(KMeansConfiguratioinOptions.TotalNumEvaluators))] int totalNumEvaluators,
[Parameter(Value = typeof(KMeansConfiguratioinOptions.ExecutionDirectory))] string executionDirectory,
IGroupCommClient groupCommClient)
{
using (Logger.LogFunction("KMeansMasterTask"))
{
if (totalNumEvaluators <= 1)
{
throw new ArgumentException("There must be more than 1 Evaluators in total, but the total evaluators number provided is " + totalNumEvaluators);
}
_commGroup = groupCommClient.GetCommunicationGroup(Constants.KMeansCommunicationGroupName);
_dataBroadcastSender = _commGroup.GetBroadcastSender<Centroids>(Constants.CentroidsBroadcastOperatorName);
_meansReducerReceiver = _commGroup.GetReduceReceiver<ProcessedResults>(Constants.MeansReduceOperatorName);
_controlBroadcastSender = _commGroup.GetBroadcastSender<ControlMessage>(Constants.ControlMessageBroadcastOperatorName);
_kMeansExecutionDirectory = executionDirectory;
_isInitialIteration = true;
}
}
public byte[] Call(byte[] memento)
{
// TODO: this belongs to dedicated dataloader layer, will refactor once we have that
string centroidFile = Path.Combine(_kMeansExecutionDirectory, Constants.CentroidsFile);
_centroids = new Centroids(DataPartitionCache.ReadDataFile(centroidFile));
float loss = float.MaxValue;
float newLoss;
while (true)
{
if (_isInitialIteration)
{
// broadcast initial centroids to all slave nodes
Logger.Log(Level.Info, "Broadcasting INITIAL centroids to all slave nodes: " + _centroids);
_isInitialIteration = false;
}
else
{
ProcessedResults results = _meansReducerReceiver.Reduce();
_centroids = new Centroids(results.Means.Select(m => m.Mean).ToList());
Logger.Log(Level.Info, "Broadcasting new centroids to all slave nodes: " + _centroids);
newLoss = results.Loss;
Logger.Log(Level.Info, string.Format(CultureInfo.InvariantCulture, "The new loss value {0} at iteration {1} ", newLoss, _iteration));
if (newLoss > loss)
{
_controlBroadcastSender.Send(ControlMessage.STOP);
throw new InvalidOperationException(
string.Format(CultureInfo.InvariantCulture, "The new loss {0} is larger than previous loss {1}, while loss function must be monotonically decreasing across iterations", newLoss, loss));
}
else if (newLoss.Equals(loss))
{
Logger.Log(Level.Info, string.Format(CultureInfo.InvariantCulture, "KMeans clustering has converged with a loss value of {0} at iteration {1} ", newLoss, _iteration));
break;
}
else
{
loss = newLoss;
}
}
_controlBroadcastSender.Send(ControlMessage.RECEIVE);
_dataBroadcastSender.Send(_centroids);
_iteration++;
}
_controlBroadcastSender.Send(ControlMessage.STOP);
return null;
}
public void Dispose()
{
}
public class AggregateMeans : IReduceFunction<ProcessedResults>
{
[Inject]
public AggregateMeans()
{
}
public ProcessedResults Reduce(IEnumerable<ProcessedResults> elements)
{
List<PartialMean> aggregatedMeans = new List<PartialMean>();
List<PartialMean> totalList = new List<PartialMean>();
float aggregatedLoss = 0;
foreach (var element in elements)
{
totalList.AddRange(element.Means);
aggregatedLoss += element.Loss;
}
// we infer the value of K from the labeled data
int clustersNum = totalList.Max(p => p.Mean.Label) + 1;
for (int i = 0; i < clustersNum; i++)
{
List<PartialMean> means = totalList.Where(m => m.Mean.Label == i).ToList();
aggregatedMeans.Add(new PartialMean(PartialMean.AggreatedMean(means), means.Count));
}
ProcessedResults returnMeans = new ProcessedResults(aggregatedMeans, aggregatedLoss);
Logger.Log(Level.Info, "The true means aggregated by the reduce function: " + returnMeans);
return returnMeans;
}
}
}
}