Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(WIP)import export #131

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions jubatus/core/anomaly/anomaly_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "../common/exception.hpp"
#include "../common/jsonconfig.hpp"
#include "../nearest_neighbor/nearest_neighbor_factory.hpp"
#include "../unlearner/unlearner_config.hpp"
#include "../unlearner/unlearner_factory.hpp"
#include "../storage/column_table.hpp"
#include "../recommender/recommender_factory.hpp"
Expand Down Expand Up @@ -100,10 +101,11 @@ shared_ptr<anomaly_base> anomaly_factory::create_anomaly(
<< common::exception::error_message(
"unlearner is set but unlearner_parameter is not found"));
}
shared_ptr<unlearner::unlearner_config_base> unl_conf(
unlearner::create_unlearner_config(*conf.unlearner,
*conf.unlearner_parameter));
jubatus::util::lang::shared_ptr<unlearner::unlearner_base> unlearner(
unlearner::create_unlearner(
*conf.unlearner,
*conf.unlearner_parameter));
unlearner::create_unlearner(unl_conf));
return shared_ptr<anomaly_base>(
new light_lof(conf, id, nearest_neighbor_engine, unlearner));
}
Expand Down
18 changes: 7 additions & 11 deletions jubatus/core/classifier/arow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,20 @@
#include "jubatus/util/concurrent/lock.h"
#include "classifier_util.hpp"
#include "../common/exception.hpp"
#include "../storage/storage_base.hpp"

using std::string;
using jubatus::core::storage_ptr;

