
# imports
import typing

# bsfs imports
from bsfs import schema as bsc
from bsfs.namespace import ns
from bsfs.utils import errors, typename

# inner-module imports
from . import ast

# exports
__all__ : typing.Sequence[str] = (
    'Filter',
    )

# FIXME: Split into a submodule and the two classes into their own respective files.

## code ##

class Filter():
    """Validate a `bsfs.query.ast.filter` query's structure and schema compliance.

    * Conditions (Bounded, Value) can only be applied on literals
    * Branches, Id, and Has can only be applied on nodes
    * Predicates' domain and range must match
    * Predicate paths must follow the schema
    * Referenced types are present in the schema

    """

    # schema to validate against.
    schema: bsc.Schema

    def __init__(self, schema: bsc.Schema):
        self.schema = schema

    def __call__(self, root_type: bsc.Node, query: ast.filter.FilterExpression) -> bool:
        """Alias for `Filter.validate`."""
        return self.validate(root_type, query)

    def validate(self, root_type: bsc.Node, query: ast.filter.FilterExpression) -> bool:
        """Validate a filter *query*, assuming the subject having *root_type*.

        Raises a `bsfs.utils.errors.ConsistencyError` if the query violates the schema.
        Raises a `bsfs.utils.errors.BackendError` if the query structure is invalid.

        """
        # root_type must be a schema.Node
        if not isinstance(root_type, bsc.Node):
            raise TypeError(f'expected a node, found {typename(root_type)}')
        # root_type must exist in the schema
        if root_type not in self.schema.nodes():
            raise errors.ConsistencyError(f'{root_type} is not defined in the schema')
        # check root expression
        self._parse_filter_expression(root_type, query)
        # all tests passed
        return True


    ## routing methods

    def _parse_filter_expression(self, type_: bsc.Vertex, node: ast.filter.FilterExpression):
        """Route *node* to the handler of the respective FilterExpression subclass."""
        if isinstance(node, ast.filter.Is):
            return self._is(type_, node)
        if isinstance(node, ast.filter.Not):
            return self._not(type_, node)
        if isinstance(node, ast.filter.Has):
            return self._has(type_, node)
        if isinstance(node, ast.filter.Distance):
            return self._distance(type_, node)
        if isinstance(node, (ast.filter.Any, ast.filter.All)):
            return self._branch(type_, node)
        if isinstance(node, (ast.filter.And, ast.filter.Or)):
            return self._agg(type_, node)
        if isinstance(node, (ast.filter.Equals, ast.filter.Substring, ast.filter.StartsWith, ast.filter.EndsWith)):
            return self._value(type_, node)
        if isinstance(node, (ast.filter.LessThan, ast.filter.GreaterThan)):
            return self._bounded(type_, node)
        # invalid node
        raise errors.BackendError(f'expected filter expression, found {node}')

    def _parse_predicate_expression(self, node: ast.filter.PredicateExpression) -> typing.Tuple[bsc.Vertex, bsc.Vertex]:
        """Route *node* to the handler of the respective PredicateExpression subclass."""
        if isinstance(node, ast.filter.Predicate):
            return self._predicate(node)
        if isinstance(node, ast.filter.OneOf):
            return self._one_of(node)
        # invalid node
        raise errors.BackendError(f'expected predicate expression, found {node}')


    ## predicate expressions

    def _predicate(self, node: ast.filter.Predicate) -> typing.Tuple[bsc.Vertex, bsc.Vertex]:
        # predicate exists in the schema
        if not self.schema.has_predicate(node.predicate):
            raise errors.ConsistencyError(f'predicate {node.predicate} is not in the schema')
        # determine domain and range
        pred = self.schema.predicate(node.predicate)
        if not isinstance(pred.range, (bsc.Node, bsc.Literal)):
            raise errors.BackendError(f'the range of predicate {pred} is undefined')
        dom, rng = pred.domain, pred.range
        if node.reverse:
            dom, rng = rng, dom # type: ignore [assignment] # variable re-use confuses mypy
        # return domain and range
        return dom, rng

    def _one_of(self, node: ast.filter.OneOf) -> typing.Tuple[bsc.Vertex, bsc.Vertex]:
        # determine domain and range types
        # NOTE: select the most specific domain and the most generic range
        dom, rng = None, None
        for pred in node:
            # parse child expression
            subdom, subrng = self._parse_predicate_expression(pred)
            # determine overall domain
            if dom is None or subdom < dom: # pick most specific domain
                dom = subdom
            # domains must be related across all child expressions
            if not subdom <= dom and not subdom >= dom:
                raise errors.ConsistencyError(f'domains {subdom} and {dom} are not related')
            # determine overall range
            if rng is None or subrng > rng: # pick most generic range
                rng = subrng
            # ranges must be related across all child expressions
            if not subrng <= rng and not subrng >= rng:
                raise errors.ConsistencyError(f'ranges {subrng} and {rng} are not related')
        # OneOf guarantees at least one expression, dom and rng are always bsc.Vertex.
        # mypy does not realize this, hence we ignore the warning.
        return dom, rng # type: ignore [return-value]


    ## intermediates

    def _branch(self, type_: bsc.Vertex, node: ast.filter._Branch):
        # type is a Node
        if not isinstance(type_, bsc.Node):
            raise errors.ConsistencyError(f'expected a Node, found {type_}')
        # type exists in the schema
        # FIXME: Isn't it actually guaranteed that the type (except the root type) is part of the schema?
        # all types can be traced back to (a) root_type, (b) predicate, or (c) manually set (e.g. in _is).
        # For (a), we do (and have to) perform a check. For (c), the code base should be consistent throughout
        # the module, so this is an assumption that has to be ensured in schema.Schema. For (b), we know (and
        # check) that the predicate is in the schema, hence all node/literals derived from it are also in the
        # schema by construction of the schema.Schema class. So, why do we check this every time?
        if type_ not in self.schema.nodes():
            raise errors.ConsistencyError(f'node {type_} is not in the schema')
        # predicate is valid
        dom, rng = self._parse_predicate_expression(node.predicate)
        # type_ is a subtype of the predicate's domain
        if not type_ <= dom:
            raise errors.ConsistencyError(f'expected type {dom} or subtype thereof, found {type_}')
        # child expression is valid
        self._parse_filter_expression(rng, node.expr)

    def _agg(self, type_: bsc.Vertex, node: ast.filter._Agg):
        for expr in node:
            # child expression is valid
            self._parse_filter_expression(type_, expr)

    def _not(self, type_: bsc.Vertex, node: ast.filter.Not):
        # child expression is valid
        self._parse_filter_expression(type_, node.expr)

    def _has(self, type_: bsc.Vertex, node: ast.filter.Has):
        # type is a Node
        if not isinstance(type_, bsc.Node):
            raise errors.ConsistencyError(f'expected a Node, found {type_}')
        # type exists in the schema
        if type_ not in self.schema.nodes():
            raise errors.ConsistencyError(f'node {type_} is not in the schema')
        # predicate is valid
        dom, _= self._parse_predicate_expression(node.predicate)
        # type_ is a subtype of the predicate's domain
        if not type_ <= dom:
            raise errors.ConsistencyError(f'expected type {dom}, found {type_}')
        # node.count is a numerical expression
        self._parse_filter_expression(self.schema.literal(ns.bsl.Number), node.count)

    def _distance(self, type_: bsc.Vertex, node: ast.filter.Distance):
        # type is a Literal
        if not isinstance(type_, bsc.Feature):
            raise errors.ConsistencyError(f'expected a Feature, found {type_}')
        # type exists in the schema
        if type_ not in self.schema.literals():
            raise errors.ConsistencyError(f'literal {type_} is not in the schema')
        # reference matches type_
        if len(node.reference) != type_.dimension:
            raise errors.ConsistencyError(f'reference has dimension {len(node.reference)}, expected {type_.dimension}')
        # FIXME: test dtype


    ## conditions

    def _is(self, type_: bsc.Vertex, node: ast.filter.Is): # pylint: disable=unused-argument # (node)
        if not isinstance(type_, bsc.Node):
            raise errors.ConsistencyError(f'expected a Node, found {type_}')
        if type_ not in self.schema.nodes():
            raise errors.ConsistencyError(f'node {type_} is not in the schema')

    def _value(self, type_: bsc.Vertex, node: ast.filter._Value): # pylint: disable=unused-argument # (node)
        # type is a literal
        if not isinstance(type_, bsc.Literal):
            raise errors.ConsistencyError(f'expected a Literal, found {type_}')
        # type exists in the schema
        if type_ not in self.schema.literals():
            raise errors.ConsistencyError(f'literal {type_} is not in the schema')
        # FIXME: Check if node.value corresponds to type_
        # FIXME: A specific literal might be requested (i.e., a numeric type when used in Has)

    def _bounded(self, type_: bsc.Vertex, node: ast.filter._Bounded): # pylint: disable=unused-argument # (node)
        # type is a literal
        if not isinstance(type_, bsc.Literal):
            raise errors.ConsistencyError(f'expected a Literal, found {type_}')
        # type exists in the schema
        if type_ not in self.schema.literals():
            raise errors.ConsistencyError(f'literal {type_} is not in the schema')
        # type must be a numerical
        if not type_ <= self.schema.literal(ns.bsl.Number):
            raise errors.ConsistencyError(f'expected a number type, found {type_}')
        # FIXME: Check if node.value corresponds to type_


