From 4f250248b62ee6592a679060fc12bb13778f70cf Mon Sep 17 00:00:00 2001 From: KUMAZAKI Hiroki Date: Thu, 12 Jun 2014 21:18:10 +0900 Subject: [PATCH] export_model interface of driver::classifier --- jubatus/core/anomaly/anomaly_factory.cpp | 8 +- jubatus/core/classifier/arow.cpp | 18 +- jubatus/core/classifier/arow.hpp | 6 +- jubatus/core/classifier/classifier_base.hpp | 14 ++ jubatus/core/classifier/classifier_config.cpp | 69 ++++++ jubatus/core/classifier/classifier_config.hpp | 83 ++++++- .../core/classifier/classifier_factory.cpp | 226 ++++++------------ .../core/classifier/classifier_factory.hpp | 5 +- jubatus/core/classifier/classifier_test.cpp | 2 +- .../core/classifier/confidence_weighted.cpp | 20 +- .../core/classifier/confidence_weighted.hpp | 8 +- jubatus/core/classifier/linear_classifier.cpp | 24 +- jubatus/core/classifier/linear_classifier.hpp | 8 +- .../nearest_neighbor_classifier.cpp | 13 + .../nearest_neighbor_classifier.hpp | 6 +- jubatus/core/classifier/normal_herd.cpp | 16 +- jubatus/core/classifier/normal_herd.hpp | 8 +- .../core/classifier/passive_aggressive.cpp | 5 +- .../core/classifier/passive_aggressive.hpp | 2 +- .../core/classifier/passive_aggressive_1.cpp | 17 +- .../core/classifier/passive_aggressive_1.hpp | 9 +- .../core/classifier/passive_aggressive_2.cpp | 16 +- .../core/classifier/passive_aggressive_2.hpp | 9 +- jubatus/core/classifier/perceptron.cpp | 5 +- jubatus/core/classifier/perceptron.hpp | 2 +- jubatus/core/classifier/wscript | 1 + jubatus/core/common/byte_buffer.hpp | 7 + jubatus/core/common/export_model.hpp | 36 +++ jubatus/core/common/key_manager.hpp | 6 + jubatus/core/driver/classifier.cpp | 147 +++++++++--- jubatus/core/driver/classifier.hpp | 36 ++- jubatus/core/driver/recommender.hpp | 1 + jubatus/core/framework/mixable_helper.hpp | 4 + .../core/fv_converter/converter_config.hpp | 19 +- .../fv_converter/datum_to_fv_converter.cpp | 1 + .../fv_converter/datum_to_fv_converter.hpp | 4 + jubatus/core/fv_converter/weight_manager.hpp | 3 + .../core/recommender/recommender_factory.cpp | 8 +- jubatus/core/storage/local_storage.cpp | 4 + jubatus/core/storage/local_storage.hpp | 4 + .../core/storage/local_storage_mixture.cpp | 24 ++ .../core/storage/local_storage_mixture.hpp | 3 + jubatus/core/storage/storage_base.hpp | 3 + jubatus/core/unlearner/lru_unlearner.cpp | 38 ++- jubatus/core/unlearner/lru_unlearner.hpp | 15 +- jubatus/core/unlearner/random_unlearner.cpp | 41 +++- jubatus/core/unlearner/random_unlearner.hpp | 14 +- jubatus/core/unlearner/unlearner_base.hpp | 8 + jubatus/core/unlearner/unlearner_config.hpp | 47 ++++ jubatus/core/unlearner/unlearner_factory.cpp | 33 ++- jubatus/core/unlearner/unlearner_factory.hpp | 6 +- 51 files changed, 782 insertions(+), 330 deletions(-) create mode 100644 jubatus/core/classifier/classifier_config.cpp create mode 100644 jubatus/core/common/export_model.hpp create mode 100644 jubatus/core/unlearner/unlearner_config.hpp diff --git a/jubatus/core/anomaly/anomaly_factory.cpp b/jubatus/core/anomaly/anomaly_factory.cpp index 1fc1bc35..d2924fe9 100644 --- a/jubatus/core/anomaly/anomaly_factory.cpp +++ b/jubatus/core/anomaly/anomaly_factory.cpp @@ -24,6 +24,7 @@ #include "../common/exception.hpp" #include "../common/jsonconfig.hpp" #include "../nearest_neighbor/nearest_neighbor_factory.hpp" +#include "../unlearner/unlearner_config.hpp" #include "../unlearner/unlearner_factory.hpp" #include "../storage/column_table.hpp" #include "../recommender/recommender_factory.hpp" @@ -100,10 +101,11 @@ shared_ptr anomaly_factory::create_anomaly( << common::exception::error_message( "unlearner is set but unlearner_parameter is not found")); } + shared_ptr unl_conf( + unlearner::create_unlearner_config(*conf.unlearner, + *conf.unlearner_parameter)); jubatus::util::lang::shared_ptr unlearner( - unlearner::create_unlearner( - *conf.unlearner, - *conf.unlearner_parameter)); + unlearner::create_unlearner(unl_conf)); return shared_ptr( new light_lof(conf, id, nearest_neighbor_engine, unlearner)); } diff --git a/jubatus/core/classifier/arow.cpp b/jubatus/core/classifier/arow.cpp index 4719e923..2812b584 100644 --- a/jubatus/core/classifier/arow.cpp +++ b/jubatus/core/classifier/arow.cpp @@ -23,24 +23,20 @@ #include "jubatus/util/concurrent/lock.h" #include "classifier_util.hpp" #include "../common/exception.hpp" +#include "../storage/storage_base.hpp" using std::string; +using jubatus::core::storage_ptr; namespace jubatus { namespace core { namespace classifier { -arow::arow(storage_ptr storage) - : linear_classifier(storage) { -} - -arow::arow( - const classifier_config& config, - storage_ptr storage) - : linear_classifier(storage), - config_(config) { +arow::arow(float regularization_weight) + : linear_classifier(), + regularization_weight_(regularization_weight) { - if (!(0.f < config.regularization_weight)) { + if (!(0.f < regularization_weight_)) { throw JUBATUS_EXCEPTION( common::invalid_parameter("0.0 < regularization_weight")); } @@ -58,7 +54,7 @@ void arow::train(const common::sfv_t& sfv, const string& label) { return; } - float beta = 1.f / (variance + 1.f / config_.regularization_weight); + float beta = 1.f / (variance + 1.f / regularization_weight_); float alpha = (1.f - margin) * beta; // max(0, 1 - margin) = 1 - margin update(sfv, alpha, beta, label, incorrect_label); } diff --git a/jubatus/core/classifier/arow.hpp b/jubatus/core/classifier/arow.hpp index 4c74f3c0..321dcc6c 100644 --- a/jubatus/core/classifier/arow.hpp +++ b/jubatus/core/classifier/arow.hpp @@ -24,11 +24,11 @@ namespace jubatus { namespace core { namespace classifier { +struct classifier_parameter; class arow : public linear_classifier { public: - explicit arow(storage_ptr storage); - arow(const classifier_config& config, storage_ptr storage); + arow(float regularization_weight); void train(const common::sfv_t& fv, const std::string& label); std::string name() const; private: @@ -38,7 +38,7 @@ class arow : public linear_classifier { float beta, const std::string& pos_label, const std::string& neg_label); - classifier_config config_; + float regularization_weight_; }; } // namespace classifier diff --git a/jubatus/core/classifier/classifier_base.hpp b/jubatus/core/classifier/classifier_base.hpp index 826a5bb4..28e1eac3 100644 --- a/jubatus/core/classifier/classifier_base.hpp +++ b/jubatus/core/classifier/classifier_base.hpp @@ -29,8 +29,17 @@ #include "../unlearner/unlearner_base.hpp" #include "classifier_type.hpp" +namespace msgpack { + +} namespace jubatus { namespace core { +namespace common { +class byte_buffer; +} +namespace framework { +class jubatus_packer; +} namespace classifier { class classifier_base { @@ -48,6 +57,8 @@ class classifier_base { virtual void set_label_unlearner( jubatus::util::lang::shared_ptr label_unlearner) = 0; + virtual jubatus::util::lang::shared_ptr + get_label_unlearner() const = 0; virtual bool delete_label(const std::string& label) = 0; virtual std::vector get_labels() const = 0; @@ -61,6 +72,9 @@ class classifier_base { virtual void unpack(msgpack::object o) = 0; virtual void clear() = 0; + void import_model(msgpack::object& src); + void export_model(framework::packer& dst) const; + virtual framework::mixable* get_mixable() = 0; }; diff --git a/jubatus/core/classifier/classifier_config.cpp b/jubatus/core/classifier/classifier_config.cpp new file mode 100644 index 00000000..ff3ae5d7 --- /dev/null +++ b/jubatus/core/classifier/classifier_config.cpp @@ -0,0 +1,69 @@ +// Jubatus: Online machine learning framework for distributed environment +// Copyright (C) 2012 Preferred Networks and Nippon Telegraph and Telephone Corporation. +// +// This library is free software; you can redistribute it and/or +// modify it under the terms of the GNU Lesser General Public +// License version 2.1 as published by the Free Software Foundation. +// +// This library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public +// License along with this library; if not, write to the Free Software +// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + +#include "classifier_config.hpp" + +#include "jubatus/util/data/serialization.h" +#include "jubatus/util/data/optional.h" +#include "jubatus/util/lang/shared_ptr.h" +#include "../common/jsonconfig.hpp" +#include "../unlearner/unlearner_config.hpp" +#include "../storage/column_table.hpp" +#include "../nearest_neighbor/nearest_neighbor_base.hpp" + +using jubatus::util::lang::shared_ptr; +using jubatus::core::common::jsonconfig::config_cast_check; + +namespace jubatus { +namespace core { +namespace classifier { +classifier_config::classifier_config(const std::string& method, + const common::jsonconfig::config& param) { + if (method == "perceptron" || + method == "PA" || method == "passive_aggressive") { + if (param.type() != jubatus::util::text::json::json::Null) { + throw JUBATUS_EXCEPTION( + common::config_exception() << common::exception::error_message( + "parameter block is specified in config")); + } + conf_.reset(new detail::unlearning_classifier_config(method, param)); + } else if (method == "PA1" || method == "passive_aggressive_1" || + method == "PA2" || method == "passive_aggressive_2" || + method == "CW" || method == "confidence_weighted" || + method == "AROW" || method == "arow" || + method == "NHERD" || method == "normal_herd") { + if (param.type() == jubatus::util::text::json::json::Null) { + throw JUBATUS_EXCEPTION( + common::config_exception() << common::exception::error_message( + "parameter block is not specified in config")); + } + conf_.reset(new detail::unlearning_classifier_config(method, param)); + } else if (method == "NN" || method == "nearest_neighbor") { + if (param.type() == jubatus::util::text::json::json::Null) { + throw JUBATUS_EXCEPTION( + common::config_exception() << common::exception::error_message( + "parameter block is not specified in config")); + } + conf_.reset(new detail::nearest_neighbor_classifier_config(method, param)); + } else { + throw JUBATUS_EXCEPTION( + common::unsupported_method("classifier(" + method + ")")); + } +} + +} // namespace classifier +} // namespace core +} // namespace jubatus diff --git a/jubatus/core/classifier/classifier_config.hpp b/jubatus/core/classifier/classifier_config.hpp index 6e81ce30..f511c8c1 100644 --- a/jubatus/core/classifier/classifier_config.hpp +++ b/jubatus/core/classifier/classifier_config.hpp @@ -18,21 +18,96 @@ #define JUBATUS_CORE_CLASSIFIER_CLASSIFIER_CONFIG_HPP_ #include "jubatus/util/data/serialization.h" +#include "jubatus/util/data/optional.h" +#include "jubatus/util/lang/shared_ptr.h" +#include "../unlearner/unlearner_config.hpp" namespace jubatus { namespace core { namespace classifier { -struct classifier_config { - classifier_config() - : regularization_weight(1.0f) { +struct classifier_config_base { + std::string method_; + virtual ~classifier_config_base() {} + classifier_config_base(const std::string& method) + : method_(method) { + } + + template + void serialize(Ar& ar) { + ar & JUBA_NAMED_MEMBER("method", method_); } +}; +namespace detail { +struct classifier_parameter : public classifier_config_base { + classifier_parameter(const std::string& method) + : classifier_config_base(method), + regularization_weight(1.0f) { + } float regularization_weight; template void serialize(Ar& ar) { - ar & JUBA_NAMED_MEMBER("regularization_weight", regularization_weight); + classifier_config_base::serialize(ar); + ar & JUBA_MEMBER(regularization_weight); + } +}; + +struct unlearning_classifier_config : public classifier_parameter { + unlearning_classifier_config(const std::string& method, + const common::jsonconfig::config& param) + : classifier_parameter(method) { + // TODO + if (param.type() == jubatus::util::text::json::json::Null) { + throw JUBATUS_EXCEPTION( + common::config_exception() << common::exception::error_message( + "parameter block is not specified in config")); + } + } + util::lang::shared_ptr unlearner_config_; + + template + void serialize(Ar& ar) { + classifier_config_base::serialize(ar); + unlearner_config_->serialize(ar); + } +}; + +struct nearest_neighbor_classifier_config : public classifier_config_base { + std::string method; + int nearest_neighbor_num; + float local_sensitivity; + util::lang::shared_ptr unlearner_config_; + nearest_neighbor_classifier_config(const std::string& method, + const common::jsonconfig::config& param) + : classifier_config_base(method) { + // TODO + } + + template + void serialize(Ar& ar) { + classifier_config_base::serialize(ar); + ar & JUBA_MEMBER(method) + & JUBA_MEMBER(nearest_neighbor_num) + & JUBA_MEMBER(local_sensitivity); + if (unlearner_config_) { + unlearner_config_->serialize(ar); + } + } +}; + +} // namespace detail + +struct classifier_config { + util::lang::shared_ptr conf_; + classifier_config(const std::string& method, + const common::jsonconfig::config& param); + classifier_config() { + } + template + void serialize(Ar& ar) { + conf_->serialize(ar); } }; diff --git a/jubatus/core/classifier/classifier_factory.cpp b/jubatus/core/classifier/classifier_factory.cpp index 628cbdd9..18041339 100644 --- a/jubatus/core/classifier/classifier_factory.cpp +++ b/jubatus/core/classifier/classifier_factory.cpp @@ -7,183 +7,107 @@ // // This library is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU // Lesser General Public License for more details. // // You should have received a copy of the GNU Lesser General Public // License along with this library; if not, write to the Free Software -// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA +// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA #include "classifier_factory.hpp" #include #include "classifier.hpp" +#include "classifier_config.hpp" #include "../common/exception.hpp" -#include "../common/jsonconfig.hpp" #include "../storage/storage_base.hpp" #include "../unlearner/unlearner_factory.hpp" +#include "../unlearner/unlearner_config.hpp" #include "../nearest_neighbor/nearest_neighbor_factory.hpp" -using jubatus::core::common::jsonconfig::config; -using jubatus::core::common::jsonconfig::config_cast_check; using jubatus::util::lang::shared_ptr; - namespace jubatus { namespace core { namespace classifier { namespace { -struct unlearner_config { - jubatus::util::data::optional unlearner; - jubatus::util::data::optional unlearner_parameter; - - template - void serialize(Ar& ar) { - ar & JUBA_MEMBER(unlearner) & JUBA_MEMBER(unlearner_parameter); - } -}; - -struct unlearning_classifier_config - : public classifier_config, unlearner_config { - template - void serialize(Ar& ar) { - classifier_config::serialize(ar); - unlearner_config::serialize(ar); - } -}; - -struct nearest_neighbor_classifier_config - : public unlearner_config { - std::string method; - config parameter; - int nearest_neighbor_num; - float local_sensitivity; - - template - void serialize(Ar& ar) { - ar & JUBA_MEMBER(method) - & JUBA_MEMBER(parameter) - & JUBA_MEMBER(nearest_neighbor_num) - & JUBA_MEMBER(local_sensitivity); - unlearner_config::serialize(ar); - } -}; - jubatus::util::lang::shared_ptr -create_unlearner(const unlearner_config& conf) { - if (conf.unlearner) { - if (!conf.unlearner_parameter) { - throw JUBATUS_EXCEPTION(common::exception::runtime_error( - "Unlearner is set but unlearner_parameter is not found")); - } - return unlearner::create_unlearner( - *conf.unlearner, *conf.unlearner_parameter); - } else { - return jubatus::util::lang::shared_ptr(); - } +create_unlearner(const detail::unlearning_classifier_config& conf) { + if (conf.unlearner_config_) { + return unlearner::create_unlearner(conf.unlearner_config_); + } else { + return jubatus::util::lang::shared_ptr(); + } } - -} // namespace +} // namespace shared_ptr classifier_factory::create_classifier( - const std::string& name, - const common::jsonconfig::config& param, - jubatus::util::lang::shared_ptr storage) { - jubatus::util::lang::shared_ptr unlearner; - shared_ptr res; - if (name == "perceptron") { - // perceptron doesn't have parameter - if (param.type() != jubatus::util::text::json::json::Null) { - unlearner_config conf = config_cast_check(param); - unlearner = create_unlearner(conf); - } - res.reset(new perceptron(storage)); - } else if (name == "PA" || name == "passive_aggressive") { - // passive_aggressive doesn't have parameter - if (param.type() != jubatus::util::text::json::json::Null) { - unlearner_config conf = config_cast_check(param); - unlearner = create_unlearner(conf); - } - res.reset(new passive_aggressive(storage)); - } else if (name == "PA1" || name == "passive_aggressive_1") { - if (param.type() == jubatus::util::text::json::json::Null) { - throw JUBATUS_EXCEPTION( - common::config_exception() << common::exception::error_message( - "parameter block is not specified in config")); - } - unlearning_classifier_config conf - = config_cast_check(param); - unlearner = create_unlearner(conf); - res.reset(new passive_aggressive_1(conf, storage)); - } else if (name == "PA2" || name == "passive_aggressive_2") { - if (param.type() == jubatus::util::text::json::json::Null) { - throw JUBATUS_EXCEPTION( - common::config_exception() << common::exception::error_message( - "parameter block is not specified in config")); - } - unlearning_classifier_config conf - = config_cast_check(param); - unlearner = create_unlearner(conf); - res.reset(new passive_aggressive_2(conf, storage)); - } else if (name == "CW" || name == "confidence_weighted") { - if (param.type() == jubatus::util::text::json::json::Null) { - throw JUBATUS_EXCEPTION( - common::config_exception() << common::exception::error_message( - "parameter block is not specified in config")); - } - unlearning_classifier_config conf - = config_cast_check(param); - unlearner = create_unlearner(conf); - res.reset(new confidence_weighted(conf, storage)); - } else if (name == "AROW" || name == "arow") { - if (param.type() == jubatus::util::text::json::json::Null) { - throw JUBATUS_EXCEPTION( - common::config_exception() << common::exception::error_message( - "parameter block is not specified in config")); - } - unlearning_classifier_config conf - = config_cast_check(param); - unlearner = create_unlearner(conf); - res.reset(new arow(conf, storage)); - } else if (name == "NHERD" || name == "normal_herd") { - if (param.type() == jubatus::util::text::json::json::Null) { - throw JUBATUS_EXCEPTION( - common::config_exception() << common::exception::error_message( - "parameter block is not specified in config")); - } - unlearning_classifier_config conf - = config_cast_check(param); - unlearner = create_unlearner(conf); - res.reset(new normal_herd(conf, storage)); - } else if (name == "NN" || name == "nearest_neighbor") { - if (param.type() == jubatus::util::text::json::json::Null) { - throw JUBATUS_EXCEPTION( - common::config_exception() << common::exception::error_message( - "parameter block is not specified in config")); + const classifier_config& conf) { + const classifier_config_base* conf_base = conf.conf_.get(); + jubatus::util::lang::shared_ptr unlearner; + shared_ptr res; + + const detail::unlearning_classifier_config* uconf = + dynamic_cast(conf_base); + const detail::nearest_neighbor_classifier_config* nconf = + dynamic_cast(conf_base); + { // unlearner + if (uconf) { + unlearner = create_unlearner(uconf->unlearner_config_); + } else if (nconf) { + unlearner = create_unlearner(nconf->unlearner_config_); + } else { + // no unlearner } - nearest_neighbor_classifier_config conf - = config_cast_check(param); - unlearner = create_unlearner(conf); - shared_ptr table(new storage::column_table); - shared_ptr - nearest_neighbor_engine(nearest_neighbor::create_nearest_neighbor( - conf.method, conf.parameter, table, "")); - res.reset( - new nearest_neighbor_classifier(nearest_neighbor_engine, - conf.nearest_neighbor_num, - conf.local_sensitivity)); - } else { - throw JUBATUS_EXCEPTION( - common::unsupported_method("classifier(" + name + ")")); - } + } - if (unlearner) { - res->set_label_unlearner(unlearner); - } - return res; + if (conf_base->method_ == "perceptron") { + // perceptron doesn't have parameter + JUBATUS_ASSERT(uconf != NULL); + res.reset(new perceptron()); + } else if (conf_base->method_ == "PA" || conf_base->method_ == "passive_aggressive") { + // passive_aggressive doesn't have parameter + JUBATUS_ASSERT(uconf != NULL); + res.reset(new passive_aggressive()); + } else if (conf_base->method_ == "PA1" || conf_base->method_ == "passive_aggressive_1") { + JUBATUS_ASSERT(uconf != NULL); + res.reset(new passive_aggressive_1(uconf->regularization_weight)); + } else if (conf_base->method_ == "PA2" || conf_base->method_ == "passive_aggressive_2") { + JUBATUS_ASSERT(uconf != NULL); + res.reset(new passive_aggressive_2(uconf->regularization_weight)); + } else if (conf_base->method_ == "CW" || conf_base->method_ == "confidence_weighted") { + JUBATUS_ASSERT(uconf != NULL); + res.reset(new confidence_weighted(uconf->regularization_weight)); + } else if (conf_base->method_ == "AROW" || conf_base->method_ == "arow") { + JUBATUS_ASSERT(uconf != NULL); + res.reset(new arow(uconf->regularization_weight)); + } else if (conf_base->method_ == "NHERD" || conf_base->method_ == "normal_herd") { + JUBATUS_ASSERT(uconf != NULL); + res.reset(new normal_herd(uconf->regularization_weight)); + } else if (conf_base->method_ == "NN" || conf_base->method_ == "nearest_neighbor") { + JUBATUS_ASSERT(nconf != NULL); + shared_ptr table(new storage::column_table); + /* // TODO + shared_ptr + nearest_neighbor_engine(nearest_neighbor::create_nearest_neighbor( + conf_base->method_, nconf->parameter, table, "")); + */ + shared_ptr nearest_neighbor_engine; + res.reset( + new nearest_neighbor_classifier(nearest_neighbor_engine, + nconf->nearest_neighbor_num, + nconf->local_sensitivity)); + } else { + throw JUBATUS_EXCEPTION( + common::unsupported_method("classifier(" + conf_base->method_ + ")")); + } + if (unlearner) { + res->set_label_unlearner(unlearner); + } + return res; } -} // namespace classifier -} // namespace core -} // namespace jubatus +} // namespace classifier +} // namespace core +} // namespace jubatus diff --git a/jubatus/core/classifier/classifier_factory.hpp b/jubatus/core/classifier/classifier_factory.hpp index 1632e6e0..208a25c7 100644 --- a/jubatus/core/classifier/classifier_factory.hpp +++ b/jubatus/core/classifier/classifier_factory.hpp @@ -40,13 +40,12 @@ class config; namespace classifier { class classifier_base; +class classifier_config; class classifier_factory { public: static jubatus::util::lang::shared_ptr create_classifier( - const std::string& name, - const common::jsonconfig::config& param, - jubatus::util::lang::shared_ptr storage); + const classifier_config& conf); }; } // namespace classifier diff --git a/jubatus/core/classifier/classifier_test.cpp b/jubatus/core/classifier/classifier_test.cpp index 0fb53333..a33265f9 100644 --- a/jubatus/core/classifier/classifier_test.cpp +++ b/jubatus/core/classifier/classifier_test.cpp @@ -249,7 +249,7 @@ INSTANTIATE_TYPED_TEST_CASE_P(cl, classifier_test, classifier_types); TEST(classifier_config_test, regularization_weight) { storage_ptr s(new local_storage); - classifier_config c; + classifier_parameter c; c.regularization_weight = std::numeric_limits::quiet_NaN(); ASSERT_THROW(passive_aggressive_1 p1(c, s), common::invalid_parameter); diff --git a/jubatus/core/classifier/confidence_weighted.cpp b/jubatus/core/classifier/confidence_weighted.cpp index ebf873cd..3fc8e7e0 100644 --- a/jubatus/core/classifier/confidence_weighted.cpp +++ b/jubatus/core/classifier/confidence_weighted.cpp @@ -23,6 +23,7 @@ #include "jubatus/util/concurrent/lock.h" #include "classifier_util.hpp" #include "../common/exception.hpp" +#include "../storage/local_storage_mixture.hpp" using std::string; @@ -30,17 +31,10 @@ namespace jubatus { namespace core { namespace classifier { -confidence_weighted::confidence_weighted(storage_ptr storage) - : linear_classifier(storage) { -} - -confidence_weighted::confidence_weighted( - const classifier_config& config, - storage_ptr storage) - : linear_classifier(storage), - config_(config) { - - if (!(0.f < config.regularization_weight)) { +confidence_weighted::confidence_weighted(float regularization_weight) + : linear_classifier(), + regularization_weight_(regularization_weight) { + if (!(0.f < regularization_weight_)) { throw JUBATUS_EXCEPTION( common::invalid_parameter("0.0 < regularization_weight")); } @@ -49,7 +43,7 @@ confidence_weighted::confidence_weighted( void confidence_weighted::train(const common::sfv_t& sfv, const string& label) { check_touchable(label); - const float C = config_.regularization_weight; + const float C = regularization_weight_; string incorrect_label; float variance = 0.f; float margin = -calc_margin_and_variance(sfv, label, incorrect_label, @@ -81,7 +75,7 @@ void confidence_weighted::update( storage::val2_t neg_val(0.f, 1.f); ClassifierUtil::get_two(val2, pos_label, neg_label, pos_val, neg_val); - const float C = config_.regularization_weight; + const float C = regularization_weight_; float covar_pos_step = 2.f * step_width * val * val * C; float covar_neg_step = 2.f * step_width * val * val * C; diff --git a/jubatus/core/classifier/confidence_weighted.hpp b/jubatus/core/classifier/confidence_weighted.hpp index 433bc32c..35a1715a 100644 --- a/jubatus/core/classifier/confidence_weighted.hpp +++ b/jubatus/core/classifier/confidence_weighted.hpp @@ -24,13 +24,11 @@ namespace jubatus { namespace core { namespace classifier { +struct classifier_parameter; class confidence_weighted : public linear_classifier { public: - explicit confidence_weighted(storage_ptr storage); - confidence_weighted( - const classifier_config& config, - storage_ptr storage); + confidence_weighted(float regularization_weight); void train(const common::sfv_t& fv, const std::string& label); std::string name() const; private: @@ -39,7 +37,7 @@ class confidence_weighted : public linear_classifier { float step_weigth, const std::string& pos_label, const std::string& neg_label); - classifier_config config_; + float regularization_weight_; }; } // namespace classifier diff --git a/jubatus/core/classifier/linear_classifier.cpp b/jubatus/core/classifier/linear_classifier.cpp index 27e3a789..180b614a 100644 --- a/jubatus/core/classifier/linear_classifier.cpp +++ b/jubatus/core/classifier/linear_classifier.cpp @@ -29,6 +29,8 @@ #include "../common/exception.hpp" #include "classifier_util.hpp" +#include "../storage/storage_base.hpp" +#include "../storage/local_storage_mixture.hpp" using std::string; using std::vector; @@ -39,8 +41,9 @@ namespace jubatus { namespace core { namespace classifier { -linear_classifier::linear_classifier(storage_ptr storage) - : storage_(storage), mixable_storage_(storage_) { +linear_classifier::linear_classifier() + : storage_(new storage::local_storage_mixture()), + mixable_storage_(storage_) { } linear_classifier::~linear_classifier() { @@ -65,6 +68,12 @@ void linear_classifier::set_label_unlearner( unlearner_ = label_unlearner; } +jubatus::util::lang::shared_ptr +linear_classifier::get_label_unlearner() const { + return unlearner_; +} + + void linear_classifier::classify_with_scores( const common::sfv_t& sfv, classify_result& scores) const { @@ -206,10 +215,21 @@ float linear_classifier::squared_norm(const common::sfv_t& fv) { void linear_classifier::pack(framework::packer& pk) const { storage_->pack(pk); } + void linear_classifier::unpack(msgpack::object o) { storage_->unpack(o); } +void linear_classifier::export_model(framework::packer& pk) const { + pk.pack_array(2); // [storage_, unlearner_] + storage_->export_model(pk); + unlearner_->export_model(pk); +} + +void linear_classifier::import_model(msgpack::object o) { + // TODO +} + framework::mixable* linear_classifier::get_mixable() { return &mixable_storage_; } diff --git a/jubatus/core/classifier/linear_classifier.hpp b/jubatus/core/classifier/linear_classifier.hpp index 020ea160..bb31e5ec 100644 --- a/jubatus/core/classifier/linear_classifier.hpp +++ b/jubatus/core/classifier/linear_classifier.hpp @@ -37,18 +37,16 @@ namespace classifier { class linear_classifier : public classifier_base { public: - explicit linear_classifier(storage_ptr storage); virtual ~linear_classifier(); virtual void train(const common::sfv_t& fv, const std::string& label) = 0; + linear_classifier(); void set_label_unlearner( jubatus::util::lang::shared_ptr label_unlearner); jubatus::util::lang::shared_ptr - label_unlearner() const { - return unlearner_; - } + get_label_unlearner() const; std::string classify(const common::sfv_t& fv) const; void classify_with_scores(const common::sfv_t& fv, @@ -69,6 +67,8 @@ class linear_classifier : public classifier_base { void pack(framework::packer& pk) const; void unpack(msgpack::object o); + void export_model(framework::packer& pk) const; + void import_model(msgpack::object o); framework::mixable* get_mixable(); diff --git a/jubatus/core/classifier/nearest_neighbor_classifier.cpp b/jubatus/core/classifier/nearest_neighbor_classifier.cpp index 3285866b..287bc076 100644 --- a/jubatus/core/classifier/nearest_neighbor_classifier.cpp +++ b/jubatus/core/classifier/nearest_neighbor_classifier.cpp @@ -107,6 +107,11 @@ void nearest_neighbor_classifier::set_label_unlearner( unlearner_ = label_unlearner; } +shared_ptr +nearest_neighbor_classifier::get_label_unlearner() const { + return unlearner_; +} + std::string nearest_neighbor_classifier::classify( const common::sfv_t& fv) const { classify_result result; @@ -232,6 +237,14 @@ void nearest_neighbor_classifier::unpack(msgpack::object o) { } } +void nearest_neighbor_classifier::export_model(framework::packer& pk) const { + // TODO +} +void nearest_neighbor_classifier::import_model(msgpack::object o) { + // TODO +} + + framework::mixable* nearest_neighbor_classifier::get_mixable() { return nearest_neighbor_engine_->get_mixable(); } diff --git a/jubatus/core/classifier/nearest_neighbor_classifier.hpp b/jubatus/core/classifier/nearest_neighbor_classifier.hpp index 07c02778..057d468d 100644 --- a/jubatus/core/classifier/nearest_neighbor_classifier.hpp +++ b/jubatus/core/classifier/nearest_neighbor_classifier.hpp @@ -40,7 +40,7 @@ namespace classifier { class nearest_neighbor_classifier : public classifier_base { public: nearest_neighbor_classifier( - jubatus::util::lang::shared_ptr + util::lang::shared_ptr nearest_neighbor_engine, size_t k, float alpha); @@ -49,6 +49,8 @@ class nearest_neighbor_classifier : public classifier_base { void set_label_unlearner( jubatus::util::lang::shared_ptr label_unlearner); + jubatus::util::lang::shared_ptr + get_label_unlearner() const; std::string classify(const common::sfv_t& fv) const; void classify_with_scores(const common::sfv_t& fv, @@ -65,6 +67,8 @@ class nearest_neighbor_classifier : public classifier_base { void pack(framework::packer& pk) const; void unpack(msgpack::object o); + void export_model(framework::packer& pk) const; + void import_model(msgpack::object o); framework::mixable* get_mixable(); diff --git a/jubatus/core/classifier/normal_herd.cpp b/jubatus/core/classifier/normal_herd.cpp index f305fbd2..5228a6cf 100644 --- a/jubatus/core/classifier/normal_herd.cpp +++ b/jubatus/core/classifier/normal_herd.cpp @@ -30,18 +30,10 @@ namespace jubatus { namespace core { namespace classifier { -normal_herd::normal_herd(storage_ptr storage) - : linear_classifier(storage) { - config_.regularization_weight = 0.1f; -} - -normal_herd::normal_herd( - const classifier_config& config, - storage_ptr storage) - : linear_classifier(storage), - config_(config) { +normal_herd::normal_herd(float regularization_weight) + : regularization_weight_(regularization_weight) { - if (!(0.f < config.regularization_weight)) { + if (!(0.f < regularization_weight_)) { throw JUBATUS_EXCEPTION( common::invalid_parameter("0.0 < regularization_weight")); } @@ -81,7 +73,7 @@ void normal_herd::update( float val_covariance_pos = val * pos_val.v2; float val_covariance_neg = val * neg_val.v2; - const float C = config_.regularization_weight; + const float C = regularization_weight_; storage_->set2_nolock( feature, pos_label, diff --git a/jubatus/core/classifier/normal_herd.hpp b/jubatus/core/classifier/normal_herd.hpp index c591a9d3..264d3615 100644 --- a/jubatus/core/classifier/normal_herd.hpp +++ b/jubatus/core/classifier/normal_herd.hpp @@ -24,13 +24,11 @@ namespace jubatus { namespace core { namespace classifier { +struct classifier_parameter; class normal_herd : public linear_classifier { public: - explicit normal_herd(storage_ptr storage); - normal_herd( - const classifier_config& config, - storage_ptr storage); + normal_herd(float regularization_weight); void train(const common::sfv_t& fv, const std::string& label); std::string name() const; private: @@ -40,7 +38,7 @@ class normal_herd : public linear_classifier { float variance, const std::string& pos_label, const std::string& neg_label); - classifier_config config_; + float regularization_weight_; }; } // namespace classifier diff --git a/jubatus/core/classifier/passive_aggressive.cpp b/jubatus/core/classifier/passive_aggressive.cpp index 29f7dd3a..5557c926 100644 --- a/jubatus/core/classifier/passive_aggressive.cpp +++ b/jubatus/core/classifier/passive_aggressive.cpp @@ -17,6 +17,7 @@ #include "passive_aggressive.hpp" #include +#include "../storage/local_storage_mixture.hpp" using std::string; @@ -24,8 +25,8 @@ namespace jubatus { namespace core { namespace classifier { -passive_aggressive::passive_aggressive(storage_ptr storage) - : linear_classifier(storage) { +passive_aggressive::passive_aggressive() + : linear_classifier() { } void passive_aggressive::train(const common::sfv_t& sfv, const string& label) { diff --git a/jubatus/core/classifier/passive_aggressive.hpp b/jubatus/core/classifier/passive_aggressive.hpp index 94fda067..063bc6d7 100644 --- a/jubatus/core/classifier/passive_aggressive.hpp +++ b/jubatus/core/classifier/passive_aggressive.hpp @@ -27,7 +27,7 @@ namespace classifier { class passive_aggressive : public linear_classifier { public: - explicit passive_aggressive(storage_ptr storage); + explicit passive_aggressive(); void train(const common::sfv_t& fv, const std::string& label); std::string name() const; }; diff --git a/jubatus/core/classifier/passive_aggressive_1.cpp b/jubatus/core/classifier/passive_aggressive_1.cpp index 83691001..744241d0 100644 --- a/jubatus/core/classifier/passive_aggressive_1.cpp +++ b/jubatus/core/classifier/passive_aggressive_1.cpp @@ -20,6 +20,7 @@ #include #include "../common/exception.hpp" +#include "../storage/local_storage_mixture.hpp" using std::string; using std::min; @@ -28,17 +29,11 @@ namespace jubatus { namespace core { namespace classifier { -passive_aggressive_1::passive_aggressive_1(storage_ptr storage) - : linear_classifier(storage) { -} - passive_aggressive_1::passive_aggressive_1( - const classifier_config& config, - storage_ptr storage) - : linear_classifier(storage), - config_(config) { - - if (!(0.f < config.regularization_weight)) { + float regularization_weight) + : linear_classifier(), + regularization_weight_(regularization_weight) { + if (!(0.f < regularization_weight_)) { throw JUBATUS_EXCEPTION( common::invalid_parameter("0.0 < regularization_weight")); } @@ -63,7 +58,7 @@ void passive_aggressive_1::train(const common::sfv_t& sfv, update_weight( sfv, - min(config_.regularization_weight, loss / (2 * sfv_norm)), + min(regularization_weight_, loss / (2 * sfv_norm)), label, incorrect_label); touch(label); diff --git a/jubatus/core/classifier/passive_aggressive_1.hpp b/jubatus/core/classifier/passive_aggressive_1.hpp index 9608d0ef..9da174ee 100644 --- a/jubatus/core/classifier/passive_aggressive_1.hpp +++ b/jubatus/core/classifier/passive_aggressive_1.hpp @@ -18,23 +18,20 @@ #define JUBATUS_CORE_CLASSIFIER_PASSIVE_AGGRESSIVE_1_HPP_ #include - #include "linear_classifier.hpp" namespace jubatus { namespace core { namespace classifier { +struct classifier_parameter; class passive_aggressive_1 : public linear_classifier { public: - explicit passive_aggressive_1(storage_ptr storage); - passive_aggressive_1( - const classifier_config& config, - storage_ptr storage); + passive_aggressive_1(float regularization_weight); void train(const common::sfv_t& fv, const std::string& label); std::string name() const; private: - classifier_config config_; + float regularization_weight_; }; } // namespace classifier diff --git a/jubatus/core/classifier/passive_aggressive_2.cpp b/jubatus/core/classifier/passive_aggressive_2.cpp index d590f5db..3c187340 100644 --- a/jubatus/core/classifier/passive_aggressive_2.cpp +++ b/jubatus/core/classifier/passive_aggressive_2.cpp @@ -20,6 +20,7 @@ #include #include "../common/exception.hpp" +#include "../storage/local_storage_mixture.hpp" using std::string; @@ -27,17 +28,10 @@ namespace jubatus { namespace core { namespace classifier { -passive_aggressive_2::passive_aggressive_2(storage_ptr storage) - : linear_classifier(storage) { -} - -passive_aggressive_2::passive_aggressive_2( - const classifier_config& config, - storage_ptr storage) - : linear_classifier(storage), - config_(config) { +passive_aggressive_2::passive_aggressive_2(float regularization_weight) + : regularization_weight_(regularization_weight) { - if (!(0.f < config.regularization_weight)) { + if (!(0.f < regularization_weight_)) { throw JUBATUS_EXCEPTION( common::invalid_parameter("0.0 < regularization_weight")); } @@ -62,7 +56,7 @@ void passive_aggressive_2::train(const common::sfv_t& sfv, } update_weight( sfv, - loss / (2 * sfv_norm + 1 / (2 * config_.regularization_weight)), + loss / (2 * sfv_norm + 1 / (2 * regularization_weight_)), label, incorrect_label); touch(label); diff --git a/jubatus/core/classifier/passive_aggressive_2.hpp b/jubatus/core/classifier/passive_aggressive_2.hpp index f4e3bc01..e64172dd 100644 --- a/jubatus/core/classifier/passive_aggressive_2.hpp +++ b/jubatus/core/classifier/passive_aggressive_2.hpp @@ -18,8 +18,8 @@ #define JUBATUS_CORE_CLASSIFIER_PASSIVE_AGGRESSIVE_2_HPP_ #include - #include "linear_classifier.hpp" +#include "classifier_config.hpp" namespace jubatus { namespace core { @@ -27,15 +27,12 @@ namespace classifier { class passive_aggressive_2 : public linear_classifier { public: - explicit passive_aggressive_2(storage_ptr storage); - passive_aggressive_2( - const classifier_config& config, - storage_ptr storage); + passive_aggressive_2(float regularization_weight); void train(const common::sfv_t& sfv, const std::string& label); std::string name() const; private: - classifier_config config_; + float regularization_weight_; }; } // namespace classifier diff --git a/jubatus/core/classifier/perceptron.cpp b/jubatus/core/classifier/perceptron.cpp index e0815065..ed329f38 100644 --- a/jubatus/core/classifier/perceptron.cpp +++ b/jubatus/core/classifier/perceptron.cpp @@ -17,6 +17,7 @@ #include "perceptron.hpp" #include +#include "../storage/local_storage_mixture.hpp" using std::string; @@ -24,8 +25,8 @@ namespace jubatus { namespace core { namespace classifier { -perceptron::perceptron(storage_ptr storage) - : linear_classifier(storage) { +perceptron::perceptron() + : linear_classifier() { } void perceptron::train(const common::sfv_t& sfv, const std::string& label) { diff --git a/jubatus/core/classifier/perceptron.hpp b/jubatus/core/classifier/perceptron.hpp index b7a80f8f..1de4ba67 100644 --- a/jubatus/core/classifier/perceptron.hpp +++ b/jubatus/core/classifier/perceptron.hpp @@ -27,7 +27,7 @@ namespace classifier { class perceptron : public linear_classifier { public: - explicit perceptron(storage_ptr storage); + explicit perceptron(); void train(const common::sfv_t& sfv, const std::string& label); std::string name() const; }; diff --git a/jubatus/core/classifier/wscript b/jubatus/core/classifier/wscript index aea91d09..c7479293 100644 --- a/jubatus/core/classifier/wscript +++ b/jubatus/core/classifier/wscript @@ -15,6 +15,7 @@ def build(bld): 'normal_herd.cpp', 'nearest_neighbor_classifier.cpp', 'classifier_factory.cpp', + 'classifier_config.cpp', ] headers = [ 'classifier_base.hpp', diff --git a/jubatus/core/common/byte_buffer.hpp b/jubatus/core/common/byte_buffer.hpp index 9becfbf7..2a4fe4be 100644 --- a/jubatus/core/common/byte_buffer.hpp +++ b/jubatus/core/common/byte_buffer.hpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include "jubatus/util/lang/shared_ptr.h" @@ -41,6 +42,12 @@ class byte_buffer { buf_.reset(new std::vector(first, first+size)); } + void write(const char* buf, unsigned int len) { + const size_t old_tail = buf_->size(); + buf_->resize(buf_->size() + len); + std::memcpy(buf_->data() + old_tail, buf, len); + } + // following member functions are implicily defined: // byte_buffer(const byte_buffer& b) = default; // byte_buffer& operator=(const byte_buffer& b) = default; diff --git a/jubatus/core/common/export_model.hpp b/jubatus/core/common/export_model.hpp new file mode 100644 index 00000000..b353cda6 --- /dev/null +++ b/jubatus/core/common/export_model.hpp @@ -0,0 +1,36 @@ +// Jubatus: Online machine learning framework for distributed environment +// Copyright (C) 2014 Preferred Infrastructure and Nippon Telegraph and Telephone Corporation. +// +// This library is free software; you can redistribute it and/or +// modify it under the terms of the GNU Lesser General Public +// License version 2.1 as published by the Free Software Foundation. +// +// This library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public +// License along with this library; if not, write to the Free Software +// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + +#ifndef JUBATUS_CORE_COMMON_EXPORT_MODEL_HPP__ +#define JUBATUS_CORE_COMMON_EXPORT_MODEL_HPP__ + +#include +#include "../framework/packer.hpp" + +#define JUBATUS_EXPORT_MODEL(...) \ + void export_model(framework::packer& pk) const { \ + msgpack::type::make_define(__VA_ARGS__).msgpack_pack(pk); \ + } +#define JUBATUS_IMPORT_MODEL(...) \ + void import_model(msgpack::object o) { \ + this->clear(); \ + msgpack::type::make_define(__VA_ARGS__).msgpack_unpack(o); \ + } + +#define JUBATUS_PORTING_MODEL(...) \ + JUBATUS_IMPORT_MODEL(__VA_ARGS__) \ + JUBATUS_EXPORT_MODEL(__VA_ARGS__) +#endif // JUBATUS_CORE_COMMON_EXPORT_MODEL_HPP_ diff --git a/jubatus/core/common/key_manager.hpp b/jubatus/core/common/key_manager.hpp index 749a12d7..9bbe5bfa 100644 --- a/jubatus/core/common/key_manager.hpp +++ b/jubatus/core/common/key_manager.hpp @@ -46,6 +46,12 @@ class key_manager { key2id_.swap(km.key2id_); id2key_.swap(km.id2key_); } + key_manager& operator=(const key_manager& rhs) { + key2id_ = rhs.key2id_; + id2key_ = rhs.id2key_; + next_id_ = rhs.next_id_; + return *this; + } size_t size() const { return key2id_.size(); diff --git a/jubatus/core/driver/classifier.cpp b/jubatus/core/driver/classifier.cpp index b4293ca0..8bb17c49 100644 --- a/jubatus/core/driver/classifier.cpp +++ b/jubatus/core/driver/classifier.cpp @@ -7,47 +7,68 @@ // // This library is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU // Lesser General Public License for more details. // // You should have received a copy of the GNU Lesser General Public // License along with this library; if not, write to the Free Software -// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA +// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA #include "classifier.hpp" #include #include +#include #include #include +#include "jubatus/util/text/json.h" +#include "jubatus/util/lang/cast.h" #include "../classifier/classifier_factory.hpp" #include "../classifier/classifier_base.hpp" #include "../common/vector_util.hpp" +#include "../common/jsonconfig.hpp" +#include "../common/byte_buffer.hpp" +#include "../framework/stream_writer.hpp" +#include "../framework/packer.hpp" #include "../fv_converter/datum.hpp" #include "../fv_converter/datum_to_fv_converter.hpp" #include "../fv_converter/converter_config.hpp" #include "../storage/storage_factory.hpp" +#include "../fv_converter/factory.hpp" +#include "../fv_converter/so_factory.hpp" using std::string; using std::vector; +using std::make_pair; +using jubatus::util::text::json::json; using jubatus::util::lang::shared_ptr; +using jubatus::util::lang::lexical_cast; using jubatus::core::fv_converter::weight_manager; +using jubatus::core::common::jsonconfig::config_cast_check; using jubatus::core::fv_converter::mixable_weight_manager; namespace jubatus { namespace core { namespace driver { -classifier::classifier( - shared_ptr classifier_method, - shared_ptr converter) - : converter_(converter) - , classifier_(classifier_method) - , wm_(mixable_weight_manager::model_ptr(new weight_manager)) { +shared_ptr +generate_fv_converter(const classifier_driver_config& conf, + const fv_converter::factory_extender* extender) { + return fv_converter::make_fv_converter(conf.converter, extender); +} + +classifier::classifier(const classifier_driver_config& conf) + : conf_(conf), + extender_(new fv_converter::so_factory()), + converter_(fv_converter::make_fv_converter(conf.converter, + extender_.get())), + classifier_( + core::classifier::classifier_factory::create_classifier(*conf.parameter)), + wm_(mixable_weight_manager::model_ptr(new weight_manager)) { + register_mixable(classifier_->get_mixable()); register_mixable(&wm_); - converter_->set_weight_manager(wm_.get_model()); } @@ -55,40 +76,40 @@ classifier::~classifier() { } void classifier::train(const string& label, const fv_converter::datum& data) { - common::sfv_t v; - converter_->convert_and_update_weight(data, v); - common::sort_and_merge(v); - classifier_->train(v, label); + common::sfv_t v; + converter_->convert_and_update_weight(data, v); + common::sort_and_merge(v); + classifier_->train(v, label); } jubatus::core::classifier::classify_result classifier::classify( - const fv_converter::datum& data) const { - common::sfv_t v; - converter_->convert(data, v); + const fv_converter::datum& data) const { + common::sfv_t v; + converter_->convert(data, v); - jubatus::core::classifier::classify_result scores; - classifier_->classify_with_scores(v, scores); - return scores; + jubatus::core::classifier::classify_result scores; + classifier_->classify_with_scores(v, scores); + return scores; } void classifier::get_status(std::map& status) const { - classifier_->get_status(status); + classifier_->get_status(status); } bool classifier::delete_label(const std::string& label) { - return classifier_->delete_label(label); + return classifier_->delete_label(label); } void classifier::clear() { - classifier_->clear(); - converter_->clear_weights(); + classifier_->clear(); + converter_->clear_weights(); } std::vector classifier::get_labels() const { - return classifier_->get_labels(); + return classifier_->get_labels(); } bool classifier::set_label(const std::string& label) { - return classifier_->set_label(label); + return classifier_->set_label(label); } void classifier::pack(framework::packer& pk) const { @@ -109,6 +130,76 @@ void classifier::unpack(msgpack::object o) { wm_.get_model()->unpack(o.via.array.ptr[1]); } -} // namespace driver -} // namespace core -} // namespace jubatus +struct versioned_model { + // this struct is version compatible + std::string version; + common::byte_buffer buffer; + versioned_model(const std::string& v, const common::byte_buffer& b) + : version(v), buffer(b) { + } + MSGPACK_DEFINE(version, buffer); +}; + +void classifier::swap(driver::classifier& other) { + conf_.swap(other.conf_); + extender_.swap(other.extender_); + converter_.swap(other.converter_); + classifier_.swap(other.classifier_); + wm_.swap(other.wm_); +} + +void classifier::import_model(common::byte_buffer& src) { + msgpack::unpacked packed; + msgpack::unpack(&packed, src.ptr(), src.size()); + msgpack::object o = packed.get(); + + if (o.type != msgpack::type::ARRAY || o.via.array.size != 2) { + throw msgpack::type_error(); + } + + // config_file + common::jsonconfig::config + jsonconf(lexical_cast(o.via.array.ptr[0].as())); + classifier_driver_config new_config( + config_cast_check(jsonconf)); + classifier new_classifier(new_config); + this->swap(new_classifier); + + // config_data + std::string model_payload(o.via.array.ptr[1].as()); + msgpack::unpacked packed_data; + msgpack::unpack(&packed_data, &model_payload[0], model_payload.size()); + msgpack::object model_obj; + classifier_->import_model(model_obj); + wm_.get_model()->import_model(model_obj); + + // TODO: re-register mixer and re-register weight_manager to fvconveter is needed here + +} + +common::byte_buffer classifier::export_model() const { + common::byte_buffer result; + framework::stream_writer writer(result); + framework::jubatus_packer jp(writer); + core::framework::packer dst(jp); + dst.pack_array(2); // [config_file, model_data] + + { // config_file + std::string config_json; + classifier_driver_config nonconst_conf(*const_cast(&conf_)); + nonconst_conf.serialize(config_json); + dst.pack(config_json); + } + + { // model_data + dst.pack_array(2); // [classifier_, wm_] + classifier_->export_model(dst); + wm_.get_model()->export_model(dst); + } + return result; +} + + +} // namespace driver +} // namespace core +} // namespace jubatus diff --git a/jubatus/core/driver/classifier.hpp b/jubatus/core/driver/classifier.hpp index 625a0ec5..5da77ff8 100644 --- a/jubatus/core/driver/classifier.hpp +++ b/jubatus/core/driver/classifier.hpp @@ -20,10 +20,16 @@ #include #include #include +#include "jubatus/util/text/json.h" +#include "jubatus/util/data/serialization.h" #include "jubatus/util/lang/shared_ptr.h" +#include "../common/jsonconfig.hpp" +#include "../common/byte_buffer.hpp" #include "../classifier/classifier_type.hpp" +#include "../classifier/classifier_config.hpp" #include "../framework/mixable.hpp" #include "../fv_converter/mixable_weight_manager.hpp" +#include "../fv_converter/converter_config.hpp" #include "driver.hpp" namespace jubatus { @@ -37,16 +43,29 @@ class classifier_base; } // namespace classifier namespace driver { +struct classifier_driver_config { + std::string method; + jubatus::util::data::optional parameter; + core::fv_converter::converter_config converter; + + void swap(classifier_driver_config& other) { + method.swap(other.method); + parameter.swap(other.parameter); + converter.swap(other.converter); + } + + template + void serialize(Ar& ar) { + ar & JUBA_MEMBER(method) & JUBA_MEMBER(parameter) & JUBA_MEMBER(converter); + } +}; + class classifier : public driver_base { public: typedef core::classifier::classifier_base classifier_base; // TODO(suma): where is the owner of model, mixer, and converter? - classifier( - jubatus::util::lang::shared_ptr - classifier_method, - jubatus::util::lang::shared_ptr - converter); + classifier(const classifier_driver_config& conf); virtual ~classifier(); void train(const std::string&, const fv_converter::datum&); @@ -61,10 +80,17 @@ class classifier : public driver_base { void pack(framework::packer& pk) const; void unpack(msgpack::object o); + void import_model(common::byte_buffer& src); + common::byte_buffer export_model() const; + std::vector get_labels() const; bool set_label(const std::string& label); + void swap(classifier& other); + private: + classifier_driver_config conf_; + jubatus::util::lang::shared_ptr extender_; jubatus::util::lang::shared_ptr converter_; jubatus::util::lang::shared_ptr classifier_; diff --git a/jubatus/core/driver/recommender.hpp b/jubatus/core/driver/recommender.hpp index 29e57b09..ee004482 100644 --- a/jubatus/core/driver/recommender.hpp +++ b/jubatus/core/driver/recommender.hpp @@ -21,6 +21,7 @@ #include #include #include "jubatus/util/lang/shared_ptr.h" +#include "../common/byte_buffer.hpp" #include "../recommender/recommender_base.hpp" #include "../framework/mixable.hpp" #include "../framework/diffv.hpp" diff --git a/jubatus/core/framework/mixable_helper.hpp b/jubatus/core/framework/mixable_helper.hpp index c3b141f6..9059a6e2 100644 --- a/jubatus/core/framework/mixable_helper.hpp +++ b/jubatus/core/framework/mixable_helper.hpp @@ -102,6 +102,10 @@ class linear_mixable_helper : public linear_mixable { return model_->get_version(); } + void swap(linear_mixable_helper& other) { + model_.swap(other.model_); + } + private: struct internal_diff_object : diff_object_raw { void convert_binary(packer& pk) const { diff --git a/jubatus/core/fv_converter/converter_config.hpp b/jubatus/core/fv_converter/converter_config.hpp index f6ed97b2..846e36ac 100644 --- a/jubatus/core/fv_converter/converter_config.hpp +++ b/jubatus/core/fv_converter/converter_config.hpp @@ -112,11 +112,11 @@ struct combination_rule { struct converter_config { jubatus::util::data::optional > - string_filter_types; + string_filter_types; jubatus::util::data::optional > string_filter_rules; jubatus::util::data::optional > - num_filter_types; + num_filter_types; jubatus::util::data::optional > num_filter_rules; jubatus::util::data::optional > string_types; @@ -136,6 +136,21 @@ struct converter_config { jubatus::util::data::optional hash_max_size; + void swap(converter_config& other) { + string_filter_types.swap(other.string_filter_types); + string_filter_rules.swap(string_filter_rules); + num_filter_types.swap(num_filter_types); + num_filter_rules.swap(num_filter_rules); + string_types.swap(string_types); + string_rules.swap(string_rules); + num_types.swap(num_types); + num_rules.swap(num_rules); + binary_types.swap(binary_types); + binary_rules.swap(binary_rules); + combination_types.swap(combination_types); + combination_rules.swap(combination_rules); + } + friend class jubatus::util::data::serialization::access; template void serialize(Archive& ar) { diff --git a/jubatus/core/fv_converter/datum_to_fv_converter.cpp b/jubatus/core/fv_converter/datum_to_fv_converter.cpp index 57103d20..6dd9333f 100644 --- a/jubatus/core/fv_converter/datum_to_fv_converter.cpp +++ b/jubatus/core/fv_converter/datum_to_fv_converter.cpp @@ -658,6 +658,7 @@ void datum_to_fv_converter::clear_weights() { pimpl_->clear_weights(); } + } // namespace fv_converter } // namespace core } // namespace jubatus diff --git a/jubatus/core/fv_converter/datum_to_fv_converter.hpp b/jubatus/core/fv_converter/datum_to_fv_converter.hpp index 20f96ff6..8a9e2b49 100644 --- a/jubatus/core/fv_converter/datum_to_fv_converter.hpp +++ b/jubatus/core/fv_converter/datum_to_fv_converter.hpp @@ -25,6 +25,7 @@ #include "jubatus/util/lang/scoped_ptr.h" #include "../common/type.hpp" #include "../framework/mixable.hpp" +#include "../framework/packer.hpp" namespace jubatus { namespace core { @@ -119,6 +120,9 @@ class datum_to_fv_converter { void set_weight_manager(jubatus::util::lang::shared_ptr wm); void clear_weights(); + void swap(datum_to_fv_converter& other) { + pimpl_.swap(other.pimpl_); + } private: jubatus::util::lang::scoped_ptr pimpl_; diff --git a/jubatus/core/fv_converter/weight_manager.hpp b/jubatus/core/fv_converter/weight_manager.hpp index 914ba174..b2e3c7c3 100644 --- a/jubatus/core/fv_converter/weight_manager.hpp +++ b/jubatus/core/fv_converter/weight_manager.hpp @@ -28,6 +28,7 @@ #include "../framework/model.hpp" #include "../common/type.hpp" #include "../common/version.hpp" +#include "../common/export_model.hpp" #include "keyword_weights.hpp" namespace jubatus { @@ -94,6 +95,8 @@ class weight_manager : public framework::model { } MSGPACK_DEFINE(version_, diff_weights_, master_weights_); + JUBATUS_EXPORT_MODEL(version_, master_weights_); + void import_model(msgpack::object o); void pack(framework::packer& pk) const { util::concurrent::scoped_lock lk(mutex_); diff --git a/jubatus/core/recommender/recommender_factory.cpp b/jubatus/core/recommender/recommender_factory.cpp index 63e1ea97..04feb571 100644 --- a/jubatus/core/recommender/recommender_factory.cpp +++ b/jubatus/core/recommender/recommender_factory.cpp @@ -82,9 +82,11 @@ shared_ptr recommender_factory::create_recommender( common::config_exception() << common::exception::error_message( "unlearner is set but unlearner_parameter is not found")); } - shared_ptr unl(unlearner::create_unlearner( - *conf.unlearner, common::jsonconfig::config( - *conf.unlearner_parameter))); + shared_ptr uconf = + unlearner::create_unlearner_config(*conf.unlearner, + *conf.unlearner_parameter); + shared_ptr unl( + unlearner::create_unlearner(uconf)); return shared_ptr( new nearest_neighbor_recommender(nearest_neighbor_engine, unl)); } diff --git a/jubatus/core/storage/local_storage.cpp b/jubatus/core/storage/local_storage.cpp index 1db63cda..5f10d7ea 100644 --- a/jubatus/core/storage/local_storage.cpp +++ b/jubatus/core/storage/local_storage.cpp @@ -285,6 +285,10 @@ void local_storage::unpack(msgpack::object o) { o.convert(this); } +void local_storage::import_model(msgpack::object o) { + o.convert(this); +} + std::string local_storage::type() const { return "local_storage"; } diff --git a/jubatus/core/storage/local_storage.hpp b/jubatus/core/storage/local_storage.hpp index 5a0b51cc..4e783b8a 100644 --- a/jubatus/core/storage/local_storage.hpp +++ b/jubatus/core/storage/local_storage.hpp @@ -26,6 +26,7 @@ #include "storage_base.hpp" #include "../common/key_manager.hpp" #include "../common/version.hpp" +#include "../common/export_model.hpp" namespace jubatus { namespace core { @@ -100,12 +101,15 @@ class local_storage : public storage_base { void pack(framework::packer& packer) const; void unpack(msgpack::object o); + storage::version get_version() const { return storage::version(); } std::string type() const; MSGPACK_DEFINE(tbl_, class2id_); + JUBATUS_EXPORT_MODEL(tbl_, class2id_); + void import_model(msgpack::object o); private: // map_features3_t tbl_; diff --git a/jubatus/core/storage/local_storage_mixture.cpp b/jubatus/core/storage/local_storage_mixture.cpp index a22e024a..c106def8 100644 --- a/jubatus/core/storage/local_storage_mixture.cpp +++ b/jubatus/core/storage/local_storage_mixture.cpp @@ -339,6 +339,30 @@ void local_storage_mixture::unpack(msgpack::object o) { o.convert(this); } +void local_storage_mixture::export_model(framework::packer& pk) const { + pk.pack_array(3); // tbl_(diff_), class2id_, model_version_ + if (tbl_.empty()) { // standalone jubatus + pk.pack(tbl_diff_); + } else { + pk.pack(tbl_); + } + pk.pack(class2id_); + pk.pack(model_version_); +} + +void local_storage_mixture::import_model(msgpack::object o) { + if(o.type != msgpack::type::ARRAY) { + throw msgpack::type_error(); + } + const size_t size = o.via.array.size; + JUBATUS_ASSERT_EQ(3, size, "importing array length must be 3"); + o.via.array.ptr[0].convert(&tbl_); + o.via.array.ptr[1].convert(&class2id_); + o.via.array.ptr[2].convert(&model_version_); + tbl_diff_.clear(); +} + + std::string local_storage_mixture::type() const { return "local_storage_mixture"; } diff --git a/jubatus/core/storage/local_storage_mixture.hpp b/jubatus/core/storage/local_storage_mixture.hpp index a982e67a..2cf05d5f 100644 --- a/jubatus/core/storage/local_storage_mixture.hpp +++ b/jubatus/core/storage/local_storage_mixture.hpp @@ -102,6 +102,9 @@ class local_storage_mixture : public storage_base { void pack(framework::packer& packer) const; void unpack(msgpack::object o); + void export_model(framework::packer& pk) const; + void import_model(msgpack::object o); + version get_version() const { return model_version_; } diff --git a/jubatus/core/storage/storage_base.hpp b/jubatus/core/storage/storage_base.hpp index 1db776bd..7cf65318 100644 --- a/jubatus/core/storage/storage_base.hpp +++ b/jubatus/core/storage/storage_base.hpp @@ -30,6 +30,7 @@ #include "../common/exception.hpp" #include "../common/type.hpp" #include "../framework/model.hpp" +#include "../framework/packer.hpp" namespace jubatus { namespace core { @@ -82,6 +83,8 @@ class storage_base : public framework::model { virtual void pack(framework::packer& packer) const = 0; virtual void unpack(msgpack::object o) = 0; + virtual void export_model(framework::packer& pk) const = 0; + virtual void import_model(msgpack::object o) = 0; virtual version get_version() const = 0; diff --git a/jubatus/core/unlearner/lru_unlearner.cpp b/jubatus/core/unlearner/lru_unlearner.cpp index fadffe7a..6bbc38d3 100644 --- a/jubatus/core/unlearner/lru_unlearner.cpp +++ b/jubatus/core/unlearner/lru_unlearner.cpp @@ -17,12 +17,15 @@ #include "lru_unlearner.hpp" #include +#include #include "jubatus/util/data/unordered_set.h" +#include "../common/assert.hpp" // TODO(kmaehashi) move key_matcher to common #include "../fv_converter/key_matcher.hpp" #include "../fv_converter/key_matcher_factory.hpp" #include "../common/exception.hpp" +#include "../common/unordered_set.hpp" using jubatus::util::data::unordered_set; using jubatus::core::fv_converter::key_matcher_factory; @@ -30,19 +33,26 @@ using jubatus::core::fv_converter::key_matcher_factory; namespace jubatus { namespace core { namespace unlearner { +namespace { +const lru_unlearner::lru_unlearner_config& +as_lru_unlearner_config(const unlearner_config_base& conf) { + return dynamic_cast(conf); +} +} // namespace -lru_unlearner::lru_unlearner(const config& conf) - : max_size_(conf.max_size) { - if (conf.max_size <= 0) { +lru_unlearner::lru_unlearner(const unlearner_config_base& conf) + : max_size_(as_lru_unlearner_config(conf).max_size) { + const lru_unlearner_config& lconfig = as_lru_unlearner_config(conf); + if (lconfig.max_size <= 0) { throw JUBATUS_EXCEPTION( common::config_exception() << common::exception::error_message( "max_size must be a positive integer")); } - entry_map_.reserve(max_size_); + entry_map_.reserve(lconfig.max_size); - if (conf.sticky_pattern) { + if (lconfig.sticky_pattern) { key_matcher_factory f; - sticky_matcher_ = f.create_matcher(*conf.sticky_pattern); + sticky_matcher_ = f.create_matcher(*lconfig.sticky_pattern); } } @@ -126,6 +136,22 @@ bool lru_unlearner::exists_in_memory(const std::string& id) const { return entry_map_.count(id) > 0 || sticky_ids_.count(id) > 0; } +void lru_unlearner::export_model(framework::packer& pk) const { + pk.pack_array(1); // [lru_] + pk.pack(lru_); +} + +void lru_unlearner::import_model(msgpack::object o) { + if(o.type != msgpack::type::ARRAY) { + throw msgpack::type_error(); + } + JUBATUS_ASSERT_EQ(1, + o.via.array.size, + "importing lru_unlearner length must be 1"); + o.via.array.ptr[0].convert(&lru_); + rebuild_entry_map(); +} + // private void lru_unlearner::rebuild_entry_map() { diff --git a/jubatus/core/unlearner/lru_unlearner.hpp b/jubatus/core/unlearner/lru_unlearner.hpp index 13b8cf37..6f146c19 100644 --- a/jubatus/core/unlearner/lru_unlearner.hpp +++ b/jubatus/core/unlearner/lru_unlearner.hpp @@ -26,24 +26,27 @@ #include "jubatus/util/data/optional.h" #include "jubatus/util/lang/shared_ptr.h" #include "unlearner_base.hpp" +#include "../common/porting_model.hpp" +#include "../common/unordered_map.hpp" +#include "unlearner_config.hpp" namespace jubatus { namespace core { namespace fv_converter { class key_matcher; -} +} // namespace fv_converter namespace unlearner { // Unlearner based on Least Recently Used algorithm. class lru_unlearner : public unlearner_base { public: - struct config { + struct lru_unlearner_config : public unlearner_config_base { int32_t max_size; jubatus::util::data::optional sticky_pattern; template void serialize(Ar& ar) { - ar & JUBA_MEMBER(max_size) & JUBA_MEMBER(sticky_pattern); + ar & JUBA_MEMBER(name) & JUBA_MEMBER(max_size) & JUBA_MEMBER(sticky_pattern); } }; @@ -54,16 +57,18 @@ class lru_unlearner : public unlearner_base { void clear() { lru_.clear(); entry_map_.clear(); - sticky_ids_.clear(); } - explicit lru_unlearner(const config& conf); + explicit lru_unlearner(const unlearner_config_base& conf); bool can_touch(const std::string& id); bool touch(const std::string& id); bool remove(const std::string& id); bool exists_in_memory(const std::string& id) const; + void export_model(framework::packer& pk) const; + void import_model(msgpack::object o); + private: typedef std::list lru; typedef jubatus::util::data::unordered_map diff --git a/jubatus/core/unlearner/random_unlearner.cpp b/jubatus/core/unlearner/random_unlearner.cpp index 834464f8..edabd225 100644 --- a/jubatus/core/unlearner/random_unlearner.cpp +++ b/jubatus/core/unlearner/random_unlearner.cpp @@ -19,25 +19,32 @@ #include #include #include "../common/exception.hpp" +#include "../common/assert.hpp" namespace jubatus { namespace core { namespace unlearner { -random_unlearner::random_unlearner(const config& conf) - : max_size_(conf.max_size) { - if (conf.max_size <= 0) { +const random_unlearner::random_unlearner_config& +as_random_unlearner_config(const unlearner_config_base& orig) { + return dynamic_cast(orig); +} + +random_unlearner::random_unlearner(const unlearner_config_base& conf) + : max_size_(as_random_unlearner_config(conf).max_size) { + const random_unlearner_config& rconf = as_random_unlearner_config(conf); + if (rconf.max_size <= 0) { throw JUBATUS_EXCEPTION( common::config_exception() << common::exception::error_message( "max_size must be a positive integer")); } - if (conf.seed) { - if (*conf.seed < 0 || std::numeric_limits::max() < *conf.seed) { + if (rconf.seed) { + if (*rconf.seed < 0 || std::numeric_limits::max() < *rconf.seed) { throw JUBATUS_EXCEPTION( common::config_exception() << common::exception::error_message( "unlearner seed must be within unsigned 32 bit integer")); } - mtr_ = jubatus::util::math::random::mtrand(*conf.seed); + mtr_ = jubatus::util::math::random::mtrand(*rconf.seed); } id_map_.reserve(max_size_); ids_.reserve(max_size_); @@ -89,6 +96,28 @@ bool random_unlearner::remove(const std::string& id) { return true; } +void random_unlearner::export_model(framework::packer& pk) const { + pk.pack_array(1); // [ids_] + pk.pack(ids_); +} +void random_unlearner::import_model(msgpack::object o) { + if(o.type != msgpack::type::ARRAY) { + throw msgpack::type_error(); + } + JUBATUS_ASSERT_EQ(1, + o.via.array.size, + "importing random_unlearner length must be 1"); + o.via.array.ptr[0].convert(&ids_); + rebuild_map(); +} + +void random_unlearner::rebuild_map() { + id_map_.clear(); + for (size_t i = 0; i < ids_.size(); ++i) { + id_map_[ids_[i]] = i; + } +} + bool random_unlearner::exists_in_memory(const std::string& id) const { return id_map_.count(id) > 0; } diff --git a/jubatus/core/unlearner/random_unlearner.hpp b/jubatus/core/unlearner/random_unlearner.hpp index b26f3170..91327f54 100644 --- a/jubatus/core/unlearner/random_unlearner.hpp +++ b/jubatus/core/unlearner/random_unlearner.hpp @@ -19,11 +19,15 @@ #include #include +#include #include "jubatus/util/data/optional.h" #include "jubatus/util/data/serialization.h" #include "jubatus/util/data/unordered_map.h" #include "jubatus/util/math/random.h" #include "unlearner_base.hpp" +#include "unlearner_config.hpp" +#include "../common/porting_model.hpp" +#include "../common/unordered_set.hpp" namespace jubatus { namespace core { @@ -32,13 +36,13 @@ namespace unlearner { // Unlearner that chooses an item to be removed by uniformly random sampling. class random_unlearner : public unlearner_base { public: - struct config { + struct random_unlearner_config : public unlearner_config_base { int32_t max_size; jubatus::util::data::optional seed; template void serialize(Ar& ar) { - ar & JUBA_MEMBER(max_size) & JUBA_MEMBER(seed); + ar & JUBA_MEMBER(name) & JUBA_MEMBER(max_size) & JUBA_MEMBER(seed); } }; @@ -51,14 +55,18 @@ class random_unlearner : public unlearner_base { ids_.clear(); } - explicit random_unlearner(const config& conf); + explicit random_unlearner(const unlearner_config_base& conf); bool can_touch(const std::string& id); bool touch(const std::string& id); bool remove(const std::string& id); bool exists_in_memory(const std::string& id) const; + void export_model(framework::packer& pk) const; + void import_model(msgpack::object o); + private: + void rebuild_map(); /** * Map of ID and its position in ids_. */ diff --git a/jubatus/core/unlearner/unlearner_base.hpp b/jubatus/core/unlearner/unlearner_base.hpp index 39051aca..bbbbcd64 100644 --- a/jubatus/core/unlearner/unlearner_base.hpp +++ b/jubatus/core/unlearner/unlearner_base.hpp @@ -19,6 +19,11 @@ #include #include "jubatus/util/lang/function.h" +#include "jubatus/core/framework/packer.hpp" + +namespace msgpack { +class object; +} namespace jubatus { namespace core { @@ -72,6 +77,9 @@ class unlearner_base { // touched and not unlearned since then, it returns true. virtual bool exists_in_memory(const std::string& id) const = 0; + virtual void export_model(framework::packer& pk) const = 0; + virtual void import_model(msgpack::object o) = 0; + protected: void unlearn(const std::string& id) const { callback_(id); diff --git a/jubatus/core/unlearner/unlearner_config.hpp b/jubatus/core/unlearner/unlearner_config.hpp new file mode 100644 index 00000000..51516472 --- /dev/null +++ b/jubatus/core/unlearner/unlearner_config.hpp @@ -0,0 +1,47 @@ +// Jubatus: Online machine learning framework for distributed environment +// Copyright (C) 2013 Preferred Networks and Nippon Telegraph and Telephone Corporation. +// +// This library is free software; you can redistribute it and/or +// modify it under the terms of the GNU Lesser General Public +// License version 2.1 as published by the Free Software Foundation. +// +// This library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public +// License along with this library; if not, write to the Free Software +// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + +#ifndef JUBATUS_CORE_UNLEARNER_UNLEARNER_CONFIG_HPP_ +#define JUBATUS_CORE_UNLEARNER_UNLEARNER_CONFIG_HPP_ + +#include +#include "../common/jsonconfig.hpp" +#include "jubatus/util/lang/shared_ptr.h" +#include "jubatus/util/data/serialization.h" + +namespace jubatus { +namespace core { +namespace unlearner { + +struct unlearner_config_base { + std::string name; + virtual ~unlearner_config_base() = 0; + + template + void serialize(Ar& ar) { + ar & JUBA_NAMED_MEMBER("name", name); + } +}; + +util::lang::shared_ptr +create_unlearner_config(const std::string, + const common::jsonconfig::config& param); + +} // namespace unlearner +} // namespace core +} // namespace jubatus + +#endif // JUBATUS_CORE_UNLEARNER_UNLEARNER_FACTORY_HPP_ diff --git a/jubatus/core/unlearner/unlearner_factory.cpp b/jubatus/core/unlearner/unlearner_factory.cpp index 0d969db8..6354fd6d 100644 --- a/jubatus/core/unlearner/unlearner_factory.cpp +++ b/jubatus/core/unlearner/unlearner_factory.cpp @@ -29,19 +29,30 @@ namespace core { namespace unlearner { shared_ptr create_unlearner( - const std::string& name, - const common::jsonconfig::config& config) { - if (name == "lru") { - return shared_ptr( - new lru_unlearner(common::jsonconfig::config_cast_check< - lru_unlearner::config>(config))); - } else if (name == "random") { - return shared_ptr( - new random_unlearner(common::jsonconfig::config_cast_check< - random_unlearner::config>(config))); + const shared_ptr conf) { + if (conf->name == "lru") { + lru_unlearner::lru_unlearner_config* lconf = + dynamic_cast(conf.get()); + if (lconf) { + return shared_ptr( + new lru_unlearner(*lconf)); + } else { + throw JUBATUS_EXCEPTION(common::unsupported_method( + "invaild lru unlearner config")); + } + } else if (conf->name == "random") { + random_unlearner::random_unlearner_config* rconf = + dynamic_cast(conf.get()); + if (rconf) { + return shared_ptr( + new random_unlearner(*rconf)); + } else { + throw JUBATUS_EXCEPTION(common::unsupported_method( + "invaild random unlearner config")); + } } else { throw JUBATUS_EXCEPTION(common::unsupported_method( - "unlearner(" + name + ')')); + "unlearner(" + conf->name + ')')); } } diff --git a/jubatus/core/unlearner/unlearner_factory.hpp b/jubatus/core/unlearner/unlearner_factory.hpp index 4049ae56..90791a34 100644 --- a/jubatus/core/unlearner/unlearner_factory.hpp +++ b/jubatus/core/unlearner/unlearner_factory.hpp @@ -20,7 +20,8 @@ #include #include "../common/jsonconfig.hpp" #include "jubatus/util/lang/shared_ptr.h" - +#include "jubatus/util/data/optional.h" +#include "unlearner_config.hpp" namespace jubatus { namespace core { @@ -29,8 +30,7 @@ namespace unlearner { class unlearner_base; util::lang::shared_ptr create_unlearner( - const std::string& name, - const common::jsonconfig::config& config); + util::lang::shared_ptr config); } // namespace unlearner } // namespace core