
# imports
import operator
import typing

# external imports
import rdflib

# bsfs imports
from bsfs import schema as bsc
from bsfs.namespace import ns
from bsfs.query import ast
from bsfs.utils import URI, errors

# inner-module imports
from .distance import DISTANCE_FU
from .utils import GenHopName, Query

# exports
__all__: typing.Sequence[str] = (
    'Filter',
    )


## code ##

class Filter():
    """Translate `bsfs.query.ast.filter` structures into Sparql queries."""

    # Current schema to validate against.
    schema: bsc.Schema

    # Generator that produces unique symbol names.
    ngen: GenHopName

    def __init__(self, graph, schema):
        self.graph = graph
        self.schema = schema
        self.ngen = GenHopName(prefix='?flt')

    def __call__(
            self,
            root_type: bsc.Node,
            root: typing.Optional[ast.filter.FilterExpression] = None,
            ) -> Query:
        """
        """
        # check root_type
        if not isinstance(root_type, bsc.Node):
            raise errors.BackendError(f'expected Node, found {root_type}')
        if root_type not in self.schema.nodes():
            raise errors.ConsistencyError(f'node {root_type} is not in the schema')
        # parse root
        if root is None:
            cond = ''
        else:
            cond = self._parse_filter_expression(root_type, root, '?ent')
        # assemble query
        return Query(
            root_type=root_type.uri,
            root_head='?ent',
            where=cond,
            )

    def _parse_filter_expression(
            self,
            type_: bsc.Vertex,
            node: ast.filter.FilterExpression,
            head: str,
            ) -> str:
        """Route *node* to the handler of the respective FilterExpression subclass."""
        if isinstance(node, ast.filter.Is):
            return self._is(type_, node, head)
        if isinstance(node, ast.filter.Not):
            return self._not(type_, node, head)
        if isinstance(node, ast.filter.Has):
            return self._has(type_, node, head)
        if isinstance(node, ast.filter.Distance):
            return self._distance(type_, node, head)
        if isinstance(node, ast.filter.Any):
            return self._any(type_, node, head)
        if isinstance(node, ast.filter.All):
            return self._all(type_, node, head)
        if isinstance(node, ast.filter.And):
            return self._and(type_, node, head)
        if isinstance(node, ast.filter.Or):
            return self._or(type_, node, head)
        if isinstance(node, ast.filter.Equals):
            return self._equals(type_, node, head)
        if isinstance(node, ast.filter.Substring):
            return self._substring(type_, node, head)
        if isinstance(node, ast.filter.StartsWith):
            return self._starts_with(type_, node, head)
        if isinstance(node, ast.filter.EndsWith):
            return self._ends_with(type_, node, head)
        if isinstance(node, ast.filter.LessThan):
            return self._less_than(type_, node, head)
        if isinstance(node, ast.filter.GreaterThan):
            return self._greater_than(type_, node, head)
        # invalid node
        raise errors.BackendError(f'expected filter expression, found {node}')

    def _parse_predicate_expression(
            self,
            type_: bsc.Vertex,
            node: ast.filter.PredicateExpression
            ) -> typing.Tuple[str, bsc.Vertex]:
        """Route *node* to the handler of the respective PredicateExpression subclass."""
        if isinstance(node, ast.filter.Predicate):
            return self._predicate(type_, node)
        if isinstance(node, ast.filter.OneOf):
            return self._one_of(type_, node)
        # invalid node
        raise errors.BackendError(f'expected predicate expression, found {node}')

    def _one_of(self, node_type: bsc.Vertex, node: ast.filter.OneOf) -> typing.Tuple[str, bsc.Vertex]:
        """
        """
        if not isinstance(node_type, bsc.Node):
            raise errors.BackendError(f'expected Node, found {node_type}')
        # walk through predicates
        suburi, rng = set(), None
        for pred in node: # OneOf guarantees at least one expression
            puri, subrng = self._parse_predicate_expression(node_type, pred)
            # track predicate uris
            suburi.add(puri)
            # check for more generic range
            if rng is None or subrng > rng:
                rng = subrng
            # check range consistency
            if not subrng <= rng and not subrng >= rng:
                raise errors.ConsistencyError(f'ranges {subrng} and {rng} are not related')
        # return joint predicate expression and next range
        # OneOf guarantees at least one expression, rng is always a bsc.Vertex.
        # mypy does not realize this, hence we ignore the warning.
        return '|'.join(suburi), rng # type: ignore [return-value]

    def _predicate(self, node_type: bsc.Vertex, node: ast.filter.Predicate) -> typing.Tuple[str, bsc.Vertex]:
        """
        """
        # check node_type
        if not isinstance(node_type, bsc.Node):
            raise errors.BackendError(f'expected Node, found {node_type}')
        # fetch predicate and its uri
        puri = node.predicate
        # get and check predicate, domain, and range
        if not self.schema.has_predicate(puri):
            raise errors.ConsistencyError(f'predicate {puri} is not in the schema')
        pred = self.schema.predicate(puri)
        if not isinstance(pred.range, (bsc.Node, bsc.Literal)):
            raise errors.BackendError(f'the range of predicate {pred} is undefined')
        dom, rng = pred.domain, pred.range
        # encapsulate predicate uri
        uri_str = f'<{puri}>'
        # apply reverse flag
        if node.reverse:
            uri_str = '^' + uri_str
            dom, rng = rng, dom # type: ignore [assignment] # variable re-use confuses mypy
        # check path consistency
        if not node_type <= dom:
            raise errors.ConsistencyError(f'expected type {dom} or subtype thereof, found {node_type}')
        # return predicate URI and next node type
        return uri_str, rng

    def _any(self, node_type: bsc.Vertex, node: ast.filter.Any, head: str) -> str:
        """
        """
        if not isinstance(node_type, bsc.Node):
            raise errors.BackendError(f'expected Node, found {node_type}')
        # parse predicate
        pred, next_type = self._parse_predicate_expression(node_type, node.predicate)
        # parse expression
        nexthead = next(self.ngen)
        expr = self._parse_filter_expression(next_type, node.expr, nexthead)
        # combine results
        return f'{head} {pred} {nexthead} . {expr}'

    def _all(self, node_type: bsc.Vertex, node: ast.filter.All, head: str) -> str:
        """
        """
        # NOTE: All(P, E) := Not(Any(P, Not(E))) and EXISTS(P, ?)
        if not isinstance(node_type, bsc.Node):
            raise errors.BackendError(f'expected Node, found {node_type}')
        # parse rewritten ast
        expr = self._parse_filter_expression(node_type,
            ast.filter.Not(
                ast.filter.Any(node.predicate,
                    ast.filter.Not(node.expr))), head)
        # parse predicate for existence constraint
        pred, _ = self._parse_predicate_expression(node_type, node.predicate)
        temphead = next(self.ngen)
        # return existence and rewritten expression
        return f'FILTER EXISTS {{ {head} {pred} {temphead} }} . ' + expr

    def _and(self, node_type: bsc.Vertex, node: ast.filter.And, head: str) -> str:
        """
        """
        sub = [self._parse_filter_expression(node_type, expr, head) for expr in node]
        return ' . '.join(sub)

    def _or(self, node_type: bsc.Vertex, node: ast.filter.Or, head: str) -> str:
        """
        """
        # potential special case optimization:
        # * ast: Or(Equals('foo'), Equals('bar'), ...)
        # * query: VALUES ?head { "value1"^^<...> "value2"^^<...> "value3"^<...> ... }
        sub = [self._parse_filter_expression(node_type, expr, head) for expr in node]
        sub = ['{' + expr + '}' for expr in sub]
        return ' UNION '.join(sub)

    def _not(self, node_type: bsc.Vertex, node: ast.filter.Not, head: str) -> str:
        """
        """
        expr = self._parse_filter_expression(node_type, node.expr, head)
        if isinstance(node_type, bsc.Literal):
            return f'MINUS {{ {expr} }}'
        # NOTE: for bsc.Node types, we must include at least one expression in the body of MINUS,
        # otherwise the connection between the context and body of MINUS is lost.
        # The simplest (and non-interfering) choice is a type statement.
        return f'MINUS {{ {head} <{ns.rdf.type}>/<{ns.rdfs.subClassOf}>* <{node_type.uri}> . {expr} }}'

    def _has(self, node_type: bsc.Vertex, node: ast.filter.Has, head: str) -> str:
        """
        """
        if not isinstance(node_type, bsc.Node):
            raise errors.BackendError(f'expected Node, found {node_type}')
        # parse predicate
        pred, _ = self._parse_predicate_expression(node_type, node.predicate)
        # get new heads
        inner = next(self.ngen)
        outer = next(self.ngen)
        # predicate count expression (fetch number of predicates at *head*)
        num_preds = f'{{ SELECT (COUNT(distinct {inner}) as {outer}) WHERE {{ {head} {pred} {inner} }} }}'
        # count expression
        count_bounds = self._parse_filter_expression(self.schema.literal(ns.xsd.integer), node.count, outer)
        # combine
        return num_preds + ' . ' + count_bounds

    def _distance(self, node_type: bsc.Vertex, node: ast.filter.Distance, head: str) -> str:
        """
        """
        if not isinstance(node_type, bsc.Feature):
            raise errors.BackendError(f'expected Feature, found {node_type}')
        if len(node.reference) != node_type.dimension:
            raise errors.ConsistencyError(
                f'reference has dimension {len(node.reference)}, expected {node_type.dimension}')
        # get distance metric
        dist = DISTANCE_FU[node_type.distance]
        # get operator
        cmp = operator.lt if node.strict else operator.le
        # get candidate values
        candidates = {
            f'"{cand}"^^<{node_type.uri}>'
            for cand
            in self.graph.objects()
            if isinstance(cand, rdflib.Literal)
            and cand.datatype == rdflib.URIRef(node_type.uri)
            and cmp(dist(cand.value, node.reference), node.threshold)
            }
        # combine candidate values
        values = ' '.join(candidates) if len(candidates) else f'"impossible value"^^<{ns.xsd.string}>'
        # return sparql fragment
        return f'VALUES {head} {{ {values} }}'

    def _is(self, node_type: bsc.Vertex, node: ast.filter.Is, head: str) -> str:
        """
        """
        if not isinstance(node_type, bsc.Node):
            raise errors.BackendError(f'expected Node, found {node_type}')
        return f'VALUES {head} {{ <{URI(node.value)}> }}'

    def _equals(self, node_type: bsc.Vertex, node: ast.filter.Equals, head: str) -> str:
        """
        """
        if not isinstance(node_type, bsc.Literal):
            raise errors.BackendError(f'expected Literal, found {node}')
        return f'VALUES {head} {{ "{node.value}"^^<{node_type.uri}> }}'

    def _substring(self, node_type: bsc.Vertex, node: ast.filter.Substring, head: str) -> str:
        """
        """
        if not isinstance(node_type, bsc.Literal):
            raise errors.BackendError(f'expected Literal, found {node_type}')
        return f'FILTER contains(str({head}), "{node.value}")'

    def _starts_with(self, node_type: bsc.Vertex, node: ast.filter.StartsWith, head: str) -> str:
        """
        """
        if not isinstance(node_type, bsc.Literal):
            raise errors.BackendError(f'expected Literal, found {node_type}')
        return f'FILTER strstarts(str({head}), "{node.value}")'

    def _ends_with(self, node_type: bsc.Vertex, node: ast.filter.EndsWith, head: str) -> str:
        """
        """
        if not isinstance(node_type, bsc.Literal):
            raise errors.BackendError(f'expected Literal, found {node_type}')
        return f'FILTER strends(str({head}), "{node.value}")'

    def _less_than(self, node_type: bsc.Vertex, node: ast.filter.LessThan, head: str) -> str:
        """
        """
        if not isinstance(node_type, bsc.Literal):
            raise errors.BackendError(f'expected Literal, found {node_type}')
        equality = '=' if not node.strict else ''
        return f'FILTER ({head} <{equality} {float(node.threshold)})'

    def _greater_than(self, node_type: bsc.Vertex, node: ast.filter.GreaterThan, head: str) -> str:
        """
        """
        if not isinstance(node_type, bsc.Literal):
            raise errors.BackendError(f'expected Literal, found {node_type}')
        equality = '=' if not node.strict else ''
        return f'FILTER ({head} >{equality} {float(node.threshold)})'

## EOF ##
