Source code for pyformlang.fst.fst

""" Finite State Transducer """
import json
from typing import Any, Iterable

import networkx as nx
from networkx.drawing.nx_pydot import write_dot

from pyformlang.indexed_grammar import DuplicationRule, ProductionRule, \
    EndRule, ConsumptionRule, IndexedGrammar, Rules


[docs]class FST: """ Representation of a Finite State Transducer""" def __init__(self): self._states = set() # Set of states self._input_symbols = set() # Set of input symbols self._output_symbols = set() # Set of output symbols # Dict from _states x _input_symbols U {epsilon} into a subset of # _states X _output_symbols* self._delta = {} self._start_states = set() self._final_states = set() # _final_states is final states @property def states(self): """ Get the states of the FST Returns ---------- states : set of any The states """ return self._states @property def input_symbols(self): """ Get the input symbols of the FST Returns ---------- input_symbols : set of any The input symbols of the FST """ return self._input_symbols @property def output_symbols(self): """ Get the output symbols of the FST Returns ---------- output_symbols : set of any The output symbols of the FST """ return self._output_symbols @property def start_states(self): """ Get the start states of the FST Returns ---------- start_states : set of any The start states of the FST """ return self._start_states @property def final_states(self): """ Get the final states of the FST Returns ---------- final_states : set of any The final states of the FST """ return self._final_states @property def transitions(self): """Gives the transitions as a dictionary""" return self._delta
[docs] def get_number_transitions(self) -> int: """ Get the number of transitions in the FST Returns ---------- n_transitions : int The number of transitions """ return sum(len(x) for x in self._delta.values())
[docs] def add_transition(self, s_from: Any, input_symbol: Any, s_to: Any, output_symbols: Iterable[Any]): """ Add a transition to the FST Parameters ----------- s_from : any The source state input_symbol : any The symbol to read s_to : any The destination state output_symbols : iterable of Any The symbols to output """ self._states.add(s_from) self._states.add(s_to) if input_symbol != "epsilon": self._input_symbols.add(input_symbol) for output_symbol in output_symbols: if output_symbol != "epsilon": self._output_symbols.add(output_symbol) head = (s_from, input_symbol) if head in self._delta: self._delta[head].append((s_to, output_symbols)) else: self._delta[head] = [(s_to, output_symbols)]
[docs] def add_transitions(self, transitions_list): """ Adds several transitions to the FST Parameters ---------- transitions_list : list of tuples The tuples have the form (s_from, in_symbol, s_to, out_symbols) """ for s_from, input_symbol, s_to, output_symbols in transitions_list: self.add_transition( s_from, input_symbol, s_to, output_symbols )
[docs] def add_start_state(self, start_state: Any): """ Add a start state Parameters ---------- start_state : any The start state """ self._states.add(start_state) self._start_states.add(start_state)
[docs] def add_final_state(self, final_state: Any): """ Add a final state Parameters ---------- final_state : any The final state to add """ self._final_states.add(final_state) self._states.add(final_state)
[docs] def translate(self, input_word: Iterable[Any], max_length: int = -1) -> \ Iterable[Any]: """ Translate a string into another using the FST Parameters ---------- input_word : iterable of any The word to translate max_length : int, optional The maximum size of the output word, to prevent infinite \ generation due to epsilon transitions Returns ---------- output_word : iterable of any The translation of the input word """ # (remaining in the input, generated so far, current_state) to_process = [] seen_by_state = {state: [] for state in self.states} for start_state in self._start_states: to_process.append((input_word, [], start_state)) while to_process: remaining, generated, current_state = to_process.pop() if (remaining, generated) in seen_by_state[current_state]: continue seen_by_state[current_state].append((remaining, generated)) if len(remaining) == 0 and current_state in self._final_states: yield generated # We try to read an input if len(remaining) != 0: for next_state, output_string in self._delta.get( (current_state, remaining[0]), []): to_process.append( (remaining[1:], generated + output_string, next_state)) # We try to read an epsilon transition if max_length == -1 or len(generated) < max_length: for next_state, output_string in self._delta.get( (current_state, "epsilon"), []): to_process.append((remaining, generated + output_string, next_state))
[docs] def intersection(self, indexed_grammar): """ Compute the intersection with an other object Equivalent to: >> fst and indexed_grammar """ rules = indexed_grammar.rules new_rules = [EndRule("T", "epsilon")] self._extract_consumption_rules_intersection(rules, new_rules) self._extract_indexed_grammar_rules_intersection(rules, new_rules) self._extract_terminals_intersection(rules, new_rules) self._extract_epsilon_transitions_intersection(new_rules) self._extract_fst_delta_intersection(new_rules) self._extract_fst_epsilon_intersection(new_rules) self._extract_fst_duplication_rules_intersection(new_rules) rules = Rules(new_rules, rules.optim) return IndexedGrammar(rules).remove_useless_rules()
def _extract_fst_duplication_rules_intersection(self, new_rules): for state_p in self._final_states: for start_state in self._start_states: new_rules.append(DuplicationRule( "S", str((start_state, "S", state_p)), "T")) def _extract_fst_epsilon_intersection(self, new_rules): for state_p in self._states: new_rules.append(EndRule( str((state_p, "epsilon", state_p)), "epsilon")) def _extract_fst_delta_intersection(self, new_rules): for key, pair in self._delta.items(): state_p = key[0] terminal = key[1] for transition in pair: state_q = transition[0] symbol = transition[1] new_rules.append(EndRule(str((state_p, terminal, state_q)), symbol)) def _extract_epsilon_transitions_intersection(self, new_rules): for state_p in self._states: for state_q in self._states: for state_r in self._states: new_rules.append(DuplicationRule( str((state_p, "epsilon", state_q)), str((state_p, "epsilon", state_r)), str((state_r, "epsilon", state_q)))) def _extract_indexed_grammar_rules_intersection(self, rules, new_rules): for rule in rules.rules: if rule.is_duplication(): for state_p in self._states: for state_q in self._states: for state_r in self._states: new_rules.append(DuplicationRule( str((state_p, rule.left_term, state_q)), str((state_p, rule.right_terms[0], state_r)), str((state_r, rule.right_terms[1], state_q)))) elif rule.is_production(): for state_p in self._states: for state_q in self._states: new_rules.append(ProductionRule( str((state_p, rule.left_term, state_q)), str((state_p, rule.right_term, state_q)), str(rule.production))) elif rule.is_end_rule(): for state_p in self._states: for state_q in self._states: new_rules.append(DuplicationRule( str((state_p, rule.left_term, state_q)), str((state_p, rule.right_term, state_q)), "T")) def _extract_terminals_intersection(self, rules, new_rules): terminals = rules.terminals for terminal in terminals: for state_p in self._states: for state_q in self._states: for state_r in self._states: new_rules.append(DuplicationRule( str((state_p, terminal, state_q)), str((state_p, "epsilon", state_r)), str((state_r, terminal, state_q)))) new_rules.append(DuplicationRule( str((state_p, terminal, state_q)), str((state_p, terminal, state_r)), str((state_r, "epsilon", state_q)))) def _extract_consumption_rules_intersection(self, rules, new_rules): consumptions = rules.consumption_rules for consumption_rule in consumptions: for consumption in consumptions[consumption_rule]: for state_r in self._states: for state_s in self._states: new_rules.append(ConsumptionRule( consumption.f_parameter, str((state_r, consumption.left_term, state_s)), str((state_r, consumption.right, state_s)))) def __and__(self, other): return self.intersection(other)
[docs] def union(self, other_fst): """ Makes the union of two fst Parameters ---------- other_fst : :class:`~pyformlang.fst.FST` The other FST Returns ------- union_fst : :class:`~pyformlang.fst.FST` A new FST which is the union of the two given FST """ state_renaming = self._get_state_renaming(other_fst) union_fst = FST() # pylint: disable=protected-access self._copy_into(union_fst, state_renaming, 0) other_fst._copy_into(union_fst, state_renaming, 1) return union_fst
def __or__(self, other_fst): """ Makes the union of two fst Parameters ---------- other_fst : :class:`~pyformlang.fst.FST` The other FST Returns ------- union_fst : :class:`~pyformlang.fst.FST` A new FST which is the union of the two given FST """ return self.union(other_fst) def _copy_into(self, union_fst, state_renaming, idx): self._add_extremity_states_to(union_fst, state_renaming, idx) self._add_transitions_to(union_fst, state_renaming, idx) def _add_transitions_to(self, union_fst, state_renaming, idx): for head, transition in self.transitions.items(): s_from, input_symbol = head for s_to, output_symbols in transition: union_fst.add_transition( state_renaming.get_name(s_from, idx), input_symbol, state_renaming.get_name(s_to, idx), output_symbols) def _add_extremity_states_to(self, union_fst, state_renaming, idx): self._add_start_states_to(union_fst, state_renaming, idx) self._add_final_states_to(union_fst, state_renaming, idx) def _add_final_states_to(self, union_fst, state_renaming, idx): for state in self.final_states: union_fst.add_final_state(state_renaming.get_name(state, idx)) def _add_start_states_to(self, union_fst, state_renaming, idx): for state in self.start_states: union_fst.add_start_state(state_renaming.get_name(state, idx))
[docs] def concatenate(self, other_fst): """ Makes the concatenation of two fst Parameters ---------- other_fst : :class:`~pyformlang.fst.FST` The other FST Returns ------- fst_concatenate : :class:`~pyformlang.fst.FST` A new FST which is the concatenation of the two given FST """ state_renaming = self._get_state_renaming(other_fst) fst_concatenate = FST() self._add_start_states_to(fst_concatenate, state_renaming, 0) # pylint: disable=protected-access other_fst._add_final_states_to(fst_concatenate, state_renaming, 1) self._add_transitions_to(fst_concatenate, state_renaming, 0) other_fst._add_transitions_to(fst_concatenate, state_renaming, 1) for final_state in self.final_states: for start_state in other_fst.start_states: fst_concatenate.add_transition( state_renaming.get_name(final_state, 0), "epsilon", state_renaming.get_name(start_state, 1), [] ) return fst_concatenate
def __add__(self, other): """ Makes the concatenation of two fst Parameters ---------- other_fst : :class:`~pyformlang.fst.FST` The other FST Returns ------- fst_concatenate : :class:`~pyformlang.fst.FST` A new FST which is the concatenation of the two given FST """ return self.concatenate(other) def _get_state_renaming(self, other_fst): state_renaming = FSTStateRemaining() state_renaming.add_states(self.states, 0) state_renaming.add_states(other_fst.states, 1) return state_renaming
[docs] def kleene_star(self): """ Computes the kleene star of the FST Returns ------- fst_star : :class:`~pyformlang.fst.FST` A FST representing the kleene star of the FST """ fst_star = FST() state_renaming = FSTStateRemaining() state_renaming.add_states(self.states, 0) self._add_extremity_states_to(fst_star, state_renaming, 0) self._add_transitions_to(fst_star, state_renaming, 0) for final_state in self.final_states: for start_state in self.start_states: fst_star.add_transition( state_renaming.get_name(final_state, 0), "epsilon", state_renaming.get_name(start_state, 0), [] ) for final_state in self.start_states: for start_state in self.final_states: fst_star.add_transition( state_renaming.get_name(final_state, 0), "epsilon", state_renaming.get_name(start_state, 0), [] ) return fst_star
[docs] def to_networkx(self) -> nx.MultiDiGraph: """ Transform the current fst into a networkx graph Returns ------- graph : networkx.MultiDiGraph A networkx MultiDiGraph representing the fst """ graph = nx.MultiDiGraph() for state in self._states: graph.add_node(state, is_start=state in self.start_states, is_final=state in self.final_states, peripheries=2 if state in self.final_states else 1, label=state) if state in self.start_states: graph.add_node("starting_" + str(state), label="", shape=None, height=.0, width=.0) graph.add_edge("starting_" + str(state), state) for s_from, input_symbol in self._delta: for s_to, output_symbols in self._delta[(s_from, input_symbol)]: graph.add_edge( s_from, s_to, label=(json.dumps(input_symbol) + " -> " + json.dumps(output_symbols))) return graph
[docs] @classmethod def from_networkx(cls, graph): """ Import a networkx graph into an finite state transducer. \ The imported graph requires to have the good format, i.e. to come \ from the function to_networkx Parameters ---------- graph : The graph representation of the FST Returns ------- enfa : A FST read from the graph TODO ------- * Explain the format """ fst = FST() for s_from in graph: for s_to in graph[s_from]: for transition in graph[s_from][s_to].values(): if "label" in transition: in_symbol, out_symbols = transition["label"].split( " -> ") in_symbol = json.loads(in_symbol) out_symbols = json.loads(out_symbols) fst.add_transition(s_from, in_symbol, s_to, out_symbols) for node in graph.nodes: if graph.nodes[node].get("is_start", False): fst.add_start_state(node) if graph.nodes[node].get("is_final", False): fst.add_final_state(node) return fst
[docs] def write_as_dot(self, filename): """ Write the FST in dot format into a file Parameters ---------- filename : str The filename where to write the dot file """ write_dot(self.to_networkx(), filename)
class FSTStateRemaining: """Class for remaining the states in FST""" def __init__(self): self._state_renaming = {} self._seen_states = set() def add_state(self, state, idx): """ Add a state Parameters ---------- state : str The state to add idx : int The index of the FST """ if state in self._seen_states: counter = 0 new_state = state + str(counter) while new_state in self._seen_states: counter += 1 new_state = state + str(counter) self._state_renaming[(state, idx)] = new_state self._seen_states.add(new_state) else: self._state_renaming[(state, idx)] = state self._seen_states.add(state) def add_states(self, states, idx): """ Add states Parameters ---------- states : list of str The states to add idx : int The index of the FST """ for state in states: self.add_state(state, idx) def get_name(self, state, idx): """ Get the renaming. Parameters ---------- state : str The state to rename idx : int The index of the FST Returns ------- new_name : str The new name of the state """ return self._state_renaming[(state, idx)]