Skip to content

Conversation

@LeiWang1999
Copy link
Contributor

  • Added modular set analysis to ConstIntBoundAnalyzer for tighter bounds when min_value equals max_value.
  • Introduced ComputeGCD function to calculate the GCD of two integers.
  • Updated Combine functions in IntervalSet to accept operation nodes for better type handling.
  • Enhanced tests for modular set bounds in both const integer bounds and interval sets.

- Added modular set analysis to ConstIntBoundAnalyzer for tighter bounds when min_value equals max_value.
- Introduced ComputeGCD function to calculate the GCD of two integers.
- Updated Combine functions in IntervalSet to accept operation nodes for better type handling.
- Enhanced tests for modular set bounds in both const integer bounds and interval sets.
if (b.min_value > 0) {
int64_t b_max_cap = InfAwareAdd(b.max_value, -1);
// Try to get tighter bounds using modular set information
if (parent_ && b.min_value == b.max_value) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we have the bound analysis already in IntervalSet, is the const int bound still necessary? just want to get a sense of if we need to introduce tihs bound

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from my understanding, const int bound is faster to analysis, and IntervalSet can build on those constant bounds for further analysis. keeping them separate makes the design clearer in my view.

@tqchen
Copy link
Member

tqchen commented Oct 2, 2025

@LeiWang1999 would be great to followup on the notes and get it in

@LeiWang1999
Copy link
Contributor Author

@tqchen sure, sorry for the delay. I'll work on this tomorrow.

@tqchen
Copy link
Member

tqchen commented Oct 13, 2025

@tvm-bot rerun

@tqchen
Copy link
Member

tqchen commented Oct 14, 2025

we are close, @LeiWang1999 please fix the remaining example, likely just need to update the after case

    TVM_FFI_THROW(ValueError) << oss.str();
[2025-10-13T17:27:54.414Z] E   ValueError: StructuralEqual check failed, caused by lhs at <root>.body.block.body.seq[0].body.body.body.body.block.body.extent:
[2025-10-13T17:27:54.414Z] E   # from tvm.script import tir as T
[2025-10-13T17:27:54.414Z] E   
[2025-10-13T17:27:54.414Z] E   @T.prim_func
[2025-10-13T17:27:54.414Z] E   def main(var_x: T.handle, var_adaptive_pool_avg: T.handle):
[2025-10-13T17:27:54.414Z] E       T.func_attr({"tir.noalias": True})
[2025-10-13T17:27:54.414Z] E       x = T.match_buffer(var_x, (1, 1024, 16, 40))
[2025-10-13T17:27:54.414Z] E       adaptive_pool_avg = T.match_buffer(var_adaptive_pool_avg, (1, 1024, 12, 30))
[2025-10-13T17:27:54.414Z] E       with T.block("root"):
[2025-10-13T17:27:54.414Z] E           T.reads()
[2025-10-13T17:27:54.414Z] E           T.writes()
[2025-10-13T17:27:54.414Z] E           adaptive_pool_sum = T.alloc_buffer((1, 1024, 12, 30))
[2025-10-13T17:27:54.414Z] E           for ax0 in range(1):
[2025-10-13T17:27:54.414Z] E               for ax1 in range(1024):
[2025-10-13T17:27:54.414Z] E                   for ax2 in range(12):
[2025-10-13T17:27:54.414Z] E                       for ax3 in range(30):
[2025-10-13T17:27:54.414Z] E                           with T.block("adaptive_pool_sum_l1"):
[2025-10-13T17:27:54.414Z] E                               v_ax0 = T.axis.spatial(1, ax0)
[2025-10-13T17:27:54.414Z] E                               v_ax1 = T.axis.spatial(1024, ax1)
[2025-10-13T17:27:54.414Z] E                               v_ax2 = T.axis.spatial(12, ax2)
[2025-10-13T17:27:54.414Z] E                               v_ax3 = T.axis.spatial(30, ax3)
[2025-10-13T17:27:54.414Z] E                               T.reads(x[v_ax0, v_ax1, v_ax2 * 16 // 12:v_ax2 * 16 // 12 + ((v_ax2 % 3 * 4 + 16) // 12 + 1), v_ax3 * 40 // 30:v_ax3 * 40 // 30 + ((v_ax3 % 3 * 10 + 40) // 30 + 1)])
[2025-10-13T17:27:54.414Z] E                               T.writes(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
[2025-10-13T17:27:54.414Z] E                               for rv0 in range((v_ax2 % 3 * 4 + 16) // 12 + 1):
[2025-10-13T17:27:54.414Z] E                                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[2025-10-13T17:27:54.414Z] E                                   for rv1 in range((v_ax3 % 3 * 10 + 40) // 30 + 1):
[2025-10-13T17:27:54.414Z] E                                       with T.block("adaptive_pool_sum"):
[2025-10-13T17:27:54.414Z] E                                           v_ax0_1 = T.axis.spatial((v_ax0, v_ax0 + 1), v_ax0)
[2025-10-13T17:27:54.414Z] E                                           v_ax1_1 = T.axis.spatial((v_ax1, v_ax1 + 1), v_ax1)
[2025-10-13T17:27:54.414Z] E                                           v_ax2_1 = T.axis.spatial((v_ax2, v_ax2 + 1), v_ax2)
[2025-10-13T17:27:54.414Z] E                                           v_ax3_1 = T.axis.spatial((v_ax3, v_ax3 + 1), v_ax3)
[2025-10-13T17:27:54.414Z] E                                           v_rv0 = T.axis.reduce((v_ax2 % 3 * 4 + 16) // 12 + 1, rv0)
[2025-10-13T17:27:54.414Z] E                                           v_rv1 = T.axis.reduce((v_ax3 % 3 * 10 + 40) // 30 + 1, rv1)
[2025-10-13T17:27:54.414Z] E                                           T.reads(x[v_ax0_1, v_ax1_1, v_ax2_1 * 16 // 12 + v_rv0, v_ax3_1 * 40 // 30 + v_rv1])
[2025-10-13T17:27:54.414Z] E                                           T.writes(adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, v_ax3_1])
[2025-10-13T17:27:54.414Z] E                                           with T.init():
[2025-10-13T17:27:54.414Z] E                                               adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, v_ax3_1] = T.float32(0.0)
[2025-10-13T17:27:54.414Z] E                                           adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, v_ax3_1] = adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, v_ax3_1] + x[v_ax0_1, v_ax1_1, v_ax2_1 * 16 // 12 + v_rv0, v_ax3_1 * 40 // 30 + v_rv1]
[2025-10-13T17:27:54.414Z] E           for ax0 in range(1):
[2025-10-13T17:27:54.414Z] E               for ax1 in range(1024):
[2025-10-13T17:27:54.414Z] E                   for ax2 in range(12):
[2025-10-13T17:27:54.414Z] E                       for ax3 in range(30):
[2025-10-13T17:27:54.414Z] E                           with T.block("adaptive_pool_avg"):
[2025-10-13T17:27:54.414Z] E                               v_ax0 = T.axis.spatial(1, ax0)
[2025-10-13T17:27:54.414Z] E                               v_ax1 = T.axis.spatial(1024, ax1)
[2025-10-13T17:27:54.414Z] E                               v_ax2 = T.axis.spatial(12, ax2)
[2025-10-13T17:27:54.414Z] E                               v_ax3 = T.axis.spatial(30, ax3)
[2025-10-13T17:27:54.414Z] E                               T.reads(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
[2025-10-13T17:27:54.414Z] E                               T.writes(adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3])
[2025-10-13T17:27:54.414Z] E                               T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"})
[2025-10-13T17:27:54.414Z] E                               adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / (T.Cast("float32", (v_ax2 % 3 * 4 + 16) // 12 + 1) * T.Cast("float32", (v_ax3 % 3 * 10 + 40) // 30 + 1))
[2025-10-13T17:27:54.414Z] E   and rhs at <root>.body.block.body.seq[0].body.body.body.body.block.body.extent:
[2025-10-13T17:27:54.414Z] E   # from tvm.script import tir as T
[2025-10-13T17:27:54.414Z] E   
[2025-10-13T17:27:54.414Z] E   @T.prim_func
[2025-10-13T17:27:54.414Z] E   def main(x_handle: T.handle, adaptive_pool_avg_handle: T.handle):
[2025-10-13T17:27:54.414Z] E       T.func_attr({"tir.noalias": True})
[2025-10-13T17:27:54.414Z] E       x = T.match_buffer(x_handle, (1, 1024, 16, 40))
[2025-10-13T17:27:54.414Z] E       adaptive_pool_avg = T.match_buffer(adaptive_pool_avg_handle, (1, 1024, 12, 30))
[2025-10-13T17:27:54.414Z] E       with T.block("root"):
[2025-10-13T17:27:54.414Z] E           T.reads()
[2025-10-13T17:27:54.414Z] E           T.writes()
[2025-10-13T17:27:54.414Z] E           adaptive_pool_sum = T.alloc_buffer((1, 1024, 12, 30))
[2025-10-13T17:27:54.414Z] E           for ax0 in range(1):
[2025-10-13T17:27:54.414Z] E               for ax1 in range(1024):
[2025-10-13T17:27:54.414Z] E                   for ax2 in range(12):
[2025-10-13T17:27:54.414Z] E                       for ax3 in range(30):
[2025-10-13T17:27:54.414Z] E                           with T.block("adaptive_pool_sum_1"):
[2025-10-13T17:27:54.414Z] E                               v_ax0 = T.axis.spatial(1, ax0)
[2025-10-13T17:27:54.414Z] E                               v_ax1 = T.axis.spatial(1024, ax1)
[2025-10-13T17:27:54.414Z] E                               v_ax2 = T.axis.spatial(12, ax2)
[2025-10-13T17:27:54.414Z] E                               v_ax3 = T.axis.spatial(30, ax3)
[2025-10-13T17:27:54.414Z] E                               T.reads(x[v_ax0, v_ax1, v_ax2 * 16 // 12:v_ax2 * 16 // 12 + ((v_ax2 % 3 * 4 + 16) // 12 + 1), v_ax3 * 40 // 30:v_ax3 * 40 // 30 + ((v_ax3 % 3 * 10 + 40) // 30 + 1)])
[2025-10-13T17:27:54.414Z] E                               T.writes(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
[2025-10-13T17:27:54.414Z] E                               for rv0 in range(T.Select((v_ax2 * 4 + 4) % 12 == 0, (v_ax2 * 16 + 16) // 12, (v_ax2 * 16 + 16) // 12 + 1) - v_ax2 * 16 // 12):
[2025-10-13T17:27:54.414Z] E                                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[2025-10-13T17:27:54.414Z] E                                   for rv1 in range(T.Select((v_ax3 * 10 + 10) % 30 == 0, (v_ax3 * 40 + 40) // 30, (v_ax3 * 40 + 40) // 30 + 1) - v_ax3 * 40 // 30):
[2025-10-13T17:27:54.414Z] E                                       with T.block("adaptive_pool_sum"):
[2025-10-13T17:27:54.414Z] E                                           v_ax0_1 = T.axis.spatial((v_ax0, v_ax0 + 1), v_ax0)
[2025-10-13T17:27:54.414Z] E                                           v_ax1_1 = T.axis.spatial((v_ax1, v_ax1 + 1), v_ax1)
[2025-10-13T17:27:54.414Z] E                                           v_ax2_1 = T.axis.spatial((v_ax2, v_ax2 + 1), v_ax2)
[2025-10-13T17:27:54.414Z] E                                           v_ax3_1 = T.axis.spatial((v_ax3, v_ax3 + 1), v_ax3)
[2025-10-13T17:27:54.414Z] E                                           v_rv0 = T.axis.reduce(T.Select((v_ax2 * 4 + 4) % 12 == 0, (v_ax2 * 16 + 16) // 12, (v_ax2 * 16 + 16) // 12 + 1) - v_ax2 * 16 // 12, rv0)
[2025-10-13T17:27:54.414Z] E                                           v_rv1 = T.axis.reduce(T.Select((v_ax3 * 10 + 10) % 30 == 0, (v_ax3 * 40 + 40) // 30, (v_ax3 * 40 + 40) // 30 + 1) - v_ax3 * 40 // 30, rv1)
[2025-10-13T17:27:54.414Z] E                                           T.reads(x[v_ax0_1, v_ax1_1, v_ax2_1 * 16 // 12 + v_rv0, v_ax3_1 * 40 // 30 + v_rv1])
[2025-10-13T17:27:54.414Z] E                                           T.writes(adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, v_ax3_1])
[2025-10-13T17:27:54.414Z] E                                           with T.init():
[2025-10-13T17:27:54.414Z] E                                               adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, v_ax3_1] = T.float32(0.0)
[2025-10-13T17:27:54.414Z] E                                           adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, v_ax3_1] = adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, v_ax3_1] + x[v_ax0_1, v_ax1_1, v_ax2_1 * 16 // 12 + v_rv0, v_ax3_1 * 40 // 30 + v_rv1]
[2025-10-13T17:27:54.414Z] E           for ax0 in range(1):
[2025-10-13T17:27:54.414Z] E               for ax1 in range(1024):
[2025-10-13T17:27:54.414Z] E                   for ax2 in range(12):
[2025-10-13T17:27:54.414Z] E                       for ax3 in range(30):
[2025-10-13T17:27:54.414Z] E                           with T.block("adaptive_pool_avg"):
[2025-10-13T17:27:54.414Z] E                               v_ax0 = T.axis.spatial(1, ax0)
[2025-10-13T17:27:54.414Z] E                               v_ax1 = T.axis.spatial(1024, ax1)
[2025-10-13T17:27:54.414Z] E                               v_ax2 = T.axis.spatial(12, ax2)
[2025-10-13T17:27:54.414Z] E                               v_ax3 = T.axis.spatial(30, ax3)
[2025-10-13T17:27:54.414Z] E                               T.reads(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
[2025-10-13T17:27:54.414Z] E                               T.writes(adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3])
[2025-10-13T17:27:54.415Z] E                               T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"})
[2025-10-13T17:27:54.415Z] E                               adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / (T.Cast("float32", T.Select((v_ax2 * 4 + 4) % 12 == 0, (v_ax2 * 16 + 16) // 12, (v_ax2 * 16 + 16) // 12 + 1) - v_ax2 * 16 // 12) * T.Cast("float32", T.Select((v_ax3 * 10 + 10) % 30 == 0, (v_ax3 * 40 + 40) // 30, (v_ax3 * 40 + 40) // 30 + 1) - v_ax3 * 40 // 30))
[2025-10-13T17:27:54.415Z] -- generated xml file: /workspace/build/pytest-results/te,-shard-0-cython.xml --
[2025-10-13T17:27:54.415Z] =========================== short test summary info ============================
[2025-10-13T17:27:54.415Z] FAILED tests/python/te/test_te_create_primfunc.py::test_adaptive_pooling_window - ValueError: StructuralEqual check failed, caused by lhs at <root>.body.block.body.seq[0].body.body.body.body.block.body.extent:
[2025-10-13T17:27:54.415Z] # from tvm.script import tir as T
[2025-10-13T17:27:54.415Z] 

@LeiWang1999
Copy link
Contributor Author

I double checked that the expr T.Select((v_ax2 * 4 + 4) % 12 == 0, (v_ax2 * 16 + 16) // 12, (v_ax2 * 16 + 16) // 12 + 1) - v_ax2 * 16 // 12 is euqal to the simplified version (v_ax3 % 3 * 10 + 40) // 30 + 1 when v_ax3 is an integer in [0, 12), likely the new rule can be powerful.

@tqchen
Copy link
Member

tqchen commented Oct 17, 2025

tests/python/meta_schedule/test_meta_schedule_feature_extractor_per_store_feature.py can be fixed by relaxing bounds, then we are good to go with lint

@tqchen
Copy link
Member

tqchen commented Oct 17, 2025

you can run clang-format via ./tests/lint/git-clang-format.sh -i HEAD~7

@tqchen tqchen merged commit 70c157d into apache:main Oct 18, 2025
10 checks passed
@tqchen
Copy link
Member

tqchen commented Oct 18, 2025

thanks @LeiWang1999 this is merged

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants