Skip to content

Commit

Permalink
export_model interface of driver::classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
KUMAZAKI Hiroki committed Apr 2, 2015
1 parent 07fba27 commit 4f25024
Show file tree
Hide file tree
Showing 51 changed files with 782 additions and 330 deletions.
8 changes: 5 additions & 3 deletions jubatus/core/anomaly/anomaly_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "../common/exception.hpp"
#include "../common/jsonconfig.hpp"
#include "../nearest_neighbor/nearest_neighbor_factory.hpp"
#include "../unlearner/unlearner_config.hpp"
#include "../unlearner/unlearner_factory.hpp"
#include "../storage/column_table.hpp"
#include "../recommender/recommender_factory.hpp"
Expand Down Expand Up @@ -100,10 +101,11 @@ shared_ptr<anomaly_base> anomaly_factory::create_anomaly(
<< common::exception::error_message(
"unlearner is set but unlearner_parameter is not found"));
}
shared_ptr<unlearner::unlearner_config_base> unl_conf(
unlearner::create_unlearner_config(*conf.unlearner,
*conf.unlearner_parameter));
jubatus::util::lang::shared_ptr<unlearner::unlearner_base> unlearner(
unlearner::create_unlearner(
*conf.unlearner,
*conf.unlearner_parameter));
unlearner::create_unlearner(unl_conf));
return shared_ptr<anomaly_base>(
new light_lof(conf, id, nearest_neighbor_engine, unlearner));
}
Expand Down
18 changes: 7 additions & 11 deletions jubatus/core/classifier/arow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,20 @@
#include "jubatus/util/concurrent/lock.h"
#include "classifier_util.hpp"
#include "../common/exception.hpp"
#include "../storage/storage_base.hpp"

using std::string;
using jubatus::core::storage_ptr;

