
# imports
from collections import abc
import time
import typing

# bsfs imports
from bsfs import schema as bsc
from bsfs.namespace import ns
from bsfs.query import ast, validate
from bsfs.triple_store import TripleStoreBase
from bsfs.utils import errors, URI, typename

# inner-module imports
from . import ac
from . import result
from . import walk

# exports
__all__: typing.Sequence[str] = (
    'Nodes',
    )


## code ##

class Nodes():
    """Container for graph nodes, provides operations on nodes.

    NOTE: Should not be created directly but only via `bsfs.graph.Graph`.
    NOTE: guids may or may not exist. This is not verified as nodes are created on demand.
    """

    # triple store backend.
    _backend: TripleStoreBase

    # access controls.
    _ac: ac.AccessControlBase

    # node type.
    _node_type: bsc.Node

    # guids of nodes. Can be empty.
    _guids: typing.Set[URI]

    def __init__(
            self,
            backend: TripleStoreBase,
            access_control: ac.AccessControlBase,
            node_type: bsc.Node,
            guids: typing.Iterable[URI],
            ):
        # set main members
        self._backend = backend
        self._ac = access_control
        self._node_type = node_type
        # convert to URI since this is not guaranteed by Graph
        self._guids = {URI(guid) for guid in guids}

    def __eq__(self, other: typing.Any) -> bool:
        return isinstance(other, Nodes) \
           and self._backend == other._backend \
           and self._ac == other._ac \
           and self._node_type == other._node_type \
           and self._guids == other._guids

    def __hash__(self) -> int:
        return hash((type(self), self._backend, self._ac, self._node_type, tuple(sorted(self._guids))))

    def __repr__(self) -> str:
        return f'{typename(self)}({self._backend}, {self._ac}, {self._node_type}, {self._guids})'

    def __str__(self) -> str:
        return f'{typename(self)}({self._node_type}, {self._guids})'

    @property
    def node_type(self) -> bsc.Node:
        """Return the node's type."""
        return self._node_type

    @property
    def guids(self) -> typing.Iterator[URI]:
        """Return all node guids."""
        return iter(self._guids)

    @property
    def schema(self) -> bsc.Schema:
        """Return the store's local schema."""
        return self._backend.schema

    def __add__(self, other: typing.Any) -> 'Nodes':
        """Concatenate guids. Backend, AC, and node type must match."""
        if not isinstance(other, type(self)):
            return NotImplemented
        if self._backend != other._backend:
            raise ValueError(other)
        if self._ac != other._ac:
            raise ValueError(other)
        if self.node_type != other.node_type:
            raise ValueError(other)
        return Nodes(self._backend, self._ac, self.node_type, self._guids | other._guids)

    def __or__(self, other: typing.Any) -> 'Nodes':
        """Concatenate guids. Backend, AC, and node type must match."""
        return self.__add__(other)

    def __sub__(self, other: typing.Any) -> 'Nodes':
        """Subtract guids. Backend, AC, and node type must match."""
        if not isinstance(other, type(self)):
            return NotImplemented
        if self._backend != other._backend:
            raise ValueError(other)
        if self._ac != other._ac:
            raise ValueError(other)
        if self.node_type != other.node_type:
            raise ValueError(other)
        return Nodes(self._backend, self._ac, self.node_type, self._guids - other._guids)

    def __and__(self, other: typing.Any) -> 'Nodes':
        """Intersect guids. Backend, AC, and node type must match."""
        if not isinstance(other, type(self)):
            return NotImplemented
        if self._backend != other._backend:
            raise ValueError(other)
        if self._ac != other._ac:
            raise ValueError(other)
        if self.node_type != other.node_type:
            raise ValueError(other)
        return Nodes(self._backend, self._ac, self.node_type, self._guids & other._guids)

    def __len__(self) -> int:
        """Return the number of guids."""
        return len(self._guids)

    def __iter__(self) -> typing.Iterator['Nodes']:
        """Iterate over individual guids. Returns `Nodes` instances."""
        return iter(
            Nodes(self._backend, self._ac, self.node_type, {guid})
            for guid in self._guids
            )

    def __getattr__(self, name: str):
        try:
            return super().__getattr__(name) # type: ignore [misc] # parent has no getattr
        except AttributeError:
            pass
        return walk.Walk(self, walk.Walk.step(self.schema, self.node_type, name))

    def set(
            self,
            pred: URI, # FIXME: URI or bsc.Predicate?
            value: typing.Any,
            ) -> 'Nodes':
        """Set predicate *pred* to *value*."""
        return self.set_from_iterable([(pred, value)])

    def set_from_iterable(
            self,
            predicate_values: typing.Iterable[typing.Tuple[URI, typing.Any]], # FIXME: URI or bsc.Predicate?
            ) -> 'Nodes':
        """Set mutliple predicate-value pairs at once."""
        # TODO: Could group predicate_values by predicate to gain some efficiency
        # TODO: ignore errors on some predicates; For now this could leave residual
        #       data (e.g. some nodes were created, some not).
        try:
            # insert triples
            for pred, value in predicate_values:
                self.__set(pred, value)
            # save changes
            self._backend.commit()

        except (
                errors.PermissionDeniedError, # tried to set a protected predicate
                errors.ConsistencyError, # node types are not in the schema or don't match the predicate
                errors.InstanceError, # guids/values don't have the correct type
                TypeError, # value is supposed to be a Nodes instance
                ValueError, # multiple values passed to unique predicate
                ):
            # revert changes
            self._backend.rollback()
            # notify the client
            raise

        # FIXME: How about other errors? Shouldn't I then rollback as well?!

        return self

    def get(
            self,
            *paths: typing.Union[URI, typing.Iterable[URI]],
            view: typing.Union[typing.Type[list], typing.Type[dict]] = dict,
            **view_kwargs,
            ) -> typing.Any:
        """Get values or nodes at *paths*.
        Return an iterator (view=list) or a dict (view=dict) over the results.
        """
        # FIXME: user-provided Fetch query AST?
        # check args
        if len(paths) == 0:
            raise AttributeError('expected at least one path, found none')
        if view not in (dict, list):
            raise ValueError(f'expected dict or list, found {view}')
        # process paths: create fetch ast, build name mapping, and find unique paths
        schema = self.schema
        statements = set()
        name2path = {}
        unique_paths = set() # paths that result in a single (unique) value
        normpath: typing.Tuple[URI, ...]
        for idx, path in enumerate(paths):
            # normalize path
            if isinstance(path, str):
                normpath = (URI(path), )
            elif isinstance(path, abc.Iterable):
                if not all(isinstance(step, str) for step in path):
                    raise TypeError(path)
                normpath = tuple(URI(step) for step in path)
            else:
                raise TypeError(path)
            # check path's schema consistency
            if not all(schema.has_predicate(pred) for pred in normpath):
                raise errors.ConsistencyError(f'path is not fully covered by the schema: {path}')
            # check path's uniqueness
            if all(schema.predicate(pred).unique for pred in normpath):
                unique_paths.add(path)
            # fetch tail predicate
            tail = schema.predicate(normpath[-1])
            # determine tail ast node type
            factory = ast.fetch.Node if isinstance(tail.range, bsc.Node) else ast.fetch.Value
            # assign name
            name = f'fetch{idx}'
            name2path[name] = (path, tail)
            # create tail ast node
            curr: ast.fetch.FetchExpression = factory(tail.uri, name)
            # walk towards front
            hop: URI
            for hop in normpath[-2::-1]:
                curr = ast.fetch.Fetch(hop, curr)
            # add to fetch query
            statements.add(curr)
        # aggregate fetch statements
        if len(statements) == 1:
            fetch = next(iter(statements))
        else:
            fetch = ast.fetch.All(*statements)
        # add access controls to fetch
        fetch = self._ac.fetch_read(self.node_type, fetch)

        if len(self._guids) == 0:
            # shortcut: no need to query; no triples
            # FIXME: if the Fetch query can given by the user, we might want to check its validity
            def triple_iter():
                return []
        else:
            # compose filter ast
            filter = ast.filter.IsIn(self.guids) # pylint: disable=redefined-builtin
            # add access controls to filter
            filter = self._ac.filter_read(self.node_type, filter) # type: ignore [assignment]

            # validate queries
            validate.Filter(self._backend.schema).validate(self.node_type, filter)
            validate.Fetch(self._backend.schema).validate(self.node_type, fetch)

            # process results, convert if need be
            def triple_iter():
                # query the backend
                triples = self._backend.fetch(self.node_type, filter, fetch)
                # process triples
                for root, name, raw in triples:
                    # get node
                    node = Nodes(self._backend, self._ac, self.node_type, {root})
                    # get path
                    path, tail = name2path[name]
                    # covert raw to value
                    if isinstance(tail.range, bsc.Node):
                        value = Nodes(self._backend, self._ac, tail.range, {raw})
                    else:
                        value = raw
                    # emit triple
                    yield node, path, value

        # simplify by default
        view_kwargs['node'] = view_kwargs.get('node', len(self._guids) != 1)
        view_kwargs['path'] = view_kwargs.get('path', len(paths) != 1)
        view_kwargs['value'] = view_kwargs.get('value', False)

        # return results view
        if view == list:
            return result.to_list_view(
                triple_iter(),
                # aggregation args
                **view_kwargs,
                )

        if view == dict:
            return result.to_dict_view(
                triple_iter(),
                # context
                len(self._guids) == 1,
                len(paths) == 1,
                unique_paths,
                # aggregation args
                **view_kwargs,
                )

        raise errors.UnreachableError() # view was already checked


    def __set(self, predicate: URI, value: typing.Any):
        """
        """
        # get normalized predicate. Raises KeyError if *pred* not in the schema.
        pred = self._backend.schema.predicate(predicate)

        # node_type must be a subclass of the predicate's domain
        node_type = self.node_type
        if not node_type <= pred.domain:
            raise errors.ConsistencyError(f'{node_type} must be a subclass of {pred.domain}')

        # check reserved predicates (access controls, metadata, internal structures)
        # FIXME: Needed? Could be integrated into other AC methods (by passing the predicate!)
        #        This could allow more fine-grained predicate control (e.g. based on ownership)
        #        rather than a global approach like this.
        if self._ac.is_protected_predicate(pred):
            raise errors.PermissionDeniedError(pred)

        # set operation affects all nodes (if possible)
        guids = set(self.guids)

        # ensure subject node existence; create nodes if need be
        guids = set(self._ensure_nodes(node_type, guids))

        # check value
        if isinstance(pred.range, bsc.Literal):
            # check write permissions on existing nodes
            # As long as the user has write permissions, we don't restrict
            # the creation or modification of literal values.
            guids = set(self._ac.write_literal(node_type, guids))

            # insert literals
            # TODO: Support passing iterators as values for non-unique predicates
            self._backend.set(
                node_type,
                guids,
                pred,
                [value],
                )

        elif isinstance(pred.range, bsc.Node):
            # check value type
            # FIXME: value could be a set of Nodes
            if not isinstance(value, Nodes):
                raise TypeError(value)
            # value's node_type must be a subclass of the predicate's range
            if not value.node_type <= pred.range:
                raise errors.ConsistencyError(f'{value.node_type} must be a subclass of {pred.range}')

            # check link permissions on source nodes
            # Link permissions cover adding and removing links on the source node.
            # Specifically, link permissions also allow to remove links to other
            # nodes if needed (e.g. for unique predicates).
            guids = set(self._ac.link_from_node(node_type, guids))

            # get link targets
            targets = set(value.guids)
            # ensure existence of value nodes; create nodes if need be
            targets = set(self._ensure_nodes(value.node_type, targets))
            # check link permissions on target nodes
            targets = set(self._ac.link_to_node(value.node_type, targets))

            # insert node links
            self._backend.set(
                node_type,
                guids,
                pred,
                targets,
                )

        else:
            raise errors.UnreachableError()

    def _ensure_nodes(self, node_type: bsc.Node, guids: typing.Iterable[URI]):
        """
        """
        # check node existence
        guids = set(guids)
        existing = set(self._backend.exists(node_type, guids))
        # get nodes to be created
        missing = guids - existing
        # create nodes if need be
        if len(missing) > 0:
            # check which missing nodes can be created
            missing = set(self._ac.createable(node_type, missing))
            # create nodes
            self._backend.create(node_type, missing)
            # add bookkeeping triples
            self._backend.set(node_type, missing,
                self._backend.schema.predicate(ns.bsn.t_created), [time.time()])
            # add permission triples
            self._ac.create(node_type, missing)
        # return available nodes
        return existing | missing

## EOF ##