namespace jubatus {
namespace core {
namespace classifier {

arow::arow(storage_ptr storage)
: linear_classifier(storage) {
}

arow::arow(
const classifier_config& config,
storage_ptr storage)
: linear_classifier(storage),
config_(config) {
arow::arow(float regularization_weight)
: linear_classifier(),
regularization_weight_(regularization_weight) {

if (!(0.f < config.regularization_weight)) {
if (!(0.f < regularization_weight_)) {
throw JUBATUS_EXCEPTION(
common::invalid_parameter("0.0 < regularization_weight"));
}
Expand All @@ -58,7 +54,7 @@ void arow::train(const common::sfv_t& sfv, const string& label) {
return;
}

float beta = 1.f / (variance + 1.f / config_.regularization_weight);
float beta = 1.f / (variance + 1.f / regularization_weight_);
float alpha = (1.f - margin) * beta; // max(0, 1 - margin) = 1 - margin
update(sfv, alpha, beta, label, incorrect_label);
}
Expand Down
5 changes: 2 additions & 3 deletions jubatus/core/classifier/arow.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ namespace classifier {

class arow : public linear_classifier {
public:
explicit arow(storage_ptr storage);
arow(const classifier_config& config, storage_ptr storage);
explicit arow(float regularization_weight);
void train(const common::sfv_t& fv, const std::string& label);
std::string name() const;
private:
Expand All @@ -38,7 +37,7 @@ class arow : public linear_classifier {
float beta,
const std::string& pos_label,
const std::string& neg_label);
classifier_config config_;
float regularization_weight_;
};

} // namespace classifier
Expand Down
5 changes: 5 additions & 0 deletions jubatus/core/classifier/classifier_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class classifier_base {
virtual void set_label_unlearner(
jubatus::util::lang::shared_ptr<unlearner::unlearner_base>
label_unlearner) = 0;
virtual jubatus::util::lang::shared_ptr<unlearner::unlearner_base>
get_label_unlearner() const = 0;

virtual bool delete_label(const std::string& label) = 0;
virtual std::vector<std::string> get_labels() const = 0;
Expand All @@ -61,6 +63,9 @@ class classifier_base {
virtual void unpack(msgpack::object o) = 0;
virtual void clear() = 0;

void import_model(msgpack::object src);
void export_model(framework::packer& dst) const;

virtual framework::mixable* get_mixable() = 0;
};

Expand Down
70 changes: 70 additions & 0 deletions jubatus/core/classifier/classifier_config.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Jubatus: Online machine learning framework for distributed environment
// Copyright (C) 2015 Preferred Networks and Nippon Telegraph and Telephone Corporation.
//
// This library is free software; you can redistribute it and/or
// modify it under the terms of the GNU Lesser General Public
// License version 2.1 as published by the Free Software Foundation.
//
// This library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
// Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public
// License along with this library; if not, write to the Free Software
// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA

#include "classifier_config.hpp"

#include <string>
#include "jubatus/util/data/serialization.h"
#include "jubatus/util/data/optional.h"
#include "jubatus/util/lang/shared_ptr.h"
#include "../common/jsonconfig.hpp"
#include "../unlearner/unlearner_config.hpp"
#include "../storage/column_table.hpp"
#include "../nearest_neighbor/nearest_neighbor_base.hpp"

using jubatus::util::lang::shared_ptr;
using jubatus::core::common::jsonconfig::config_cast_check;

namespace jubatus {
namespace core {
namespace classifier {
classifier_config::classifier_config(const std::string& method,
const common::jsonconfig::config& param) {
if (method == "perceptron" ||
method == "PA" || method == "passive_aggressive") {
if (param.type() != jubatus::util::text::json::json::Null) {
throw JUBATUS_EXCEPTION(
common::config_exception() << common::exception::error_message(
"parameter block is specified in config"));
}
conf_.reset(new detail::unlearning_classifier_config(method, param));
} else if (method == "PA1" || method == "passive_aggressive_1" ||
method == "PA2" || method == "passive_aggressive_2" ||
method == "CW" || method == "confidence_weighted" ||
method == "AROW" || method == "arow" ||
method == "NHERD" || method == "normal_herd") {
if (param.type() == jubatus::util::text::json::json::Null) {
throw JUBATUS_EXCEPTION(
common::config_exception() << common::exception::error_message(
"parameter block is not specified in config"));
}
conf_.reset(new detail::unlearning_classifier_config(method, param));
} else if (method == "NN" || method == "nearest_neighbor") {
if (param.type() == jubatus::util::text::json::json::Null) {
throw JUBATUS_EXCEPTION(
common::config_exception() << common::exception::error_message(
"parameter block is not specified in config"));
}
conf_.reset(new detail::nearest_neighbor_classifier_config(method, param));
} else {
throw JUBATUS_EXCEPTION(
common::unsupported_method("classifier(" + method + ")"));
}
}

} // namespace classifier
} // namespace core
} // namespace jubatus
86 changes: 81 additions & 5 deletions jubatus/core/classifier/classifier_config.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Jubatus: Online machine learning framework for distributed environment
// Copyright (C) 2012 Preferred Networks and Nippon Telegraph and Telephone Corporation.
// Copyright (C) 2015 Preferred Networks and Nippon Telegraph and Telephone Corporation.
//
// This library is free software; you can redistribute it and/or
// modify it under the terms of the GNU Lesser General Public
Expand All @@ -17,22 +17,98 @@
#ifndef JUBATUS_CORE_CLASSIFIER_CLASSIFIER_CONFIG_HPP_
#define JUBATUS_CORE_CLASSIFIER_CLASSIFIER_CONFIG_HPP_

#include <string>
#include "jubatus/util/data/serialization.h"
#include "jubatus/util/data/optional.h"
#include "jubatus/util/lang/shared_ptr.h"
#include "../unlearner/unlearner_config.hpp"

namespace jubatus {
namespace core {
namespace classifier {

struct classifier_config {
classifier_config()
: regularization_weight(1.0f) {
struct classifier_config_base {
std::string method_;
virtual ~classifier_config_base() {}
explicit classifier_config_base(const std::string& method)
: method_(method) {
}

template<typename Ar>
void serialize(Ar& ar) {
ar & JUBA_NAMED_MEMBER("method", method_);
}
};

namespace detail {
struct classifier_parameter : public classifier_config_base {
explicit classifier_parameter(const std::string& method)
: classifier_config_base(method),
regularization_weight(1.0f) {
}
float regularization_weight;

template<typename Ar>
void serialize(Ar& ar) {
ar & JUBA_NAMED_MEMBER("regularization_weight", regularization_weight);
classifier_config_base::serialize(ar);
ar & JUBA_MEMBER(regularization_weight);
}
};

struct unlearning_classifier_config : public classifier_parameter {
unlearning_classifier_config(const std::string& method,
const common::jsonconfig::config& param)
: classifier_parameter(method) {
// TODO // NOLINT
if (param.type() == jubatus::util::text::json::json::Null) {
throw JUBATUS_EXCEPTION(
common::config_exception() << common::exception::error_message(
"parameter block is not specified in config"));
}
}
util::lang::shared_ptr<unlearner::unlearner_config_base> unlearner_config_;

template<typename Ar>
void serialize(Ar& ar) {
classifier_config_base::serialize(ar);
unlearner_config_->serialize(ar);
}
};

struct nearest_neighbor_classifier_config : public classifier_config_base {
std::string method;
int nearest_neighbor_num;
float local_sensitivity;
util::lang::shared_ptr<unlearner::unlearner_config_base> unlearner_config_;
nearest_neighbor_classifier_config(const std::string& method,
const common::jsonconfig::config& param)
: classifier_config_base(method) {
// TODO // NOLINT
}

template<typename Ar>
void serialize(Ar& ar) {
classifier_config_base::serialize(ar);
ar & JUBA_MEMBER(method)
& JUBA_MEMBER(nearest_neighbor_num)
& JUBA_MEMBER(local_sensitivity);
if (unlearner_config_) {
unlearner_config_->serialize(ar);
}
}
};

} // namespace detail

struct classifier_config {
util::lang::shared_ptr<classifier_config_base> conf_;
classifier_config(const std::string& method,
const common::jsonconfig::config& param);
classifier_config() {
}
template<typename Ar>
void serialize(Ar& ar) {
conf_->serialize(ar);
}
};

Expand Down
Loading