| /** |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "singa/model/metric.h" |
| #include <algorithm> |
| namespace singa { |
| |
| Tensor Accuracy::Match(const Tensor& predict, const vector<int>& target) { |
| Tensor prediction(predict.shape()); |
| prediction.CopyData(predict); |
| size_t batchsize = target.size(); |
| size_t nb_classes = prediction.Size() / batchsize; |
| // each row of prediction is the prob distribution for one sample |
| CHECK_EQ(prediction.shape().at(0), batchsize); |
| // TODO(wangwei) CloneToDevice(host); |
| const float* prob = prediction.data<float>(); |
| float* score = new float[batchsize]; |
| memset(score, 0, batchsize * sizeof(float)); |
| for (size_t b = 0; b < batchsize; b++) { |
| vector<std::pair<float, int>> prob_class; |
| for (size_t c = 0; c < nb_classes; c++) { |
| prob_class.push_back(std::make_pair(prob[b * nb_classes + c], c)); |
| } |
| std::partial_sort(prob_class.begin(), prob_class.begin() + top_k_, |
| prob_class.end(), std::greater<std::pair<float, int>>()); |
| |
| for (size_t k = 0; k < top_k_; k++) |
| if (prob_class.at(k).second == target.at(b)) score[b] = 1; |
| } |
| Tensor ret(Shape{batchsize}); |
| ret.CopyDataFromHostPtr(score, batchsize); |
| delete [] score; |
| return ret; |
| } |
| |
| // TODO(wangwei) consider multi-label cases, where target is of shape |
| // nb_samples * nb_classes |
| Tensor Accuracy::Forward(const Tensor& prediction, const Tensor& t) { |
| Tensor target(t.shape(), t.data_type()); |
| target.CopyData(t); |
| vector<int> target_vec; |
| // TODO(wangwei) copy target to host. |
| const int* target_value = target.data<int>(); |
| for (size_t i = 0; i < target.Size(); i++) |
| target_vec.push_back(target_value[i]); |
| return Match(prediction, target_vec); |
| } |
| |
| } // namespace singa |