class Fetch():
    """Validate a `bsfs.query.ast.fetch` query's structure and schema compliance.

    * Value can only be applied on literals
    * Node can only be applied on nodes
    * Names must be non-empty
    * Branching nodes' predicates must match the type
    * Symbols must be in the schema
    * Predicates must follow the schema

    """

    # schema to validate against.
    schema: bsc.Schema

    def __init__(self, schema: bsc.Schema):
        self.schema = schema

    def __call__(self, root_type: bsc.Node, query: ast.fetch.FetchExpression) -> bool:
        """Alias for `Fetch.validate`."""
        return self.validate(root_type, query)

    def validate(self, root_type: bsc.Node, query: ast.fetch.FetchExpression) -> bool:
        """Validate a fetch *query*, assuming the subject having *root_type*.

        Raises a `bsfs.utils.errors.ConsistencyError` if the query violates the schema.
        Raises a `bsfs.utils.errors.BackendError` if the query structure is invalid.

        """
        # root_type must be a schema.Node
        if not isinstance(root_type, bsc.Node):
            raise TypeError(f'expected a node, found {typename(root_type)}')
        # root_type must exist in the schema
        if root_type not in self.schema.nodes():
            raise errors.ConsistencyError(f'{root_type} is not defined in the schema')
        # query must be a FetchExpression
        if not isinstance(query, ast.fetch.FetchExpression):
            raise TypeError(f'expected a fetch expression, found {typename(query)}')
        # check root expression
        self._parse_fetch_expression(root_type, query)
        # all tests passed
        return True

    def _parse_fetch_expression(self, type_: bsc.Vertex, node: ast.fetch.FetchExpression):
        """Route *node* to the handler of the respective FetchExpression subclass."""
        if isinstance(node, (ast.fetch.Fetch, ast.fetch.Value, ast.fetch.Node)):
            # NOTE: don't return so that checks below are executed
            self._branch(type_, node)
        if isinstance(node, (ast.fetch.Value, ast.fetch.Node)):
            # NOTE: don't return so that checks below are executed
            self._named(type_, node)
        if isinstance(node, ast.fetch.All):
            return self._all(type_, node)
        if isinstance(node, ast.fetch.Fetch):
            return self._fetch(type_, node)
        if isinstance(node, ast.fetch.Value):
            return self._value(type_, node)
        if isinstance(node, ast.fetch.Node):
            return self._node(type_, node)
        if isinstance(node, ast.fetch.This):
            return self._this(type_, node)
        # invalid node
        raise errors.BackendError(f'expected fetch expression, found {node}')

    def _all(self, type_: bsc.Vertex, node: ast.fetch.All):
        # check child expressions
        for expr in node:
            self._parse_fetch_expression(type_, expr)

    def _branch(self, type_: bsc.Vertex, node: ast.fetch._Branch):
        # type is a node
        if not isinstance(type_, bsc.Node):
            raise errors.ConsistencyError(f'expected a Node, found {type_}')
        # node exists in the schema
        if type_ not in self.schema.nodes():
            raise errors.ConsistencyError(f'node {type_} is not in the schema')
        # predicate exists in the schema
        if not self.schema.has_predicate(node.predicate):
            raise errors.ConsistencyError(f'predicate {node.predicate} is not in the schema')
        pred = self.schema.predicate(node.predicate)
        # type_ must be a subclass of domain
        if not type_ <= pred.domain:
            raise errors.ConsistencyError(
                f'expected type {pred.domain} or subtype thereof, found {type_}')

    def _fetch(self, type_: bsc.Vertex, node: ast.fetch.Fetch): # pylint: disable=unused-argument # type_ was considered in _branch
        # range must be a node
        rng = self.schema.predicate(node.predicate).range
        if not isinstance(rng, bsc.Node):
            raise errors.ConsistencyError(
                f'expected the predicate\'s range to be a Node, found {rng}')
        # child expression must be valid
        self._parse_fetch_expression(rng, node.expr)

    def _named(self, type_: bsc.Vertex, node: ast.fetch._Named): # pylint: disable=unused-argument # type_ was considered in _branch
        # name must be set
        if node.name.strip() == '':
            raise errors.BackendError('node name cannot be empty')
        # FIXME: check for double name use?

    def _node(self, type_: bsc.Vertex, node: ast.fetch.Node): # pylint: disable=unused-argument # type_ was considered in _branch
        # range must be a node
        rng = self.schema.predicate(node.predicate).range
        if not isinstance(rng, bsc.Node):
            raise errors.ConsistencyError(
                f'expected the predicate\'s range to be a Node, found {rng}')

    def _value(self, type_: bsc.Vertex, node: ast.fetch.Value): # pylint: disable=unused-argument # type_ was considered in _branch
        # range must be a literal
        rng = self.schema.predicate(node.predicate).range
        if not isinstance(rng, bsc.Literal):
            raise errors.ConsistencyError(
                f'expected the predicate\'s range to be a Literal, found {rng}')

    def _this(self, type_: bsc.Vertex, node: ast.fetch.This):
        # type is a node
        if not isinstance(type_, bsc.Node):
            raise errors.ConsistencyError(f'expected a Node, found {type_}')
        # node exists in the schema
        if type_ not in self.schema.nodes():
            raise errors.ConsistencyError(f'node {type_} is not in the schema')
        # name must be set
        if node.name.strip() == '':
            raise errors.BackendError('node name cannot be empty')

## EOF ##
