#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Time    : 2021/5/11 20:18
# @Author  : 程婷婷
# @FileName: XgboostClassifyModel.py
# @Software: PyCharm
from torch.optim import Adam
from torch.optim.lr_scheduler import OneCycleLR
from flair.models import TextClassifier
from flair.trainers import ModelTrainer
from model.base.views.model.BaseModel import BaseModel


class FlairClassifyModel(BaseModel):
    def __init__(self, config_path):
        super().__init__(config_path)

    def building_model(self, corpus, document_embeddings, label_dict, loss_weights):
        # downstream classifier
        classifier = TextClassifier(
            document_embeddings,
            label_dictionary=label_dict,
            loss_weights=loss_weights
        )

        # model trainer
        trainer = ModelTrainer(classifier, corpus, optimizer=Adam)
        model_save_path = self.model_config['model_path']
        trainer.train(str(model_save_path),
                      learning_rate=3e-5,  # use very small learning rate
                      mini_batch_size=16,
                      scheduler=OneCycleLR,
                      mini_batch_chunk_size=2,  # optionally set this if transformer is too much for your machine
                      max_epochs=3,  # terminate after X epochs
                      monitor_train=True,
                      monitor_test=True,
                      checkpoint=True
                      )
        return classifier, trainer
