Skip to content

Commit

Permalink
Initial draft
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsokol committed Jan 7, 2025
1 parent ada7871 commit 3e5894a
Show file tree
Hide file tree
Showing 11 changed files with 764 additions and 0 deletions.
1 change: 1 addition & 0 deletions sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings
from enum import Enum

from . import scheduler # noqa: F401
from ._version import __version__, __version_tuple__ # noqa: F401

__array_api_version__ = "2022.12"
Expand Down
40 changes: 40 additions & 0 deletions sparse/scheduler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from .finch_logic import (
Aggregate,
Alias,
Deferred,
Field,
Immediate,
MapJoin,
Plan,
Produces,
Query,
Reformat,
Relabel,
Reorder,
Subquery,
Table,
)
from .optimize import optimize, propagate_map_queries
from .rewrite_tools import PostOrderDFS, PostWalk, PreWalk

__all__ = [
"Aggregate",
"Alias",
"Deferred",
"Field",
"Immediate",
"MapJoin",
"Plan",
"Produces",
"Query",
"Reformat",
"Relabel",
"Reorder",
"Subquery",
"Table",
"optimize",
"propagate_map_queries",
"PostOrderDFS",
"PostWalk",
"PreWalk",
]
79 changes: 79 additions & 0 deletions sparse/scheduler/compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from textwrap import dedent

from .finch_logic import Alias, Deferred, Field, Immediate, LogicNode, MapJoin, Query, Reformat, Relabel, Reorder, Table


class PointwiseLowerer:
def __init__(self):
self.bound_idxs = []

Check warning on line 8 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L8

Added line #L8 was not covered by tests

def __call__(self, ex):
match ex:
case MapJoin(op, args) if isinstance(op, Immediate):
return f":({op.val}({','.join([self(arg) for arg in args])}))"
case Reorder(Relabel(arg, idxs_1), idxs_2) if isinstance(arg, Alias):
self.bound_idxs.append(idxs_1)
return f":({arg.name}[{','.join([idx.name if idx in idxs_2 else 1 for idx in idxs_1])}])"
case Reorder(arg, _) if isinstance(arg, Immediate):
return arg.val
case Immediate(val):
return val
case _:
raise Exception(f"Unrecognized logic: {ex}")

Check warning on line 22 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L11-L22

Added lines #L11 - L22 were not covered by tests


def compile_pointwise_logic(ex: LogicNode) -> tuple:
ctx = PointwiseLowerer()
code = ctx(ex)
return (code, ctx.bound_idxs)

Check warning on line 28 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L26-L28

Added lines #L26 - L28 were not covered by tests


def compile_logic_constant(ex):
match ex:
case Immediate(val):
return val
case Deferred(ex, type_):
return f":({ex}::{type_})"
case _:
raise Exception(f"Invalid constant: {ex}")

Check warning on line 38 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L32-L38

Added lines #L32 - L38 were not covered by tests


class LogicLowerer:
def __init__(self, mode: str = "fast"):
self.mode = mode

Check warning on line 43 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L43

Added line #L43 was not covered by tests

def __call__(self, ex):
match ex:
case Query(lhs, Table(tns, _)) if isinstance(lhs, Alias):
return f":({lhs.name} = {compile_logic_constant(tns)})"

Check warning on line 48 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L46-L48

Added lines #L46 - L48 were not covered by tests

case Query(lhs, Reformat(tns, Reorder(Relabel(arg, idxs_1), idxs_2))) if isinstance(

Check warning on line 50 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L50

Added line #L50 was not covered by tests
lhs, Alias
) and isinstance(arg, Alias):
loop_idxs = [idx.name for idx in withsubsequence(intersect(idxs_1, idxs_2), idxs_2)] # noqa: F821
lhs_idxs = [idx.name for idx in idxs_2]
(rhs, rhs_idxs) = compile_pointwise_logic(Reorder(Relabel(arg, idxs_1), idxs_2))
body = f":({lhs.name}[{','.join(lhs_idxs)}] = {rhs})"
for idx in loop_idxs:
if Field(idx) in rhs_idxs:
body = f":(for {idx} = _ \n {body} end)"
elif idx in lhs_idxs:
body = f":(for {idx} = 1:1 \n {body} end)"

Check warning on line 61 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L53-L61

Added lines #L53 - L61 were not covered by tests

result = f"""\

Check warning on line 63 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L63

Added line #L63 was not covered by tests
quote
{lhs.name} = {compile_logic_constant(tns)}
@finch mode = {self.mode} begin
{lhs.name} .= {tns.fill_value}
{body}
return {lhs.name}
end
end
"""
return dedent(result)

Check warning on line 73 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L73

Added line #L73 was not covered by tests


class LogicCompiler:
def __call__(self, prgm):
prgm = format_queries(prgm, True) # noqa: F821
return LogicLowerer()(prgm)

Check warning on line 79 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L78-L79

Added lines #L78 - L79 were not covered by tests
27 changes: 27 additions & 0 deletions sparse/scheduler/executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from .compiler import LogicCompiler


class LogicExecutor:
def __init__(self, ctx, verbose=False):
self.ctx: LogicCompiler = ctx
self.codes = {}
self.verbose = verbose

def __call__(self, prgm):
prgm_structure = prgm
if prgm_structure not in self.codes:
thunk = logic_executor_code(self.ctx, prgm)
self.codes[prgm_structure] = eval(thunk), thunk

Check warning on line 14 in sparse/scheduler/executor.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/executor.py#L11-L14

Added lines #L11 - L14 were not covered by tests

f, code = self.codes[prgm_structure]
if self.verbose:
print(code)
return f(prgm)

Check warning on line 19 in sparse/scheduler/executor.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/executor.py#L16-L19

Added lines #L16 - L19 were not covered by tests


def get_structure():
pass

Check warning on line 23 in sparse/scheduler/executor.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/executor.py#L23

Added line #L23 was not covered by tests


def logic_executor_code(ctx, prgm):
pass

Check warning on line 27 in sparse/scheduler/executor.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/executor.py#L27

Added line #L27 was not covered by tests
Loading

0 comments on commit 3e5894a

Please sign in to comment.