blob: f97b76519e0e89910a074f4e00b6e46f64514a31 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2016 by Contributors
* \file caffe_fieldentry.h
* \brief Implement FieldEntry<caffe::LayerParameter>
* \author Haoran Wang
*/
#ifndef PLUGIN_CAFFE_CAFFE_FIELDENTRY_H_
#define PLUGIN_CAFFE_CAFFE_FIELDENTRY_H_
#include <caffe/proto/caffe.pb.h>
#include <dmlc/parameter.h>
#include <dmlc/base.h>
#include <dmlc/json.h>
#include <dmlc/logging.h>
#include <dmlc/type_traits.h>
#include <google/protobuf/message.h>
#include <google/protobuf/text_format.h>
#include <cstddef>
#include <cstdlib>
#include <sstream>
#include <limits>
#include <map>
#include <set>
#include <typeinfo>
#include <string>
#include <vector>
#include <algorithm>
#include <utility>
#include <caffe/util/io.hpp>
namespace dmlc {
namespace parameter {
// specialize define for Layer Parameter
template<>
class FieldEntry<caffe::LayerParameter>
: public FieldEntryBase<FieldEntry<caffe::LayerParameter>, caffe::LayerParameter> {
public:
// parent class
typedef FieldEntryBase<FieldEntry<caffe::LayerParameter>, caffe::LayerParameter> Parent;
bool ReadProtoFromTextContent(const std::string& text,
::google::protobuf::Message* proto) const {
bool success = google::protobuf::TextFormat::ParseFromString(text, proto);
return success;
}
/**
* /brief Customize set method for LayerParameter
* /tparam value string of caffe's layer configuration
* */
virtual void Set(void *head, const std::string &value) const {
caffe::NetParameter net_param;
if (!ReadProtoFromTextContent(value, &net_param))
CHECK(false)<< "Caffe Net Prototxt: " << value << "Initialized Failed";
CHECK_EQ(net_param.layer_size(), 1) << "Prototxt" << value <<" more than a layer";
caffe::LayerParameter *layer_param = new caffe::LayerParameter(net_param.layer(0));
this->Get(head) = (*layer_param);
}
virtual void PrintValue(std::ostream &os, caffe::LayerParameter value) const { // NOLINT(*)
}
virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*)
std::string s;
caffe::NetParameter np;
// Avoid wasting time making a copy -- just push in out default object's pointer
np.mutable_layer()->AddAllocated(const_cast<::caffe::LayerParameter *>(&default_value_));
google::protobuf::TextFormat::PrintToString(np, &s);
np.mutable_layer()->ReleaseLast();
os << '\'' << s << '\'';
}
// override set_default
inline FieldEntry<caffe::LayerParameter> &set_default(const std::string &value) {
caffe::NetParameter net_param;
if (!ReadProtoFromTextContent(value, &net_param))
CHECK(false)<< "Caffe Net Prototxt: " << value << "Initialized Failed";
CHECK_EQ(net_param.layer_size(), 1) << "Protoxt " << value <<" is more than one layer";
default_value_ = caffe::LayerParameter(net_param.layer(0));
has_default_ = true;
// return self to allow chaining
return this->self();
}
};
} // namespace parameter
} // namespace dmlc
#endif // PLUGIN_CAFFE_CAFFE_FIELDENTRY_H_