Source code for rupo.api

# -*- coding: utf-8 -*-
# Автор: Гусев Илья
# Описание: Набор внешних методов для работы с библиотекой.

import os
from typing import List, Tuple, Dict

from rulm.language_model import LanguageModel

from rupo.files.reader import FileType, Reader
from rupo.files.writer import Writer
from rupo.main.markup import Markup
from rupo.metre.metre_classifier import MetreClassifier, ClassificationResult
from rupo.rhymes.rhymes import Rhymes
from rupo.settings import ZALYZNYAK_DICT, CMU_DICT, DATA_DIR, DICT_DIR
from rupo.stress.predictor import StressPredictor, CombinedStressPredictor
from rupo.main.vocabulary import StressVocabulary, inflate_stress_vocabulary
from rupo.generate.generator import Generator

from allennlp.data.vocabulary import Vocabulary, DEFAULT_OOV_TOKEN
from allennlp.common.util import END_SYMBOL
from rulm.transform import ExcludeTransform
from russ.syllables import get_syllables


[docs]class Engine: def __init__(self, language="ru"): self.language = language # type: str self.vocabulary = None # type: StressVocabulary self.generator = None # type: Generator self.stress_predictors = dict() # type: Dict[str, StressPredictor]
[docs] def load(self, stress_model_path: str, zalyzniak_dict: str, raw_stress_dict_path=None, stress_trie_path=None): self.stress_predictors = dict() if not os.path.isdir(DATA_DIR): os.makedirs(DATA_DIR) if not os.path.isdir(DICT_DIR): os.makedirs(DICT_DIR) self.get_stress_predictor(self.language, stress_model_path, raw_stress_dict_path, stress_trie_path, zalyzniak_dict)
[docs] def get_vocabulary(self, dump_path: str, markup_path: str) -> StressVocabulary: if self.vocabulary is None: self.vocabulary = StressVocabulary() if os.path.isfile(dump_path): self.vocabulary.load(dump_path) elif markup_path is not None: self.vocabulary.parse(markup_path) return self.vocabulary
[docs] def get_generator(self, model_path: str, token_vocab_path: str, stress_vocab_dump_path: str) -> Generator: if self.generator is None: assert os.path.isdir(model_path) and os.path.isdir(token_vocab_path) vocabulary = Vocabulary.from_files(token_vocab_path) stress_vocabulary = StressVocabulary() if not os.path.isfile(stress_vocab_dump_path): stress_vocabulary = inflate_stress_vocabulary(vocabulary, self.get_stress_predictor()) stress_vocabulary.save(stress_vocab_dump_path) else: stress_vocabulary.load(stress_vocab_dump_path) eos_index = vocabulary.get_token_index(END_SYMBOL) unk_index = vocabulary.get_token_index(DEFAULT_OOV_TOKEN) exclude_transform = ExcludeTransform((unk_index, eos_index)) model = LanguageModel.load(model_path, vocabulary_dir=token_vocab_path, transforms=[exclude_transform, ]) self.generator = Generator(model, vocabulary, stress_vocabulary, eos_index) return self.generator
[docs] def get_stress_predictor(self, language="ru", stress_model_path: str=None, raw_stress_dict_path=None, stress_trie_path=None, zalyzniak_dict=ZALYZNYAK_DICT, cmu_dict=CMU_DICT): if self.stress_predictors.get(language) is None: self.stress_predictors[language] = CombinedStressPredictor(language, stress_model_path, raw_stress_dict_path, stress_trie_path, zalyzniak_dict, cmu_dict) return self.stress_predictors[language]
[docs] def get_stresses(self, word: str, language: str="ru") -> List[int]: """ :param word: слово. :param language: язык. :return: ударения слова. """ return self.get_stress_predictor(language).predict(word)
[docs] @staticmethod def get_word_syllables(word: str) -> List[str]: """ :param word: слово. :return: его слоги. """ return [syllable.text for syllable in get_syllables(word)]
[docs] @staticmethod def count_syllables(word: str) -> int: """ :param word: слово. :return: количество слогов в нём. """ return len(get_syllables(word))
[docs] def get_markup(self, text: str, language: str="ru") -> Markup: """ :param text: текст. :param language: язык. :return: его разметка по словарю. """ return Markup.process_text(text, self.get_stress_predictor(language))
[docs] def get_improved_markup(self, text: str, language: str="ru") -> Tuple[Markup, ClassificationResult]: """ :param text: текст. :param language: язык. :return: его разметка по словарю, классификатору метру и ML классификатору. """ markup = Markup.process_text(text, self.get_stress_predictor(language)) return MetreClassifier.improve_markup(markup)
[docs] def classify_metre(self, text: str, language: str="ru") -> str: """ :param text: текст. :param language: язык. :return: его метр. """ return MetreClassifier.classify_metre(Markup.process_text(text, self.get_stress_predictor(language))).metre
[docs] def generate_markups(self, input_path: str, input_type: FileType, output_path: str, output_type: FileType) -> None: """ Генерация разметок по текстам. :param input_path: путь к папке/файлу с текстом. :param input_type: тип файлов с текстов. :param output_path: путь к файлу с итоговыми разметками. :param output_type: тип итогового файла. """ markups = Reader.read_markups(input_path, input_type, False, self.get_stress_predictor()) writer = Writer(output_type, output_path) writer.open() for markup in markups: writer.write_markup(markup) writer.close()
[docs] def is_rhyme(self, word1: str, word2: str) -> bool: """ :param word1: первое слово. :param word2: второе слово. :return: рифмуются ли слова. """ markup_word1 = self.get_markup(word1).lines[0].words[0] markup_word1.set_stresses(self.get_stresses(word1)) markup_word2 = self.get_markup(word2).lines[0].words[0] markup_word2.set_stresses(self.get_stresses(word2)) return Rhymes.is_rhyme(markup_word1, markup_word2)
[docs] def generate_poem(self, model_path: str, token_vocab_path: str=None, stress_vocab_path: str=None, metre_schema: str="-+", rhyme_pattern: str="abab", n_syllables: int=8, sampling_k: int=None, beam_width: int=None, seed: int=1337, temperature: float=1.0, last_text: str="") -> str: """ Сгенерировать стих. Нужно задать либо sampling_k, либо beam_width. :param model_path: путь к модели. :param token_vocab_path: путь к словарю. :param stress_vocab_path: путь к словарю ударений. :param metre_schema: схема метра. :param rhyme_pattern: схема рифм. :param n_syllables: количество слогов в строке. :param sampling_k: top-k при семплинге :param beam_width: ширина лучевого поиска. :param seed: seed :param temperature: температура генерации :param last_text: последняя строчка :return: стих. None, если генерация не была успешной. """ token_vocab_path = token_vocab_path or os.path.join(model_path, "vocabulary") stress_vocab_path = stress_vocab_path or os.path.join(model_path, "stress.pickle") generator = self.get_generator(model_path, token_vocab_path, stress_vocab_path) poem = generator.generate_poem( metre_schema=metre_schema, rhyme_pattern=rhyme_pattern, n_syllables=n_syllables, sampling_k=sampling_k, beam_width=beam_width, temperature=temperature, seed=seed, last_text=last_text ) return poem
[docs] def get_word_rhymes(self, word: str, vocab_dump_path: str, markup_path: str=None) -> List[str]: """ Поиск рифмы для данного слова. :param word: слово. :param vocab_dump_path: путь, куда сохраняется словарь. :param markup_path: путь к разметкам. :return: список рифм. """ markup_word = self.get_markup(word).lines[0].words[0] markup_word.set_stresses(self.get_stresses(word)) rhymes = [] vocabulary = self.get_vocabulary(vocab_dump_path, markup_path) for i in range(vocabulary.size()): if Rhymes.is_rhyme(markup_word, vocabulary.get_word(i)): rhymes.append(vocabulary.get_word(i).text.lower()) return rhymes