|  | // Licensed to the Apache Software Foundation (ASF) under one | 
|  | // or more contributor license agreements.  See the NOTICE file | 
|  | // distributed with this work for additional information | 
|  | // regarding copyright ownership.  The ASF licenses this file | 
|  | // to you under the Apache License, Version 2.0 (the | 
|  | // "License"); you may not use this file except in compliance | 
|  | // with the License.  You may obtain a copy of the License at | 
|  | // | 
|  | //   http://www.apache.org/licenses/LICENSE-2.0 | 
|  | // | 
|  | // Unless required by applicable law or agreed to in writing, | 
|  | // software distributed under the License is distributed on an | 
|  | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | 
|  | // KIND, either express or implied.  See the License for the | 
|  | // specific language governing permissions and limitations | 
|  | // under the License. | 
|  |  | 
|  | using System.Collections.Generic; | 
|  | using System.IO; | 
|  | using System.Linq; | 
|  | using Org.Apache.REEF.Tang.Annotations; | 
|  |  | 
|  | namespace Org.Apache.REEF.Examples.MachineLearning.KMeans | 
|  | { | 
|  | /// <summary> | 
|  | /// This is the legacy KmeansTask implemented when group communications are not available | 
|  | /// It is still being used for plain KMeans without REEF, we probably want to refactor it later | 
|  | /// to reflect that | 
|  | /// </summary> | 
|  | public class LegacyKMeansTask | 
|  | { | 
|  | private readonly int _clustersNum; | 
|  | private readonly DataPartitionCache _dataPartition; | 
|  | private readonly string _kMeansExecutionDirectory; | 
|  |  | 
|  | private Centroids _centroids; | 
|  | private List<PartialMean> _partialMeans; | 
|  |  | 
|  | [Inject] | 
|  | public LegacyKMeansTask( | 
|  | DataPartitionCache dataPartition, | 
|  | [Parameter(Value = typeof(KMeansConfiguratioinOptions.K))] int clustersNumber, | 
|  | [Parameter(Value = typeof(KMeansConfiguratioinOptions.ExecutionDirectory))] string executionDirectory) | 
|  | { | 
|  | _dataPartition = dataPartition; | 
|  | _clustersNum = clustersNumber; | 
|  | _kMeansExecutionDirectory = executionDirectory; | 
|  | if (_centroids == null) | 
|  | { | 
|  | string centroidFile = Path.Combine(_kMeansExecutionDirectory, Constants.CentroidsFile); | 
|  | _centroids = new Centroids(DataPartitionCache.ReadDataFile(centroidFile)); | 
|  | } | 
|  | } | 
|  |  | 
|  | public static float ComputeLossFunction(List<DataVector> centroids, List<DataVector> labeledData) | 
|  | { | 
|  | float d = 0; | 
|  | for (int i = 0; i < centroids.Count; i++) | 
|  | { | 
|  | DataVector centroid = centroids[i]; | 
|  | List<DataVector> slice = labeledData.Where(v => v.Label == centroid.Label).ToList(); | 
|  | d += centroid.DistanceTo(slice); | 
|  | } | 
|  | return d; | 
|  | } | 
|  |  | 
|  | public byte[] CallWithWritingToFileSystem(byte[] memento) | 
|  | { | 
|  | string centroidFile = Path.Combine(_kMeansExecutionDirectory, Constants.CentroidsFile); | 
|  | _centroids = new Centroids(DataPartitionCache.ReadDataFile(centroidFile)); | 
|  |  | 
|  | _dataPartition.LabelData(_centroids); | 
|  | _partialMeans = ComputePartialMeans(); | 
|  |  | 
|  | // should be replaced with Group Communication | 
|  | using (StreamWriter writer = new StreamWriter( | 
|  | File.OpenWrite(Path.Combine(_kMeansExecutionDirectory, Constants.DataDirectory, Constants.PartialMeanFilePrefix + _dataPartition.Partition)))) | 
|  | { | 
|  | for (int i = 0; i < _partialMeans.Count; i++) | 
|  | { | 
|  | writer.WriteLine(_partialMeans[i].ToString()); | 
|  | } | 
|  | writer.Close(); | 
|  | } | 
|  |  | 
|  | return null; | 
|  | } | 
|  |  | 
|  | public List<PartialMean> ComputePartialMeans() | 
|  | { | 
|  | List<PartialMean> partialMeans = new PartialMean[_clustersNum].ToList(); | 
|  | for (int i = 0; i < _clustersNum; i++) | 
|  | { | 
|  | List<DataVector> slices = _dataPartition.DataVectors.Where(d => d.Label == i).ToList(); | 
|  | DataVector average = new DataVector(_dataPartition.DataVectors[0].Dimension); | 
|  |  | 
|  | if (slices.Count > 1) | 
|  | { | 
|  | average = DataVector.Mean(slices); | 
|  | } | 
|  | average.Label = i; | 
|  | partialMeans[i] = new PartialMean(average, slices.Count); | 
|  | } | 
|  | return partialMeans; | 
|  | } | 
|  |  | 
|  | public void Dispose() | 
|  | { | 
|  | } | 
|  | } | 
|  | } |