namespace jubatus {
namespace core {
namespace classifier {

arow::arow(storage_ptr storage)
: linear_classifier(storage) {
}

arow::arow(
const classifier_config& config,
storage_ptr storage)
: linear_classifier(storage),
config_(config) {
arow::arow(float regularization_weight)
: linear_classifier(),
regularization_weight_(regularization_weight) {

if (!(0.f < config.regularization_weight)) {
if (!(0.f < regularization_weight_)) {
throw JUBATUS_EXCEPTION(
common::invalid_parameter("0.0 < regularization_weight"));
}
Expand All @@ -58,7 +54,7 @@ void arow::train(const common::sfv_t& sfv, const string& label) {
return;
}

float beta = 1.f / (variance + 1.f / config_.regularization_weight);
float beta = 1.f / (variance + 1.f / regularization_weight_);
float alpha = (1.f - margin) * beta; // max(0, 1 - margin) = 1 - margin
update(sfv, alpha, beta, label, incorrect_label);
}
Expand Down
6 changes: 3 additions & 3 deletions jubatus/core/classifier/arow.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
namespace jubatus {
namespace core {
namespace classifier {
struct classifier_parameter;

class arow : public linear_classifier {
public:
explicit arow(storage_ptr storage);
arow(const classifier_config& config, storage_ptr storage);
arow(float regularization_weight);
void train(const common::sfv_t& fv, const std::string& label);
std::string name() const;
private:
Expand All @@ -38,7 +38,7 @@ class arow : public linear_classifier {
float beta,
const std::string& pos_label,
const std::string& neg_label);
classifier_config config_;
float regularization_weight_;
};

} // namespace classifier
Expand Down
14 changes: 14 additions & 0 deletions jubatus/core/classifier/classifier_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,17 @@
#include "../unlearner/unlearner_base.hpp"
#include "classifier_type.hpp"

namespace msgpack {

}
namespace jubatus {
namespace core {
namespace common {
class byte_buffer;
}
namespace framework {
class jubatus_packer;
}
namespace classifier {

class classifier_base {
Expand All @@ -48,6 +57,8 @@ class classifier_base {
virtual void set_label_unlearner(
jubatus::util::lang::shared_ptr<unlearner::unlearner_base>
label_unlearner) = 0;
virtual jubatus::util::lang::shared_ptr<unlearner::unlearner_base>
get_label_unlearner() const = 0;

virtual bool delete_label(const std::string& label) = 0;
virtual std::vector<std::string> get_labels() const = 0;
Expand All @@ -61,6 +72,9 @@ class classifier_base {
virtual void unpack(msgpack::object o) = 0;
virtual void clear() = 0;

void import_model(msgpack::object& src);
void export_model(framework::packer& dst) const;

virtual framework::mixable* get_mixable() = 0;
};

Expand Down
69 changes: 69 additions & 0 deletions jubatus/core/classifier/classifier_config.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Jubatus: Online machine learning framework for distributed environment
// Copyright (C) 2012 Preferred Networks and Nippon Telegraph and Telephone Corporation.
//
// This library is free software; you can redistribute it and/or
// modify it under the terms of the GNU Lesser General Public
// License version 2.1 as published by the Free Software Foundation.
//
// This library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
// Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public
// License along with this library; if not, write to the Free Software
// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA

#include "classifier_config.hpp"

#include "jubatus/util/data/serialization.h"
#include "jubatus/util/data/optional.h"
#include "jubatus/util/lang/shared_ptr.h"
#include "../common/jsonconfig.hpp"
#include "../unlearner/unlearner_config.hpp"
#include "../storage/column_table.hpp"
#include "../nearest_neighbor/nearest_neighbor_base.hpp"

using jubatus::util::lang::shared_ptr;
using jubatus::core::common::jsonconfig::config_cast_check;

namespace jubatus {
namespace core {
namespace classifier {
classifier_config::classifier_config(const std::string& method,
const common::jsonconfig::config& param) {
if (method == "perceptron" ||
method == "PA" || method == "passive_aggressive") {
if (param.type() != jubatus::util::text::json::json::Null) {
throw JUBATUS_EXCEPTION(
common::config_exception() << common::exception::error_message(
"parameter block is specified in config"));
}
conf_.reset(new detail::unlearning_classifier_config(method, param));
} else if (method == "PA1" || method == "passive_aggressive_1" ||
method == "PA2" || method == "passive_aggressive_2" ||
method == "CW" || method == "confidence_weighted" ||
method == "AROW" || method == "arow" ||
method == "NHERD" || method == "normal_herd") {
if (param.type() == jubatus::util::text::json::json::Null) {
throw JUBATUS_EXCEPTION(
common::config_exception() << common::exception::error_message(
"parameter block is not specified in config"));
}
conf_.reset(new detail::unlearning_classifier_config(method, param));
} else if (method == "NN" || method == "nearest_neighbor") {
if (param.type() == jubatus::util::text::json::json::Null) {
throw JUBATUS_EXCEPTION(
common::config_exception() << common::exception::error_message(
"parameter block is not specified in config"));
}
conf_.reset(new detail::nearest_neighbor_classifier_config(method, param));
} else {
throw JUBATUS_EXCEPTION(
common::unsupported_method("classifier(" + method + ")"));
}
}

} // namespace classifier
} // namespace core
} // namespace jubatus
83 changes: 79 additions & 4 deletions jubatus/core/classifier/classifier_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,96 @@
#define JUBATUS_CORE_CLASSIFIER_CLASSIFIER_CONFIG_HPP_

#include "jubatus/util/data/serialization.h"
#include "jubatus/util/data/optional.h"
#include "jubatus/util/lang/shared_ptr.h"
#include "../unlearner/unlearner_config.hpp"

namespace jubatus {
namespace core {
namespace classifier {

struct classifier_config {
classifier_config()
: regularization_weight(1.0f) {
struct classifier_config_base {
std::string method_;
virtual ~classifier_config_base() {}
classifier_config_base(const std::string& method)
: method_(method) {
}

template<typename Ar>
void serialize(Ar& ar) {
ar & JUBA_NAMED_MEMBER("method", method_);
}
};

namespace detail {
struct classifier_parameter : public classifier_config_base {
classifier_parameter(const std::string& method)
: classifier_config_base(method),
regularization_weight(1.0f) {
}
float regularization_weight;

template<typename Ar>
void serialize(Ar& ar) {
ar & JUBA_NAMED_MEMBER("regularization_weight", regularization_weight);
classifier_config_base::serialize(ar);
ar & JUBA_MEMBER(regularization_weight);
}
};

struct unlearning_classifier_config : public classifier_parameter {
unlearning_classifier_config(const std::string& method,
const common::jsonconfig::config& param)
: classifier_parameter(method) {
// TODO
if (param.type() == jubatus::util::text::json::json::Null) {
throw JUBATUS_EXCEPTION(
common::config_exception() << common::exception::error_message(
"parameter block is not specified in config"));
}
}
util::lang::shared_ptr<unlearner::unlearner_config_base> unlearner_config_;

template<typename Ar>
void serialize(Ar& ar) {
classifier_config_base::serialize(ar);
unlearner_config_->serialize(ar);
}
};

struct nearest_neighbor_classifier_config : public classifier_config_base {
std::string method;
int nearest_neighbor_num;
float local_sensitivity;
util::lang::shared_ptr<unlearner::unlearner_config_base> unlearner_config_;
nearest_neighbor_classifier_config(const std::string& method,
const common::jsonconfig::config& param)
: classifier_config_base(method) {
// TODO
}

template<typename Ar>
void serialize(Ar& ar) {
classifier_config_base::serialize(ar);
ar & JUBA_MEMBER(method)
& JUBA_MEMBER(nearest_neighbor_num)
& JUBA_MEMBER(local_sensitivity);
if (unlearner_config_) {
unlearner_config_->serialize(ar);
}
}
};

} // namespace detail

struct classifier_config {
util::lang::shared_ptr<classifier_config_base> conf_;
classifier_config(const std::string& method,
const common::jsonconfig::config& param);
classifier_config() {
}
template<typename Ar>
void serialize(Ar& ar) {
conf_->serialize(ar);
}
};

Expand Down
Loading

0 comments on commit 4f25024

Please sign in to comment.