From 3e5894afb9af6b07c06abdadbc56c9b9ecc19870 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sok=C3=B3=C5=82?= Date: Sun, 5 Jan 2025 17:37:28 +0000 Subject: [PATCH] Initial draft --- sparse/__init__.py | 1 + sparse/scheduler/__init__.py | 40 +++ sparse/scheduler/compiler.py | 79 ++++++ sparse/scheduler/executor.py | 27 ++ sparse/scheduler/finch_logic.py | 332 +++++++++++++++++++++++ sparse/scheduler/lazy.py | 107 ++++++++ sparse/scheduler/optimize.py | 55 ++++ sparse/scheduler/rewrite_tools.py | 103 +++++++ sparse/scheduler/tests/__init__.py | 0 sparse/scheduler/tests/conftest.py | 0 sparse/scheduler/tests/test_scheduler.py | 20 ++ 11 files changed, 764 insertions(+) create mode 100644 sparse/scheduler/__init__.py create mode 100644 sparse/scheduler/compiler.py create mode 100644 sparse/scheduler/executor.py create mode 100644 sparse/scheduler/finch_logic.py create mode 100644 sparse/scheduler/lazy.py create mode 100644 sparse/scheduler/optimize.py create mode 100644 sparse/scheduler/rewrite_tools.py create mode 100644 sparse/scheduler/tests/__init__.py create mode 100644 sparse/scheduler/tests/conftest.py create mode 100644 sparse/scheduler/tests/test_scheduler.py diff --git a/sparse/__init__.py b/sparse/__init__.py index 5def04ac..d5cf5be8 100644 --- a/sparse/__init__.py +++ b/sparse/__init__.py @@ -2,6 +2,7 @@ import warnings from enum import Enum +from . import scheduler # noqa: F401 from ._version import __version__, __version_tuple__ # noqa: F401 __array_api_version__ = "2022.12" diff --git a/sparse/scheduler/__init__.py b/sparse/scheduler/__init__.py new file mode 100644 index 00000000..1836890a --- /dev/null +++ b/sparse/scheduler/__init__.py @@ -0,0 +1,40 @@ +from .finch_logic import ( + Aggregate, + Alias, + Deferred, + Field, + Immediate, + MapJoin, + Plan, + Produces, + Query, + Reformat, + Relabel, + Reorder, + Subquery, + Table, +) +from .optimize import optimize, propagate_map_queries +from .rewrite_tools import PostOrderDFS, PostWalk, PreWalk + +__all__ = [ + "Aggregate", + "Alias", + "Deferred", + "Field", + "Immediate", + "MapJoin", + "Plan", + "Produces", + "Query", + "Reformat", + "Relabel", + "Reorder", + "Subquery", + "Table", + "optimize", + "propagate_map_queries", + "PostOrderDFS", + "PostWalk", + "PreWalk", +] diff --git a/sparse/scheduler/compiler.py b/sparse/scheduler/compiler.py new file mode 100644 index 00000000..ea314bd0 --- /dev/null +++ b/sparse/scheduler/compiler.py @@ -0,0 +1,79 @@ +from textwrap import dedent + +from .finch_logic import Alias, Deferred, Field, Immediate, LogicNode, MapJoin, Query, Reformat, Relabel, Reorder, Table + + +class PointwiseLowerer: + def __init__(self): + self.bound_idxs = [] + + def __call__(self, ex): + match ex: + case MapJoin(op, args) if isinstance(op, Immediate): + return f":({op.val}({','.join([self(arg) for arg in args])}))" + case Reorder(Relabel(arg, idxs_1), idxs_2) if isinstance(arg, Alias): + self.bound_idxs.append(idxs_1) + return f":({arg.name}[{','.join([idx.name if idx in idxs_2 else 1 for idx in idxs_1])}])" + case Reorder(arg, _) if isinstance(arg, Immediate): + return arg.val + case Immediate(val): + return val + case _: + raise Exception(f"Unrecognized logic: {ex}") + + +def compile_pointwise_logic(ex: LogicNode) -> tuple: + ctx = PointwiseLowerer() + code = ctx(ex) + return (code, ctx.bound_idxs) + + +def compile_logic_constant(ex): + match ex: + case Immediate(val): + return val + case Deferred(ex, type_): + return f":({ex}::{type_})" + case _: + raise Exception(f"Invalid constant: {ex}") + + +class LogicLowerer: + def __init__(self, mode: str = "fast"): + self.mode = mode + + def __call__(self, ex): + match ex: + case Query(lhs, Table(tns, _)) if isinstance(lhs, Alias): + return f":({lhs.name} = {compile_logic_constant(tns)})" + + case Query(lhs, Reformat(tns, Reorder(Relabel(arg, idxs_1), idxs_2))) if isinstance( + lhs, Alias + ) and isinstance(arg, Alias): + loop_idxs = [idx.name for idx in withsubsequence(intersect(idxs_1, idxs_2), idxs_2)] # noqa: F821 + lhs_idxs = [idx.name for idx in idxs_2] + (rhs, rhs_idxs) = compile_pointwise_logic(Reorder(Relabel(arg, idxs_1), idxs_2)) + body = f":({lhs.name}[{','.join(lhs_idxs)}] = {rhs})" + for idx in loop_idxs: + if Field(idx) in rhs_idxs: + body = f":(for {idx} = _ \n {body} end)" + elif idx in lhs_idxs: + body = f":(for {idx} = 1:1 \n {body} end)" + + result = f"""\ + quote + {lhs.name} = {compile_logic_constant(tns)} + @finch mode = {self.mode} begin + {lhs.name} .= {tns.fill_value} + {body} + return {lhs.name} + end + end + """ + return dedent(result) + + +class LogicCompiler: + def __call__(self, prgm): + prgm = format_queries(prgm, True) # noqa: F821 + return LogicLowerer()(prgm) diff --git a/sparse/scheduler/executor.py b/sparse/scheduler/executor.py new file mode 100644 index 00000000..b2760307 --- /dev/null +++ b/sparse/scheduler/executor.py @@ -0,0 +1,27 @@ +from .compiler import LogicCompiler + + +class LogicExecutor: + def __init__(self, ctx, verbose=False): + self.ctx: LogicCompiler = ctx + self.codes = {} + self.verbose = verbose + + def __call__(self, prgm): + prgm_structure = prgm + if prgm_structure not in self.codes: + thunk = logic_executor_code(self.ctx, prgm) + self.codes[prgm_structure] = eval(thunk), thunk + + f, code = self.codes[prgm_structure] + if self.verbose: + print(code) + return f(prgm) + + +def get_structure(): + pass + + +def logic_executor_code(ctx, prgm): + pass diff --git a/sparse/scheduler/finch_logic.py b/sparse/scheduler/finch_logic.py new file mode 100644 index 00000000..e9b08f38 --- /dev/null +++ b/sparse/scheduler/finch_logic.py @@ -0,0 +1,332 @@ +from abc import abstractmethod +from collections.abc import Iterable +from dataclasses import dataclass +from typing import Any + +# IS_TREE = 1 +# IS_STATEFUL = 2 +# ID = 4 +# class LogicNodeKind: +# immediate = 0 * ID +# deferred = 1 * ID +# field = 2 * ID +# alias = 3 * ID +# table = 4 * ID | IS_TREE +# mapjoin = 5 * ID | IS_TREE +# aggregate = 6 * ID | IS_TREE +# reorder = 7 * ID | IS_TREE +# relabel = 8 * ID | IS_TREE +# reformat = 9 * ID | IS_TREE +# subquery = 10 * ID | IS_TREE +# query = 11 * ID | IS_TREE | IS_STATEFUL +# produces = 12 * ID | IS_TREE | IS_STATEFUL +# plan = 13 * ID | IS_TREE | IS_STATEFUL + + +@dataclass(eq=True, frozen=True) +class LogicNode: + @staticmethod + @abstractmethod + def is_tree(): ... + + @staticmethod + @abstractmethod + def is_stateful(): ... + + @abstractmethod + def get_arguments(self): ... + + @classmethod + @abstractmethod + def from_arguments(cls): ... + + +@dataclass(eq=True, frozen=True) +class Immediate(LogicNode): + val: Any + + @staticmethod + def is_tree(): + return False + + @staticmethod + def is_stateful(): + return False + + def get_arguments(self): + return [self.val] + + @classmethod + def from_arguments(cls, val): + return cls(val) + + +@dataclass(eq=True, frozen=True) +class Deferred(LogicNode): + ex: Any + type_: Any + + @staticmethod + def is_tree(): + return False + + @staticmethod + def is_stateful(): + return False + + def get_arguments(self): + return [self.val, self.type_] + + @classmethod + def from_arguments(cls, val, type_): + return cls(val, type_) + + +@dataclass(eq=True, frozen=True) +class Field(LogicNode): + name: str + + @staticmethod + def is_tree(): + return False + + @staticmethod + def is_stateful(): + return False + + def get_arguments(self): + return [self.name] + + @classmethod + def from_arguments(cls, name): + return cls(name) + + +@dataclass(eq=True, frozen=True) +class Alias(LogicNode): + name: str + + @staticmethod + def is_tree(): + return False + + @staticmethod + def is_stateful(): + return False + + def get_arguments(self): + return [self.name] + + @classmethod + def from_arguments(cls, name): + return cls(name) + + +@dataclass(eq=True, frozen=True) +class Table(LogicNode): + tns: Any + idxs: Iterable[Any] + + @staticmethod + def is_tree(): + return True + + @staticmethod + def is_stateful(): + return False + + def get_arguments(self): + return [self.tns, *self.idxs] + + @classmethod + def from_arguments(cls, tns, *idxs): + return cls(tns, idxs) + + +@dataclass(eq=True, frozen=True) +class MapJoin(LogicNode): + op: Any + args: Iterable[Any] + + @staticmethod + def is_tree(): + return True + + @staticmethod + def is_stateful(): + return False + + def get_arguments(self): + return [self.op, *self.args] + + @classmethod + def from_arguments(cls, op, *args): + return cls(op, args) + + +@dataclass(eq=True, frozen=True) +class Aggregate(LogicNode): + op: Any + init: Any + arg: Any + idxs: Iterable[Any] + + @staticmethod + def is_tree(): + return True + + @staticmethod + def is_stateful(): + return False + + def get_arguments(self): + return [self.op, self.init, self.arg, *self.idxs] + + @classmethod + def from_arguments(cls, op, init, arg, *idxs): + return cls(op, init, arg, idxs) + + +@dataclass(eq=True, frozen=True) +class Reorder(LogicNode): + arg: Any + idxs: Iterable[Any] + + @staticmethod + def is_tree(): + return True + + @staticmethod + def is_stateful(): + return False + + def get_arguments(self): + return [self.arg, *self.idxs] + + @classmethod + def from_arguments(cls, arg, *idxs): + return cls(arg, idxs) + + +@dataclass(eq=True, frozen=True) +class Relabel(LogicNode): + arg: Any + idxs: Iterable[Any] + + @staticmethod + def is_tree(): + return True + + @staticmethod + def is_stateful(): + return False + + def get_arguments(self): + return [self.arg, *self.idxs] + + @classmethod + def from_arguments(cls, arg, *idxs): + return cls(arg, idxs) + + +@dataclass(eq=True, frozen=True) +class Reformat(LogicNode): + tns: Any + arg: Any + + @staticmethod + def is_tree(): + return True + + @staticmethod + def is_stateful(): + return False + + def get_arguments(self): + return [self.tns, self.arg] + + @classmethod + def from_arguments(cls, tns, arg): + return cls(tns, arg) + + +@dataclass(eq=True, frozen=True) +class Subquery(LogicNode): + lhs: Any + arg: Any + + @staticmethod + def is_tree(): + return True + + @staticmethod + def is_stateful(): + return False + + def get_arguments(self): + return [self.lhs, self.arg] + + @classmethod + def from_arguments(cls, lhs, arg): + return cls(lhs, arg) + + +@dataclass(eq=True, frozen=True) +class Query(LogicNode): + lhs: Any + rhs: Any + + @staticmethod + def is_tree(): + return True + + @staticmethod + def is_stateful(): + return True + + def get_arguments(self): + return [self.lhs, self.rhs] + + @classmethod + def from_arguments(cls, lhs, rhs): + return cls(lhs, rhs) + + +@dataclass(eq=True, frozen=True) +class Produces(LogicNode): + args: Iterable[Any] + + @staticmethod + def is_tree(): + return True + + @staticmethod + def is_stateful(): + return True + + def get_arguments(self): + return [*self.args] + + @classmethod + def from_arguments(cls, *args): + return cls(args) + + +@dataclass(eq=True, frozen=True) +class Plan(LogicNode): + bodies: Iterable[Any] = () + + @staticmethod + def is_tree(): + return True + + @staticmethod + def is_stateful(): + return True + + def get_arguments(self): + return [*self.bodies] + + @classmethod + def from_arguments(cls, *args): + return cls(args) diff --git a/sparse/scheduler/lazy.py b/sparse/scheduler/lazy.py new file mode 100644 index 00000000..ec3c31c3 --- /dev/null +++ b/sparse/scheduler/lazy.py @@ -0,0 +1,107 @@ +import operator +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +from .compiler import LogicCompiler +from .executor import LogicExecutor +from .finch_logic import ( + Aggregate, + Alias, + Field, + Immediate, + MapJoin, + Plan, + Produces, + Query, + Relabel, + Reorder, + Subquery, + Table, +) +from .optimize import DefaultLogicOptimizer +from .rewrite_tools import gensym + + +@dataclass +class LazyTensor: + data: Any + extrude: tuple + fill_value: Any + + @property + def ndim(self) -> int: + return self.data.ndim + + +def lazy(arr) -> LazyTensor: + name = Alias(gensym("A")) + idxs = [Field(gensym("i")) for _ in range(arr.ndims)] + extrude = tuple(arr.shape[i] == 1 for i in range(arr.ndims)) + tns = Subquery(name, Table(Immediate(arr), idxs)) + return LazyTensor(tns, extrude, arr.fill_value) + + +def get_at_idxs(arr, idxs): + return [arr[i] for i in idxs] + + +def permute_dims(arg: LazyTensor, perm) -> LazyTensor: + idxs = [Field(gensym("i")) for _ in range(arg.ndim)] + return LazyTensor( + Reorder(Relabel(arg.data, idxs), get_at_idxs(idxs, perm)), + get_at_idxs(arg.extrude, perm), + arg.fill_value, + ) + + +def identify(data): + lhs = Alias(gensym("A")) + return Subquery(lhs, data) + + +def reduce(op: Callable, arg: LazyTensor, dims=..., fill_value=0.0) -> LazyTensor: + dims = list(range(arg.ndim) if dims is ... else dims) + extrude = tuple(arg.extrude[n] for n in range(arg.ndim) if n not in dims) + fields = [Field(gensym("i")) for _ in range(arg.ndim)] + data = Aggregate(Immediate(op), Immediate(fill_value), Relabel(arg.data, fields), [fields[i] for i in dims]) + return LazyTensor(identify(data), extrude, fill_value) + + +def map(f: Callable, src: LazyTensor, *args) -> LazyTensor: + largs = [src, *args] + extrude = largs[next(filter(lambda x: len(x.extrude) > 0, largs), 0)].extrude + idxs = [Field(gensym("i") for _ in src.extrude)] + ldatas = [] + for larg in largs: + if larg.extrude == extrude: + ldatas.append(Relabel(larg.data, idxs)) + elif larg.extrude == (): + ldatas.append(Relabel(larg.data)) + else: + raise Exception("Cannot map across arrays with different sizes.") + new_fill_value = f(*[x.fill_value for x in largs]) + data = MapJoin(Immediate(f), ldatas) + return LazyTensor(identify(data), src.extrude, new_fill_value) + + +def prod(arr: LazyTensor, dims) -> LazyTensor: + return reduce(operator.mul, arr, dims, arr.fill_value) + + +def multiply(x1: LazyTensor, x2: LazyTensor) -> LazyTensor: + return map(operator.mul, x1, x2) + + +_ds = LogicExecutor(DefaultLogicOptimizer(LogicCompiler())) + + +def get_default_scheduler(): + return _ds + + +def compute(*args, ctx=_ds): + vars = tuple(Alias("A") for _ in args) + bodies = tuple(*map(lambda arg, var: Query(var, arg.data), args, vars)) + prgm = Plan(bodies + (Produces(vars),)) + return ctx(prgm) diff --git a/sparse/scheduler/optimize.py b/sparse/scheduler/optimize.py new file mode 100644 index 00000000..f7390974 --- /dev/null +++ b/sparse/scheduler/optimize.py @@ -0,0 +1,55 @@ +from .compiler import LogicCompiler +from .finch_logic import Aggregate, Alias, LogicNode, MapJoin, Plan, Produces, Query +from .rewrite_tools import Chain, PostOrderDFS, PostWalk, PreWalk, Rewrite + + +def optimize(prgm: LogicNode) -> LogicNode: + # ... + return propagate_map_queries(prgm) + + +def get_productions(root: LogicNode) -> LogicNode: + for node in PostOrderDFS(root): + if isinstance(node, Produces): + return [arg for arg in PostOrderDFS(node) if isinstance(arg, Alias)] + return [] + + +def propagate_map_queries(root: LogicNode) -> LogicNode: + def rule_agg_to_mapjoin(ex): + match ex: + case Aggregate(op, init, arg, ()): + return MapJoin(op, (init, arg)) + + root = Rewrite(PostWalk(rule_agg_to_mapjoin))(root) + rets = get_productions(root) + props = {} + for node in PostOrderDFS(root): + match node: + case Query(a, MapJoin(op, args)) if a not in rets: + props[a] = MapJoin(op, args) + + def rule_0(ex): + return props.get(ex) + + def rule_1(ex): + match ex: + case Query(a, _) if a in props: + return Plan(()) + + def rule_2(ex): + match ex: + case Plan(args) if Plan(()) in args: + return Plan(tuple(a for a in args if a != Plan(()))) + + root = Rewrite(PreWalk(Chain([rule_0, rule_1])))(root) + return Rewrite(PostWalk(rule_2))(root) + + +class DefaultLogicOptimizer: + def __init__(self, ctx: LogicCompiler): + self.ctx = ctx + + def __call__(self, prgm: LogicNode): + prgm = optimize(prgm) + return self.ctx(prgm) diff --git a/sparse/scheduler/rewrite_tools.py b/sparse/scheduler/rewrite_tools.py new file mode 100644 index 00000000..a01cd22e --- /dev/null +++ b/sparse/scheduler/rewrite_tools.py @@ -0,0 +1,103 @@ +from collections.abc import Callable, Iterable, Iterator + +from .finch_logic import LogicNode + +RwCallable = Callable[[LogicNode], LogicNode | None] + + +class SymbolGenerator: + counter: int = 0 + + @classmethod + def gensym(cls, name: str) -> str: + sym = f"#{name}#{cls.counter}" + cls.counter += 1 + return sym + + +_sg = SymbolGenerator() +gensym: Callable[[str], str] = _sg.gensym + + +def get_or_else(x: LogicNode | None, y: LogicNode) -> LogicNode: + return x if x is not None else y + + +def PostOrderDFS(node: LogicNode) -> Iterator[LogicNode]: + if node.is_tree(): + for arg in node.get_arguments(): + yield from PostOrderDFS(arg) + yield node + + +class Rewrite: + def __init__(self, rw: RwCallable): + self.rw = rw + + def __call__(self, x: LogicNode) -> LogicNode: + return get_or_else(self.rw(x), x) + + +class PreWalk: + def __init__(self, rw: RwCallable): + self.rw = rw + + def __call__(self, x: LogicNode) -> LogicNode | None: + y = self.rw(x) + if y is not None: + if y.is_tree(): + args = y.get_arguments() + return y.from_arguments(*[get_or_else(self(arg), arg) for arg in args]) + return y + if x.is_tree(): + args = x.get_arguments() + new_args = list(map(self, args)) + if not all(arg is None for arg in new_args): + return x.from_arguments(*map(lambda x1, x2: get_or_else(x1, x2), new_args, args)) + return None + return None + + +class PostWalk: + def __init__(self, rw: RwCallable): + self.rw = rw + + def __call__(self, x: LogicNode) -> LogicNode | None: + if x.is_tree(): + args = x.get_arguments() + new_args = list(map(self, args)) + if all(arg is None for arg in new_args): + return self.rw(x) + y = x.from_arguments(*map(lambda x1, x2: get_or_else(x1, x2), new_args, args)) + return get_or_else(self.rw(y), y) + return self.rw(x) + + +class Chain: + def __init__(self, rws: Iterable[RwCallable]): + self.rws = rws + + def __call__(self, x: LogicNode) -> LogicNode | None: + is_success = False + for rw in self.rws: + y = rw(x) + if y is not None: + is_success = True + x = y + if is_success: + return x + return None + + +class Fixpoint: + def __init__(self, rw: RwCallable): + self.rw = rw + + def __call__(self, x: LogicNode) -> LogicNode | None: + y = self.rw(x) + if y is not None: + while y is not None and x != y: + x = y + y = self.rw(x) + else: + return None diff --git a/sparse/scheduler/tests/__init__.py b/sparse/scheduler/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sparse/scheduler/tests/conftest.py b/sparse/scheduler/tests/conftest.py new file mode 100644 index 00000000..e69de29b diff --git a/sparse/scheduler/tests/test_scheduler.py b/sparse/scheduler/tests/test_scheduler.py new file mode 100644 index 00000000..113d5c3e --- /dev/null +++ b/sparse/scheduler/tests/test_scheduler.py @@ -0,0 +1,20 @@ +from sparse.scheduler import Aggregate, Alias, Immediate, MapJoin, Plan, Produces, Query, propagate_map_queries + + +def test_simple(): + plan = Plan( + ( + Query(Alias("A10"), Aggregate(Immediate("+"), Immediate(0), Immediate("[1,2,3]"), ())), + Query(Alias("A11"), Alias("A10")), + Produces((Alias("11"),)), + ) + ) + expected = Plan( + ( + Query(Alias("A11"), MapJoin(Immediate("+"), (Immediate(0), Immediate("[1,2,3]")))), + Produces((Alias("11"),)), + ) + ) + + result = propagate_map_queries(plan) + assert result == expected