Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

auto_stage_mem bounds arithmetic is wrong #698

Open
gilbo opened this issue Sep 3, 2024 · 2 comments
Open

auto_stage_mem bounds arithmetic is wrong #698

gilbo opened this issue Sep 3, 2024 · 2 comments
Labels
C: Prog Analysis Related to formal analysis, SMT, etc. C: Scheduling The scheduling language and APIs T: Bug Something isn't working

Comments

@gilbo
Copy link
Contributor

gilbo commented Sep 3, 2024

The bounds are wrong and somehow stage_mem is also not throwing an error.

Scheduling call

neon = auto_stage_mem(neon, neon.find_loop("i"), "B", "B_reg")

Before code

def rank_k_reduce_6x16_scheduled(K: size, A: f32[6, K] @ DRAM,
                                 B: f32[K, 16] @ DRAM, C: f32[6, 16] @ DRAM):
    C_reg: f32[6, 4, 4] @ Neon
    for i0 in seq(0, 6):
        for jo in seq(0, 4):
            neon_vld_4xf32(C_reg[i0, jo, 0:4], C[i0, 4 * jo:4 + 4 * jo])
    for k in seq(0, K):
        for i in seq(0, 6):
            for jo in seq(0, 4):
                for ji in seq(0, 4):
                    C_reg[i, jo, ji] += A[i, k] * B[k, ji + 4 * jo]
    for i0 in seq(0, 6):
        for jo in seq(0, 4):
            neon_vst_4xf32(C[i0, 4 * jo:4 + 4 * jo], C_reg[i0, jo, 0:4])

After code

def rank_k_reduce_6x16_scheduled(K: size, A: f32[6, K] @ DRAM,
                                 B: f32[K, 16] @ DRAM, C: f32[6, 16] @ DRAM):
    C_reg: f32[6, 4, 4] @ Neon
    for i0 in seq(0, 6):
        for jo in seq(0, 4):
            neon_vld_4xf32(C_reg[i0, jo, 0:4], C[i0, 4 * jo:4 + 4 * jo])
    for k in seq(0, K):
        B_reg: f32[4 + 4 * 4 - (0 + 4 * 0)] @ DRAM
        for i0 in seq(0, 4 + 4 * 4 - (0 + 4 * 0)):
            if i0 + (0 + 4 * 0) < 16:
                B_reg[i0] = B[k, i0 + (0 + 4 * 0)]
        for i in seq(0, 6):
            for jo in seq(0, 4):
                for ji in seq(0, 4):
                    C_reg[i, jo,
                          ji] += A[i, k] * B_reg[ji + 4 * jo - (0 + 4 * 0)]
    for i0 in seq(0, 6):
        for jo in seq(0, 4):
            neon_vst_4xf32(C[i0, 4 * jo:4 + 4 * jo], C_reg[i0, jo, 0:4])
@gilbo gilbo added T: Bug Something isn't working C: Scheduling The scheduling language and APIs C: Prog Analysis Related to formal analysis, SMT, etc. labels Sep 3, 2024
@gilbo
Copy link
Contributor Author

gilbo commented Sep 3, 2024

Note that this is a bug in the PIP release. It may be worth fixing it on a branch so that the fix can be upstreamed in a sub-minor, bug patch version (i.e. version Major.Minor.Patch)

@gilbo
Copy link
Contributor Author

gilbo commented Sep 3, 2024

Problem 1:
Here is at least one problem with the interval analysis in auto_stage_mem:

rng_l = join(lhs_rng[0], expr.op(), rhs_rng[0])
rng_r = join(lhs_rng[1], expr.op(), rhs_rng[1])
if is_add(proc, expr) or is_sub(proc, expr):
    return (rng_l, rng_r)

This is incorrect in the case of subtraction. One expects (alo, ahi) - (blo, bhi) = (alo - bhi, ahi - blo) not (alo - blo, ahi - bhi) You can work this out intuitively and also see Wikipedia for confirmation of the correct rule.


Problem 2: (less critical, but still a bug)

if is_literal(proc, expr.rhs()): ...
elif is_literal(proc, expr.lhs()): ...
else: assert False, "Unreachable case"

The "Unreachable Case" assertion is reachable because typechecking merely checks that one multiplies by a constant expression (i.e. an expression of type T.int) not more strictly that one multiplies by a literal. One needs an inspection routine that checks whether an expression is constant, rather than merely literal. Note that one additionally needs a way to query for the value of a constant expression that is not a literal. This likewise is something that ought to be built in.

Proposed Low-Effort Fix for Problem 2: Throw an exception with an error message suggesting that the user first try simplifying the proc before calling auto_stage_mem as simplification should collapse all constant sub-expressions into literal nodes.

NOTE: Similarly the other "Unreachable Case" is reachable, at least via a ReadConfig if not via other means.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
C: Prog Analysis Related to formal analysis, SMT, etc. C: Scheduling The scheduling language and APIs T: Bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant