Using the python ast module to enfoce validations in pytest
12-05-2025
There was a refactoring bug which resulted in the composition of a sqlalchemy insert statement to include all inline values. This raised an error when the number of records grew large due to constraints from sql statement length in MSSQL. The fix is trivial, see below.
- option 1
- complex insert statement with all values inline
- hits the parameter limit of approx 2100
- option 2
- basic insert statement without values, pass values as parameters, use executemany under the hood
# option 1
insert_stmt = sqlalchemy.insert(tbl_analyst_estimates).values(records)
session.execute(insert_stmt)
session.commit()
# option 2
insert_stmt = sqlalchemy.insert(tbl_analyst_estimates)
session.execute(insert_stmt, records)
session.commit()
What was more interesting was trying to determine where else the bug existed and how we could stop it from appearing in the future. To do so, you can use python's ast
module to enforce validations before runtime invocation, which has the benefit of not having to mock data or create fixtures and helpers (which themselves may drift).
Node visitor to detect syntax usage for Sqlalchemy methods:
class SQLAlchemyInsertVisitor(ast.NodeVisitor):
"""
AST visitor that detects non-parameterized SQLAlchemy insert patterns (ie sqlalchemy.insert(...).values(...)).
"""
def __init__(self, file_path):
self.file_path = file_path
self.issues = []
self.sqlalchemy_imports = {
"insert": False,
"sqlalchemy_modules": set(),
}
def visit_Import(self, node):
"""Track 'import sqlalchemy' or 'import sqlalchemy as X'"""
for name in node.names:
if name.name == "sqlalchemy":
self.sqlalchemy_imports["sqlalchemy_modules"].add(name.asname or "sqlalchemy")
self.generic_visit(node)
def visit_ImportFrom(self, node):
"""Track 'from sqlalchemy import insert' or 'from sqlalchemy import insert as X'"""
if node.module == "sqlalchemy":
for name in node.names:
if name.name == "insert":
self.sqlalchemy_imports["insert"] = True
if name.asname:
self.sqlalchemy_imports[name.asname] = "insert"
self.generic_visit(node)
def visit_Call(self, node):
"""
Track calls to .values() on SQLAlchemy insert objects.
"""
if self._is_values_method_call(node):
if self._is_sqlalchemy_insert_call(node.func.value):
self.issues.append(
{
"file": self.file_path,
"line": node.lineno,
"col": node.col_offset,
"message": "Non-parameterized SQLAlchemy bulk insert detected.",
}
)
self.generic_visit(node)
def _is_values_method_call(self, node):
"""Check if node is a .values() method call."""
return (
isinstance(node, ast.Call)
and isinstance(node.func, ast.Attribute)
and node.func.attr == "values"
)
def _is_sqlalchemy_insert_call(self, node):
"""
Check if the node is a SQLAlchemy insert() call.
Detects these patterns:
1. sqlalchemy.insert(...)
2. <alias>.insert(...) (where <alias> is an alias for sqlalchemy)
3. insert(...) (where insert was imported from sqlalchemy)
4. <alias(...)> (where <alias> is an alias for insert, which was imported from sqlalchemy)
"""
if not isinstance(node, ast.Call):
return False
# pattern 1 & 2
if isinstance(node.func, ast.Attribute):
module_name = self._get_name_from_node(node.func.value)
return (
module_name in self.sqlalchemy_imports["sqlalchemy_modules"]
and node.func.attr == "insert"
)
# pattern 3 & 4
elif isinstance(node.func, ast.Name):
return (node.func.id == "insert" and self.sqlalchemy_imports["insert"]) or (
node.func.id in self.sqlalchemy_imports
and self.sqlalchemy_imports[node.func.id] == "insert"
)
return False
def _get_name_from_node(self, node):
if isinstance(node, ast.Name):
return node.id
return None
Fixture to extract all python file paths from a given list of packages:
@pytest.fixture
def fixt_python_file_paths_in_packages() -> List[pathlib.Path]:
package_names = ["quality", "emcd", "emd", "mase"]
py_files = []
for package_name in package_names:
package_spec = importlib.util.find_spec(package_name)
if not package_spec:
raise ValueError(f"Package '{package_name}' not found")
package_path = pathlib.Path(package_spec.origin).parent
for py_file in package_path.glob("**/*.py"):
py_files.append(py_file)
return py_files
Test case to determine if syntax usage violates assumption:
import ast
class TestAST:
def test_sqlalchemy_inserts_are_parameterized(self, fixt_python_file_paths_in_packages):
"""
Test that all SQLAlchemy inserts in the package use parameterized execution.
This test ensures that bulk inserts use:
insert_stmt = sqlalchemy.insert(table)
session.execute(insert_stmt, records)
Instead of:
insert_stmt = sqlalchemy.insert(table).values(records)
session.execute(insert_stmt)
"""
py_files = fixt_python_file_paths_in_packages
issues = []
for file_path in py_files:
file_issues = _check_file_for_non_parameterized_inserts(str(file_path))
issues.extend(file_issues)
if issues:
error_message = "Found SQLAlchemy non-parameterized inserts:\n\n"
for issue in issues:
error_message += f"{issue['file']} (Line {issue['line']}) (Column {issue['col']})\n"
assert False, error_message
def _check_file_for_non_parameterized_inserts(file_path):
issues = []
try:
with open(file_path, encoding="utf-8") as file:
content = file.read()
tree = ast.parse(content, filename=file_path)
visitor = SQLAlchemyInsertVisitor(file_path)
visitor.visit(tree)
issues.extend(visitor.issues)
except (SyntaxError, UnicodeDecodeError, FileNotFoundError):
pass
return issues