
# standard imports
import typing

# bsfs imports
from bsfs import schema as bsc
from bsfs.query import ast
from bsfs.utils import errors

# inner-module imports
from .utils import GenHopName, Query

# exports
__all__: typing.Sequence[str] = (
    'Fetch',
    )


## code ##

class Fetch():
    """Translate `bsfs.query.ast.fetch` structures into Sparql queries."""

    def __init__(self, schema):
        self.schema = schema
        self.ngen = GenHopName(prefix='?fch')

    def __call__(
            self,
            root_type: bsc.Node,
            root: ast.fetch.FetchExpression,
            ) -> Query:
        """
        """
        # check root_type
        if not isinstance(root_type, bsc.Node):
            raise errors.BackendError(f'expected Node, found {root_type}')
        if root_type not in self.schema.nodes():
            raise errors.ConsistencyError(f'node {root_type} is not in the schema')
        # parse root
        terms, expr = self._parse_fetch_expression(root_type, root, '?ent')
        # assemble query
        return Query(
            root_type=root_type.uri,
            root_head='?ent',
            select=terms,
            where=expr,
            )

    def _parse_fetch_expression(
            self,
            node_type: bsc.Vertex,
            node: ast.fetch.FetchExpression,
            head: str,
            ):
        """Route *node* to the handler of the respective FetchExpression subclass."""
        if isinstance(node, ast.fetch.All):
            return self._all(node_type, node, head)
        if isinstance(node, ast.fetch.Fetch):
            return self._fetch(node_type, node, head)
        if isinstance(node, ast.fetch.Node):
            return self._node(node_type, node, head)
        if isinstance(node, ast.fetch.Value):
            return self._value(node_type, node, head)
        if isinstance(node, ast.fetch.This):
            return self._this(node_type, node, head)
        # invalid node
        raise errors.BackendError(f'expected fetch expression, found {node}')

    def _all(self, node_type: bsc.Vertex, node: ast.fetch.All, head: str):
        # child expressions
        terms, exprs = zip(*[self._parse_fetch_expression(node_type, expr, head) for expr in node])
        terms = {term for sub in terms for term in sub}
        exprs = ' .\n'.join({expr for expr in exprs if len(expr.strip()) > 0})
        return terms, exprs

    def _fetch(self, node_type: bsc.Vertex, node: ast.fetch.Fetch, head: str): # pylint: disable=unused-argument # (node_type)
        # child expressions
        rng = self.schema.predicate(node.predicate).range
        nexthead = next(self.ngen)
        terms, expr = self._parse_fetch_expression(rng, node.expr, nexthead)
        return terms, f'OPTIONAL{{ {head} <{node.predicate}> {nexthead} .\n {expr} }}'

    def _node(self, node_type: bsc.Vertex, node: ast.fetch.Node, head: str): # pylint: disable=unused-argument # (node_type)
        if f'?{node.name}'.startswith(self.ngen.prefix):
            raise errors.BackendError(f'Node name must start with {self.ngen.prefix}')
        # compose and return statement
        term = next(self.ngen)
        return {(term, node.name)}, f'OPTIONAL{{ {head} <{node.predicate}> {term} }}'

    def _value(self, node_type: bsc.Vertex, node: ast.fetch.Value, head: str): # pylint: disable=unused-argument # (node_type)
        if f'?{node.name}'.startswith(self.ngen.prefix):
            raise errors.BackendError(f'Value name must start with {self.ngen.prefix}')
        # compose and return statement
        term = next(self.ngen)
        return {(term, node.name)}, f'OPTIONAL{{ {head} <{node.predicate}> {term} }}'

    def _this(self, node_type: bsc.Vertex, node: ast.fetch.This, head: str): # pylint: disable=unused-argument # (node_type)
        if f'?{node.name}'.startswith(self.ngen.prefix):
            raise errors.BackendError(f'This name must start with {self.ngen.prefix}')
        # compose and return statement
        return {(head, node.name)}, ''

## EOF ##
