"""

Part of the tagit module.
A copy of the license is provided with the project.
Author: Matthias Baumgartner, 2022
"""
# tagit imports
from tagit.utils.bsfs import ast, matcher, URI
from tagit.utils import errors, ns

# exports
__all__ = ('ToString', )


## code ##

class ToString():

    def __init__(self, schema):
        self.matches = matcher.Filter()

        self.schema = schema
        predicates = {pred for pred in self.schema.predicates() if pred.domain <= self.schema.node(ns.bsn.Entity)}
        # shortcuts
        self._abb2uri = {pred.uri.fragment: pred.uri for pred in predicates} # FIXME: tie-breaking for duplicates
        self._uri2abb = {uri: fragment for fragment, uri in self._abb2uri.items()}

    def __call__(self, query):
        """
        """
        # FIXME: test query class type
        if self.matches(query, ast.filter.And(matcher.Rest())):
            return ' / '.join(self._parse(sub) for sub in query)
        return self._parse(query)

    def _parse(self, query):
        cases = (
            self._has,
            self._entity,
            self._group,
            self._tag,
            self._range,
            self._categorical,
            )
        for clbk in cases:
            result = clbk(query)
            if result is not None:
                return result

        raise errors.BackendError()

    def _has(self, query):
        # Has(<pred>) <-> has <pred>
        # Not(Has(<pred>)) <-> has no <pred>
        has = ast.filter.Has(
            matcher.Partial(ast.filter.Predicate),
            ast.filter.GreaterThan(1, strict=False))
        if self.matches(query, has):
            # FIXME: guard against predicate mismatch
            return f'has {self._uri2abb[query.predicate.predicate]}'
        if self.matches(query, ast.filter.Not(has)):
            # FIXME: guard against predicate mismatch
            return f'has no {self._uri2abb[query.predicate.predicate]}'
        return None

    def _categorical(self, query):
        if not isinstance(query, ast.filter._Branch):
            return None

        # shortcuts
        expr = query.expr
        pred = self._uri2abb.get(query.predicate.predicate, None)
        if pred is None:
            return None

        # positive constraints
        if isinstance(query, ast.filter.Any):
            # approximate positive constraint
            # Any(<pred>, Includes(<values>, approx=True)) -> pred ~ ("...", ...)
            if self.matches(expr, matcher.Partial(ast.filter.Substring)):
                return f'{pred} ~ {expr.value}'
            if self.matches(expr, ast.filter.Or(matcher.Rest(matcher.Partial(ast.filter.Substring)))):
                values = '", "'.join(sub.value for sub in expr)
                return f'{pred} ~ ("{values}")'

            # exact positive constraint
            # ast.filter.Any(<pred>, ast.filter.Includes(<values>, approx=False)) -> pred = ("...", ...)
            if self.matches(expr, matcher.Partial(ast.filter.Equals)):
                return f'{pred} = {expr.value}'
            if self.matches(query, ast.filter.Or(matcher.Rest(matcher.Partial(ast.filter.Equals)))):
                values = '", "'.join(sub.value for sub in expr)
                return f'{pred} = ("{values}")'

        # negative constraints
        if isinstance(query, ast.filter.All):
            # approximate negative constraint
            # ast.filter.All(<pred>, ast.filter.Excludes(<values>, approx=True)) -> pred !~ ("...", ...)
            if self.matches(query, ast.filter.Not(matcher.Partial(ast.filter.Substring))):
                return f'{pred} !~ "{expr.value}"'
            if self.matches(query, ast.filter.Not(ast.filter.Or(matcher.Rest(matcher.Partial(ast.filter.Substring))))):
                values = '", "'.join(sub.value for sub in expr)
                return f'{pred} !~ ("{values}")'

            # exact negative constraint
            # ast.filter.All(<pred>, ast.filter.Excludes(<values>, approx=False)) -> pred != ("...", ...)
            if self.matches(query, ast.filter.Not(matcher.Partial(ast.filter.Equals))):
                return f'{pred} != "{expr.value}"'
            if self.matches(query, ast.filter.Not(ast.filter.Or(matcher.Rest(matcher.Partial(ast.filter.Equals))))):
                values = '", "'.join(sub.value for sub in expr)
                return f'{pred} != ("{values}")'

        return None

    def _tag(self, query):
        # positive constraint
        # ast.filter.Any(ns.bse.tag, ast.filter.Any(ns.bst.label, ast.filter.Includes(..., approx=?))) <-> "...", ...; ~ "...", ...
        if self.matches(query, ast.filter.Any(ns.bse.tag, ast.filter.Any(ns.bst.label, matcher.Any()))):
            expr = query.expr.expr
            # approximate positive constraint
            if self.matches(expr, matcher.Partial(ast.filter.Substring)):
                return f'~ {expr.value}'
            if self.matches(expr, ast.filter.Or(matcher.Rest(matcher.Partial(ast.filter.Substring)))):
                values = '", "'.join(sub.value for sub in expr)
                return f'~ "{values}"'
            # exact positive constraint
            if self.matches(expr, matcher.Partial(ast.filter.Equals)):
                return f'{expr.value}'
            if self.matches(expr, ast.filter.Or(matcher.Rest(matcher.Partial(ast.filter.Equals)))):
                values = '", "'.join(sub.value for sub in expr)
                return f'"{values}"'

        # negative constraint
        # ast.filter.All(ns.bse.tag, ast.filter.Any(ns.bst.label, ast.filter.Excludes(..., approx=?))) <-> ! "...", ... ; !~ "...", ...
        if self.matches(query, ast.filter.All(ns.bse.tag, ast.filter.Any(ns.bst.label, ast.filter.Not(matcher.Any())))):
            expr = query.expr.expr.expr
            # approximate negative constraint
            if self.matches(expr, matcher.Partial(ast.filter.Substring)):
                return f'!~ {expr.value}'
            if self.matches(expr, ast.filter.Or(matcher.Rest(matcher.Partial(ast.filter.Substring)))):
                values = '", "'.join(sub.value for sub in expr)
                return f'!~ "{values}"'
            # exact negative constraint
            if self.matches(expr, matcher.Partial(ast.filter.Equals)):
                return f'! {expr.value}'
            if self.matches(expr, ast.filter.Or(matcher.Rest(matcher.Partial(ast.filter.Equals)))):
                values = '", "'.join(sub.value for sub in expr)
                return f'! "{values}"'

        return None

    def _range(self, query):
        # FIXME: handle dates and times!
        # FIXME: use default/configurable separators from from_string
        if not isinstance(query, ast.filter.Any):
            return None

        expr = query.expr
        pred = self._uri2abb.get(query.predicate.predicate, None)
        if pred is None:
            return None

        if self.matches(expr, matcher.Partial(ast.filter.Equals)):
            return f'{pred} = {expr.value}'
        if self.matches(expr, matcher.Partial(ast.filter.GreaterThan, strict=True)):
            return f'{pred} > {expr.threshold}'
        if self.matches(expr, matcher.Partial(ast.filter.GreaterThan, strict=False)):
            return f'{pred} >= {expr.threshold}'
        if self.matches(expr, matcher.Partial(ast.filter.LessThan, strict=True)):
            return f'{pred} < {expr.threshold}'
        if self.matches(expr, matcher.Partial(ast.filter.LessThan, strict=False)):
            return f'{pred} <= {expr.threshold}'
        if self.matches(expr, ast.filter.And(
                matcher.Partial(ast.filter.GreaterThan),
                matcher.Partial(ast.filter.LessThan))):
            lo, hi = list(expr)
            if self.matches(lo, matcher.Partial(ast.filter.LessThan)):
                lo, hi = hi, lo
            b_open = '(' if lo.strict else '['
            b_close = ')' if hi.strict else ']'
            return f'{pred} = {b_open}{lo.threshold} - {hi.threshold}{b_close}'
        """
        ast.filter.Any(<pred>, ast.filter.Between(lo, hi, lo_strict, hi_strict))
            pred <? hi
            pred >? hi
            pred = [lo, hi]
            pred = (lo, hi)
            pred = [lo, hi)
            pred = (lo, hi]
        """
        return None

    def _entity(self, query):
        # defaults
        negated = False
        guids = set()

        def get_guids(value):
            if isinstance(value, URI):
                return {value}
            else: # elif isinstance(query.value, Nodes):
                return set(value.guids)

        if self.matches(query, matcher.Partial(ast.filter.Is)):
            guids = get_guids(query.value)
        elif self.matches(query, ast.filter.Or(matcher.Rest(matcher.Partial(ast.filter.Is)))):
            guids = {guid for sub in query for guid in get_guids(sub.value) }
        elif self.matches(query, ast.filter.Not(matcher.Partial(ast.filter.Is))):
            negated = True
            guids = get_guids(query.expr.value)
        elif self.matches(query, ast.filter.Not(ast.filter.Or(matcher.Rest(matcher.Partial(ast.filter.Is))))):
            negated = True
            guids = {guid for sub in query.expr for guid in get_guids(sub.value) }

        if len(guids) == 0:
            # no matches
            return None
        # some matches
        cmp = 'not in' if negated else 'in'
        values = '", "'.join(guids)
        return f'id {cmp} "{values}"'

    def _group(self, query):
        # ast.filter.Any(ns.bse.group, ast.filter.Is(...)) <-> group = ("...", ...)
        if not self.matches(query, ast.filter.Any(ns.bse.group, matcher.Any())):
            return None

        def get_guids(value):
            if isinstance(value, URI):
                return {value}
            else: # elif isinstance(query.value, Nodes):
                return set(value.guids)

        expr = query.expr
        guids = set()
        negated = False

        if self.matches(expr, matcher.Partial(ast.filter.Is)):
            guids = get_guids(expr.value)
        elif self.matches(expr, ast.filter.Or(matcher.Rest(matcher.Partial(ast.filter.Is)))):
            guids = {guid for sub in expr for guid in get_guids(sub.value) }
        elif self.matches(expr, ast.filter.Not(matcher.Partial(ast.filter.Is))):
            negated = True
            guids = get_guids(expr.value)
        elif self.matches(expr, ast.filter.Not(ast.filter.Or(matcher.Rest(matcher.Partial(ast.filter.Is))))):
            negated = True
            guids = {guid for sub in expr for guid in get_guids(sub.value) }

        if len(guids) == 0: # no matches
            return None
        # some matches
        cmp = 'not in' if negated else 'in'
        values = '", "'.join(guids)
        return f'group {cmp} "{values}"'

## EOF ##
