#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Time    : 2021/5/11 20:18
# @Author  : 程婷婷
# @FileName: XgboostClassifyModel.py
# @Software: PyCharm
import scipy.sparse.csr
import scipy.sparse.csc
import pickle
import numpy as np
from xgboost import XGBClassifier
from model.base.views.model.BaseModel import BaseModel


class XgboostClassify(object):
    def __init__(self, label_dict, signature, lr=0.1, reg_alpha=0, reg_lambda=1, objective='binary:logitraw', \
                 with_sample_weight=True, subsample=1, min_child_weight=1, scale_pos_weight=1, thres=0.5):
        self.lr = lr
        self.label_dict = label_dict
        self.signature = signature
        self.reg_alpha = reg_alpha
        self.reg_lambda = reg_lambda
        self.objective = objective
        self.with_sample_weight = with_sample_weight
        self.min_child_weight = min_child_weight
        self.scale_pos_weight = scale_pos_weight
        self.thres = thres
        self.clf = None

    def set_signature(self, new_signature):
        self.signature = new_signature

    def train(self, X, Y, save_to=None):
        print(len(self.label_dict))
        assert len(self.label_dict) == 2, 'It should have exactly two classes.'
        if isinstance(X, scipy.sparse.csr.csr_matrix):
            data = X.tocsc()
        elif isinstance(X, np.ndarray):
            data = X
        else:
            data = np.array(X, copy=False)
        if isinstance(Y, scipy.sparse.csr.csr_matrix):
            label = Y.todense()
        else:
            label = np.array(Y, copy=False)
        if len(np.unique(label)) == 1:
            print('Only contains one label, training stopped.')
            return

        N_0 = np.sum(label == 0)
        N_1 = np.sum(label == 1)
        w_0 = (N_0 + N_1) / (2. * N_0)
        w_1 = (N_0 + N_1) / (2. * N_1)
        self.clf = XGBClassifier(reg_alpha=self.reg_alpha, reg_lambda=self.reg_lambda, objective=self.objective, \
                                 min_child_weight=self.min_child_weight, scale_pos_weight=self.scale_pos_weight,
                                 learning_rate=self.lr)
        if self.with_sample_weight:
            self.clf.fit(data, label, sample_weight=[w_0 if l == 0 else w_1 for l in label])
        else:
            self.clf.fit(data, label)
        # print('Finished.')
        if save_to:
            self.save(save_to)

    def save(self, save_to):
        file_name = save_to + ('-%s.xgb' % self.signature)
        with open(file_name, 'wb') as f:
            pickle.dump((self.clf, self.label_dict, self.signature), f)

    @staticmethod
    def load(file_path):
        with open(file_path, 'rb') as f:
            clf, label_dict, signature = pickle.load(f)
        xgb = Xgboost(label_dict, signature)
        xgb.clf = clf
        return xgb

    def predict(self, X, thres=0.5, return_real_label=False):
        prob = self.predict_pro(X)
        label = np.zeros((prob.shape[0],))
        label[prob[:, 1] >= thres] = 1
        if return_real_label:
            return [self.label_dict[l] for l in label]
        else:
            return label.astype(np.int64)

    def sigmoid(self, x):
        return 1 / (1 + np.exp(-x))

    def predict_pro(self, X):
        if not (isinstance(X, scipy.sparse.csr.csr_matrix) or isinstance(X, np.ndarray) or isinstance(X,
                                                                                                      scipy.sparse.csc.csc_matrix)):
            X = np.array(X, copy=False)
        if isinstance(X, scipy.sparse.csr.csr_matrix):
            X = X.tocsc()
        if self.clf and X.shape[0] > 0:
            if len(X.shape) == 1:
                X = [X]
            prob = self.clf.predict_proba(X)
            prob = np.array([self.sigmoid(i) for i in prob[:]])
            return prob
        else:
            if not self.clf:
                print('模型还没训练，请先训练模型')
            else:
                print('数据不能为空')


class XgboostClassifyModel(BaseModel):
    def __init__(self, config_path):
        super().__init__(config_path)

    def building_model(self, label_dict, signature, X_train, Y_train):
        xgb = XgboostClassify(label_dict,
                      signature,
                      lr=self.model_config['lr'],
                      reg_alpha=self.model_config['reg_alpha'],
                      reg_lambda=self.model_config['reg_lambda'],
                      objective=self.model_config['objective'],
                      with_sample_weight=self.model_config['with_sample_weight'],
                      subsample=self.model_config['subsample'],
                      thres=self.model_config['thres'],
                      min_child_weight=self.model_config['min_child_weight'],
                      scale_pos_weight=self.model_config['scale_pos_weight'])
        clf_save_to = self.model_config['model_path']
        print('开始训练')
        xgb.train(X_train, Y_train, save_to=clf_save_to)
        print('训练结束')
        return xgb