| // Licensed to the Apache Software Foundation (ASF) under one or more |
| // contributor license agreements. See the NOTICE file distributed with |
| // this work for additional information regarding copyright ownership. |
| // The ASF licenses this file to You under the Apache License, Version 2.0 |
| // (the "License"); you may not use this file except in compliance with |
| // the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| = Split the dataset on test and train datasets |
| |
| Data splitting is meant to split the data stored in a cache into two parts: the training part that is used to train the model, and the test part that is used to estimate the model quality. |
| |
| All fit() methods has a special parameter to pass a filter condition to each cache. |
| |
| [NOTE] |
| ==== |
| Due to distributed and lazy nature of dataset operations, the dataset split is the lazy operation too and could be defined as a filter condition that could be applied to the initial cache to form both, the train and test datasets. |
| ==== |
| |
| In the example below the model is trained only on 75% of the initial dataset. The filter parameter value is the result of the `split.getTrainFilter()` that could continue with or reject the row from the initial dataset to handle it during the training. |
| |
| |
| [source, java] |
| ---- |
| // Define the cache. |
| IgniteCache<Integer, Vector> dataCache = ...; |
| |
| // Define the percentage of the train sub-set of the initial dataset. |
| TrainTestSplit<Integer, Vector> split = new TrainTestDatasetSplitter<>().split(0.75); |
| |
| IgniteModel<Vector, Double> mdl = trainer |
| .fit(ignite, dataCache, split.getTrainFilter(), vectorizer); |
| ---- |
| |
| |
| The `split.getTestFilter()` could be used to validate the model on the test data. |
| Below is the example of working with the cache directly: printing the predicted and real regression value from the test sub-set of the initial dataset. |
| |
| |
| [source, java] |
| ---- |
| // Define the cache query and set the filter. |
| ScanQuery<Integer, Vector> qry = new ScanQuery<>(); |
| qry.setFilter(split.getTestFilter()); |
| |
| |
| try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(qry)) { |
| for (Cache.Entry<Integer, Vector> observation : observations) { |
| Vector val = observation.getValue(); |
| Vector inputs = val.copyOfRange(1, val.size()); |
| double groundTruth = val.get(0); |
| |
| double prediction = mdl.predict(inputs); |
| |
| System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); |
| } |
| } |
| ---- |
| |
| |