| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.flink.ml.clustering.kmeans; |
| |
| import org.apache.flink.api.common.functions.MapFunction; |
| import org.apache.flink.api.common.functions.MapPartitionFunction; |
| import org.apache.flink.api.common.functions.ReduceFunction; |
| import org.apache.flink.api.common.state.ListState; |
| import org.apache.flink.api.common.state.ListStateDescriptor; |
| import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo; |
| import org.apache.flink.api.common.typeinfo.TypeInformation; |
| import org.apache.flink.api.java.tuple.Tuple2; |
| import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo; |
| import org.apache.flink.api.java.typeutils.TupleTypeInfo; |
| import org.apache.flink.iteration.DataStreamList; |
| import org.apache.flink.iteration.IterationBody; |
| import org.apache.flink.iteration.IterationBodyResult; |
| import org.apache.flink.iteration.IterationConfig; |
| import org.apache.flink.iteration.IterationListener; |
| import org.apache.flink.iteration.Iterations; |
| import org.apache.flink.iteration.ReplayableDataStreamList; |
| import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache; |
| import org.apache.flink.iteration.datacache.nonkeyed.OperatorScopeManagedMemoryManager; |
| import org.apache.flink.iteration.operator.OperatorStateUtils; |
| import org.apache.flink.ml.api.Estimator; |
| import org.apache.flink.ml.common.datastream.DataStreamUtils; |
| import org.apache.flink.ml.common.distance.DistanceMeasure; |
| import org.apache.flink.ml.common.iteration.ForwardInputsOfLastRound; |
| import org.apache.flink.ml.common.iteration.TerminateOnMaxIter; |
| import org.apache.flink.ml.linalg.BLAS; |
| import org.apache.flink.ml.linalg.DenseVector; |
| import org.apache.flink.ml.linalg.Vector; |
| import org.apache.flink.ml.linalg.VectorWithNorm; |
| import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; |
| import org.apache.flink.ml.linalg.typeinfo.VectorWithNormSerializer; |
| import org.apache.flink.ml.param.Param; |
| import org.apache.flink.ml.util.ParamUtils; |
| import org.apache.flink.ml.util.ReadWriteUtils; |
| import org.apache.flink.runtime.jobgraph.OperatorID; |
| import org.apache.flink.runtime.state.StateInitializationContext; |
| import org.apache.flink.runtime.state.StateSnapshotContext; |
| import org.apache.flink.streaming.api.datastream.DataStream; |
| import org.apache.flink.streaming.api.operators.AbstractStreamOperator; |
| import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; |
| import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; |
| import org.apache.flink.streaming.runtime.tasks.StreamTask; |
| import org.apache.flink.table.api.Table; |
| import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; |
| import org.apache.flink.table.api.internal.TableImpl; |
| import org.apache.flink.util.Collector; |
| import org.apache.flink.util.Preconditions; |
| |
| import java.io.IOException; |
| import java.util.ArrayList; |
| import java.util.Arrays; |
| import java.util.HashMap; |
| import java.util.List; |
| import java.util.Map; |
| import java.util.Objects; |
| |
| import static org.apache.flink.iteration.utils.DataStreamUtils.setManagedMemoryWeight; |
| |
| /** |
| * An Estimator which implements the k-means clustering algorithm. |
| * |
| * <p>See https://en.wikipedia.org/wiki/K-means_clustering. |
| */ |
| public class KMeans implements Estimator<KMeans, KMeansModel>, KMeansParams<KMeans> { |
| private final Map<Param<?>, Object> paramMap = new HashMap<>(); |
| |
| public KMeans() { |
| ParamUtils.initializeMapWithDefaultValues(paramMap, this); |
| } |
| |
| @Override |
| public KMeansModel fit(Table... inputs) { |
| Preconditions.checkArgument(inputs.length == 1); |
| |
| StreamTableEnvironment tEnv = |
| (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); |
| DataStream<DenseVector> points = |
| tEnv.toDataStream(inputs[0]) |
| .map(row -> ((Vector) row.getField(getFeaturesCol())).toDense()); |
| |
| DataStream<DenseVector[]> initCentroids = selectRandomCentroids(points, getK(), getSeed()); |
| |
| IterationConfig config = |
| IterationConfig.newBuilder() |
| .setOperatorLifeCycle(IterationConfig.OperatorLifeCycle.ALL_ROUND) |
| .build(); |
| |
| IterationBody body = |
| new KMeansIterationBody( |
| getMaxIter(), DistanceMeasure.getInstance(getDistanceMeasure())); |
| |
| DataStream<KMeansModelData> finalModelData = |
| Iterations.iterateBoundedStreamsUntilTermination( |
| DataStreamList.of(initCentroids), |
| ReplayableDataStreamList.notReplay(points), |
| config, |
| body) |
| .get(0); |
| |
| Table finalModelDataTable = tEnv.fromDataStream(finalModelData); |
| KMeansModel model = new KMeansModel().setModelData(finalModelDataTable); |
| ParamUtils.updateExistingParams(model, paramMap); |
| return model; |
| } |
| |
| @Override |
| public void save(String path) throws IOException { |
| ReadWriteUtils.saveMetadata(this, path); |
| } |
| |
| public static KMeans load(StreamTableEnvironment tEnv, String path) throws IOException { |
| return ReadWriteUtils.loadStageParam(path); |
| } |
| |
| @Override |
| public Map<Param<?>, Object> getParamMap() { |
| return paramMap; |
| } |
| |
| private static class KMeansIterationBody implements IterationBody { |
| private final int maxIterationNum; |
| private final DistanceMeasure distanceMeasure; |
| |
| public KMeansIterationBody(int maxIterationNum, DistanceMeasure distanceMeasure) { |
| this.maxIterationNum = maxIterationNum; |
| this.distanceMeasure = distanceMeasure; |
| } |
| |
| @Override |
| public IterationBodyResult process( |
| DataStreamList variableStreams, DataStreamList dataStreams) { |
| DataStream<DenseVector[]> centroids = variableStreams.get(0); |
| DataStream<DenseVector> points = dataStreams.get(0); |
| |
| DataStream<Integer> terminationCriteria = |
| centroids.flatMap(new TerminateOnMaxIter(maxIterationNum)); |
| |
| DataStream<Tuple2<Integer[], DenseVector[]>> centroidIdAndPoints = |
| points.connect(centroids.broadcast()) |
| .transform( |
| "CentroidsUpdateAccumulator", |
| new TupleTypeInfo<>( |
| BasicArrayTypeInfo.INT_ARRAY_TYPE_INFO, |
| ObjectArrayTypeInfo.getInfoFor( |
| DenseVectorTypeInfo.INSTANCE)), |
| new CentroidsUpdateAccumulator(distanceMeasure)); |
| |
| setManagedMemoryWeight(centroidIdAndPoints, 100); |
| |
| int parallelism = centroidIdAndPoints.getParallelism(); |
| DataStream<KMeansModelData> newModelData = |
| centroidIdAndPoints |
| .countWindowAll(parallelism) |
| .reduce(new CentroidsUpdateReducer()) |
| .map(new ModelDataGenerator()); |
| |
| DataStream<DenseVector[]> newCentroids = |
| newModelData.map(x -> x.centroids).setParallelism(1); |
| |
| DataStream<KMeansModelData> finalModelData = |
| newModelData.flatMap(new ForwardInputsOfLastRound<>()); |
| |
| return new IterationBodyResult( |
| DataStreamList.of(newCentroids), |
| DataStreamList.of(finalModelData), |
| terminationCriteria); |
| } |
| } |
| |
| private static class CentroidsUpdateReducer |
| implements ReduceFunction<Tuple2<Integer[], DenseVector[]>> { |
| @Override |
| public Tuple2<Integer[], DenseVector[]> reduce( |
| Tuple2<Integer[], DenseVector[]> tuple2, Tuple2<Integer[], DenseVector[]> t1) |
| throws Exception { |
| for (int i = 0; i < tuple2.f0.length; i++) { |
| tuple2.f0[i] += t1.f0[i]; |
| BLAS.axpy(1.0, t1.f1[i], tuple2.f1[i]); |
| } |
| |
| return tuple2; |
| } |
| } |
| |
| private static class ModelDataGenerator |
| implements MapFunction<Tuple2<Integer[], DenseVector[]>, KMeansModelData> { |
| @Override |
| public KMeansModelData map(Tuple2<Integer[], DenseVector[]> tuple2) throws Exception { |
| double[] weights = new double[tuple2.f0.length]; |
| for (int i = 0; i < tuple2.f0.length; i++) { |
| BLAS.scal(1.0 / tuple2.f0[i], tuple2.f1[i]); |
| weights[i] = tuple2.f0[i]; |
| } |
| |
| return new KMeansModelData(tuple2.f1, new DenseVector(weights)); |
| } |
| } |
| |
| private static class CentroidsUpdateAccumulator |
| extends AbstractStreamOperator<Tuple2<Integer[], DenseVector[]>> |
| implements TwoInputStreamOperator< |
| DenseVector, DenseVector[], Tuple2<Integer[], DenseVector[]>>, |
| IterationListener<Tuple2<Integer[], DenseVector[]>> { |
| |
| private final DistanceMeasure distanceMeasure; |
| |
| private ListState<DenseVector[]> centroids; |
| |
| private ListStateWithCache<VectorWithNorm> points; |
| |
| public CentroidsUpdateAccumulator(DistanceMeasure distanceMeasure) { |
| super(); |
| this.distanceMeasure = distanceMeasure; |
| } |
| |
| @Override |
| public void initializeState(StateInitializationContext context) throws Exception { |
| super.initializeState(context); |
| |
| TypeInformation<DenseVector[]> type = |
| ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE); |
| |
| centroids = |
| context.getOperatorStateStore() |
| .getListState(new ListStateDescriptor<>("centroids", type)); |
| |
| final StreamTask<?, ?> containingTask = getContainingTask(); |
| final OperatorID operatorID = config.getOperatorID(); |
| final OperatorScopeManagedMemoryManager manager = |
| OperatorScopeManagedMemoryManager.getOrCreate(containingTask, operatorID); |
| final String stateKey = "points-state"; |
| manager.register(stateKey, 1.); |
| points = |
| new ListStateWithCache<>( |
| new VectorWithNormSerializer(), stateKey, context, this); |
| } |
| |
| @Override |
| public void snapshotState(StateSnapshotContext context) throws Exception { |
| super.snapshotState(context); |
| points.snapshotState(context); |
| } |
| |
| @Override |
| public void processElement1(StreamRecord<DenseVector> streamRecord) throws Exception { |
| points.add(new VectorWithNorm(streamRecord.getValue())); |
| } |
| |
| @Override |
| public void processElement2(StreamRecord<DenseVector[]> streamRecord) throws Exception { |
| Preconditions.checkState(!centroids.get().iterator().hasNext()); |
| centroids.add(streamRecord.getValue()); |
| } |
| |
| @Override |
| public void onEpochWatermarkIncremented( |
| int epochWatermark, |
| Context context, |
| Collector<Tuple2<Integer[], DenseVector[]>> out) |
| throws Exception { |
| DenseVector[] centroidValues = |
| Objects.requireNonNull( |
| OperatorStateUtils.getUniqueElement(centroids, "centroids") |
| .orElse(null)); |
| |
| VectorWithNorm[] centroidsWithNorm = new VectorWithNorm[centroidValues.length]; |
| for (int i = 0; i < centroidsWithNorm.length; i++) { |
| centroidsWithNorm[i] = new VectorWithNorm(centroidValues[i]); |
| } |
| |
| DenseVector[] newCentroids = new DenseVector[centroidValues.length]; |
| Integer[] counts = new Integer[centroidValues.length]; |
| Arrays.fill(counts, 0); |
| for (int i = 0; i < centroidValues.length; i++) { |
| newCentroids[i] = new DenseVector(centroidValues[i].size()); |
| } |
| |
| for (VectorWithNorm point : points.get()) { |
| int closestCentroidId = distanceMeasure.findClosest(centroidsWithNorm, point); |
| BLAS.axpy(1.0, point.vector, newCentroids[closestCentroidId]); |
| counts[closestCentroidId]++; |
| } |
| |
| output.collect(new StreamRecord<>(Tuple2.of(counts, newCentroids))); |
| |
| centroids.clear(); |
| } |
| |
| @Override |
| public void onIterationTerminated( |
| Context context, Collector<Tuple2<Integer[], DenseVector[]>> collector) { |
| centroids.clear(); |
| points.clear(); |
| } |
| } |
| |
| public static DataStream<DenseVector[]> selectRandomCentroids( |
| DataStream<DenseVector> data, int k, long seed) { |
| DataStream<DenseVector[]> resultStream = |
| DataStreamUtils.mapPartition( |
| DataStreamUtils.sample(data, k, seed), |
| new MapPartitionFunction<DenseVector, DenseVector[]>() { |
| @Override |
| public void mapPartition( |
| Iterable<DenseVector> iterable, |
| Collector<DenseVector[]> collector) { |
| List<DenseVector> list = new ArrayList<>(); |
| iterable.iterator().forEachRemaining(list::add); |
| collector.collect(list.toArray(new DenseVector[0])); |
| } |
| }); |
| resultStream.getTransformation().setParallelism(1); |
| return resultStream; |
| } |
| } |