blob: 1aaa3a3fd29c5b599b8cf3ef839f8b59ed421acc [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Org.Apache.REEF.Tang.Annotations;
namespace Org.Apache.REEF.Examples.MachineLearning.KMeans
{
/// <summary>
/// This is the legacy KmeansTask implemented when group communications are not available
/// It is still being used for plain KMeans without REEF, we probably want to refactor it later
/// to reflect that
/// </summary>
public class LegacyKMeansTask
{
private readonly int _clustersNum;
private readonly DataPartitionCache _dataPartition;
private readonly string _kMeansExecutionDirectory;
private Centroids _centroids;
private List<PartialMean> _partialMeans;
[Inject]
public LegacyKMeansTask(
DataPartitionCache dataPartition,
[Parameter(Value = typeof(KMeansConfiguratioinOptions.K))] int clustersNumber,
[Parameter(Value = typeof(KMeansConfiguratioinOptions.ExecutionDirectory))] string executionDirectory)
{
_dataPartition = dataPartition;
_clustersNum = clustersNumber;
_kMeansExecutionDirectory = executionDirectory;
if (_centroids == null)
{
string centroidFile = Path.Combine(_kMeansExecutionDirectory, Constants.CentroidsFile);
_centroids = new Centroids(DataPartitionCache.ReadDataFile(centroidFile));
}
}
public static float ComputeLossFunction(List<DataVector> centroids, List<DataVector> labeledData)
{
float d = 0;
for (int i = 0; i < centroids.Count; i++)
{
DataVector centroid = centroids[i];
List<DataVector> slice = labeledData.Where(v => v.Label == centroid.Label).ToList();
d += centroid.DistanceTo(slice);
}
return d;
}
public byte[] CallWithWritingToFileSystem(byte[] memento)
{
string centroidFile = Path.Combine(_kMeansExecutionDirectory, Constants.CentroidsFile);
_centroids = new Centroids(DataPartitionCache.ReadDataFile(centroidFile));
_dataPartition.LabelData(_centroids);
_partialMeans = ComputePartialMeans();
// should be replaced with Group Communication
using (StreamWriter writer = new StreamWriter(
File.OpenWrite(Path.Combine(_kMeansExecutionDirectory, Constants.DataDirectory, Constants.PartialMeanFilePrefix + _dataPartition.Partition))))
{
for (int i = 0; i < _partialMeans.Count; i++)
{
writer.WriteLine(_partialMeans[i].ToString());
}
writer.Close();
}
return null;
}
public List<PartialMean> ComputePartialMeans()
{
List<PartialMean> partialMeans = new PartialMean[_clustersNum].ToList();
for (int i = 0; i < _clustersNum; i++)
{
List<DataVector> slices = _dataPartition.DataVectors.Where(d => d.Label == i).ToList();
DataVector average = new DataVector(_dataPartition.DataVectors[0].Dimension);
if (slices.Count > 1)
{
average = DataVector.Mean(slices);
}
average.Label = i;
partialMeans[i] = new PartialMean(average, slices.Count);
}
return partialMeans;
}
public void Dispose()
{
}
}
}