Files
diffSQL/sql_diff.py
yutou 7ed871da34 更新 sql_diff.py
feat(sql_diff): 增强SQL差异分析功能以支持约束检测

扩展表结构解析功能,支持提取约束信息
修改比较逻辑,新增约束差异检测
更新SQL生成和输出显示,包含约束信息
2026-01-26 17:05:10 +08:00

323 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
SQL差异比较工具
用于比较两个由mysqldump生成的SQL文件找出表结构差异并生成update.sql
使用场景:
- 同步两个数据库的差异
- 项目A可能是旧数据库项目B创建了新的字段和表
- 当给项目A更新服务端时确保数据库结构一致
使用方法:
1. 将两个SQL文件由mysqldump -u root -p -d database > dump.sql生成放在sql文件夹中
2. 运行脚本python sql_diff.py
3. 生成的update.sql文件将放在out文件夹中
"""
import os
import re
from datetime import datetime
def get_sql_files():
"""
获取sql文件夹中的SQL文件
:return: 包含两个SQL文件路径的列表
"""
sql_dir = "sql"
# 检查sql文件夹是否存在
if not os.path.exists(sql_dir):
print(f"错误:{sql_dir}文件夹不存在")
return []
# 获取文件夹中的所有.sql文件
sql_files = [os.path.join(sql_dir, f) for f in os.listdir(sql_dir) if f.endswith('.sql')]
# 确保只有两个SQL文件
if len(sql_files) != 2:
print(f"错误:{sql_dir}文件夹中需要有且只有两个SQL文件当前有{len(sql_files)}")
return []
return sql_files
def parse_sql_file(file_path):
"""
解析SQL文件提取表结构信息
:param file_path: SQL文件路径
:return: 表结构字典,格式为 {表名: {'fields': {字段名: 字段定义}, 'constraints': [约束定义]}}
"""
tables = {}
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
# 匹配CREATE TABLE语句
table_pattern = re.compile(r'CREATE TABLE `(.*?)`\s*\(([\s\S]*?)\) ENGINE=', re.MULTILINE)
matches = table_pattern.findall(content)
for table_name, table_content in matches:
# 解析字段定义和约束
fields = {}
constraints = []
# 手动解析字段和约束,处理括号和引号的匹配
lines = table_content.split('\n')
current_field = None
current_def = []
for line in lines:
line = line.strip()
if not line:
continue
# 检查是否是字段定义的开始
field_match = re.match(r'`(.*?)`\s+', line)
if field_match:
# 如果有当前正在处理的字段,先保存
if current_field:
field_def = ' '.join(current_def).strip()
fields[current_field] = field_def
# 开始处理新字段
current_field = field_match.group(1)
# 提取字段定义的开始部分
field_def_start = line[field_match.end():]
current_def = [field_def_start.rstrip(',').strip()]
elif any(keyword in line.upper() for keyword in ['PRIMARY KEY', 'UNIQUE KEY', 'KEY', 'INDEX']):
# 处理约束行
if current_field:
# 先保存之前的字段
field_def = ' '.join(current_def).strip()
fields[current_field] = field_def
current_field = None
current_def = []
# 保存约束
constraints.append(line.rstrip(',').strip())
elif current_field:
# 继续处理当前字段的定义
current_def.append(line.rstrip(',').strip())
# 保存最后一个字段
if current_field:
field_def = ' '.join(current_def).strip()
fields[current_field] = field_def
tables[table_name] = {'fields': fields, 'constraints': constraints}
return tables
def compare_tables(old_tables, new_tables):
"""
比较两个表结构字典,找出差异
:param old_tables: 旧表结构字典
:param new_tables: 新表结构字典
:return: 差异字典,包含新增表、新增字段和新增约束
"""
diff = {
'new_tables': [], # 新增的表
'added_fields': [], # 新增的字段,格式为 (表名, 字段名, 字段定义)
'added_constraints': [] # 新增的约束,格式为 (表名, 约束定义)
}
# 找出新增的表
for table_name in new_tables:
if table_name not in old_tables:
diff['new_tables'].append(table_name)
# 找出共同表中新增的字段和约束
for table_name in new_tables:
if table_name in old_tables:
old_fields = old_tables[table_name]['fields']
new_fields = new_tables[table_name]['fields']
old_constraints = old_tables[table_name]['constraints']
new_constraints = new_tables[table_name]['constraints']
# 比较字段
for field_name in new_fields:
if field_name not in old_fields:
diff['added_fields'].append((table_name, field_name, new_fields[field_name]))
# 比较约束
for constraint in new_constraints:
# 检查约束是否已存在
constraint_exists = False
# 对于主键,只需要检查是否存在主键约束
if 'PRIMARY KEY' in constraint.upper():
for old_constraint in old_constraints:
if 'PRIMARY KEY' in old_constraint.upper():
constraint_exists = True
break
else:
# 对于其他约束(如索引),提取约束名称进行检查
constraint_name = None
# 尝试提取索引名称
constraint_match = re.search(r'(?:KEY|INDEX)\s+`?([^`\s]+)`?', constraint)
if constraint_match:
constraint_name = constraint_match.group(1)
if constraint_name:
# 检查旧约束中是否存在同名索引
for old_constraint in old_constraints:
old_match = re.search(r'(?:KEY|INDEX)\s+`?([^`\s]+)`?', old_constraint)
if old_match and old_match.group(1) == constraint_name:
constraint_exists = True
break
else:
# 如果无法提取索引名称,进行完整匹配
if constraint in old_constraints:
constraint_exists = True
# 只添加不存在的约束
if not constraint_exists:
diff['added_constraints'].append((table_name, constraint))
return diff
def generate_update_sql(old_file, new_file, diff):
"""
根据差异生成update.sql语句
:param old_file: 旧SQL文件路径
:param new_file: 新SQL文件路径
:param diff: 差异字典
:return: update.sql内容
"""
# 使用Python内置的datetime模块获取当前日期和时间
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
update_sql = "-- SQL差异更新脚本\n"
update_sql += f"-- 生成时间: {current_time}\n"
update_sql += f"-- 比较文件: {os.path.basename(old_file)} (旧) -> {os.path.basename(new_file)} (新)\n\n"
# 读取新文件内容用于提取完整的CREATE TABLE语句
with open(new_file, 'r', encoding='utf-8') as f:
new_content = f.read()
# 添加新增表的CREATE TABLE语句
if diff['new_tables']:
update_sql += "-- 新增表\n\n"
for table_name in diff['new_tables']:
# 提取完整的CREATE TABLE语句
table_pattern = re.compile(r'(CREATE TABLE `' + table_name + r'`[\s\S]*?ENGINE=.*?;)', re.MULTILINE)
table_match = table_pattern.search(new_content)
if table_match:
update_sql += table_match.group(1) + '\n\n'
# 添加新增字段的ALTER TABLE语句
if diff['added_fields']:
update_sql += "-- 新增字段\n\n"
for table_name, field_name, field_def in diff['added_fields']:
alter_sql = f"ALTER TABLE `{table_name}` ADD COLUMN `{field_name}` {field_def};\n"
update_sql += alter_sql
update_sql += '\n'
# 添加新增约束的ALTER TABLE语句
if diff['added_constraints']:
update_sql += "-- 新增约束\n\n"
for table_name, constraint in diff['added_constraints']:
# 构建ALTER TABLE语句
if constraint.upper().startswith('PRIMARY KEY'):
alter_sql = f"ALTER TABLE `{table_name}` ADD {constraint};\n"
elif constraint.upper().startswith('UNIQUE KEY'):
alter_sql = f"ALTER TABLE `{table_name}` ADD {constraint};\n"
elif constraint.upper().startswith('KEY') or constraint.upper().startswith('INDEX'):
alter_sql = f"ALTER TABLE `{table_name}` ADD {constraint};\n"
else:
alter_sql = f"ALTER TABLE `{table_name}` ADD {constraint};\n"
update_sql += alter_sql
update_sql += '\n'
if not diff['new_tables'] and not diff['added_fields'] and not diff['added_constraints']:
update_sql += "-- 未发现差异\n"
return update_sql
def main():
"""
主函数,协调各个步骤
"""
# 获取SQL文件
sql_files = get_sql_files()
if len(sql_files) != 2:
return
# 确定旧文件和新文件根据文件名中的old关键字判断
if 'old' in sql_files[0].lower():
old_file = sql_files[0]
new_file = sql_files[1]
else:
old_file = sql_files[1]
new_file = sql_files[0]
print(f"比较文件: {os.path.basename(old_file)} (旧) -> {os.path.basename(new_file)} (新)")
# 解析SQL文件
print("解析SQL文件...")
old_tables = parse_sql_file(old_file)
new_tables = parse_sql_file(new_file)
print(f"旧文件包含 {len(old_tables)} 个表")
print(f"新文件包含 {len(new_tables)} 个表")
# 比较表结构
print("比较表结构...")
diff = compare_tables(old_tables, new_tables)
# 生成update.sql
print("生成update.sql...")
update_sql_content = generate_update_sql(old_file, new_file, diff)
# 确保out文件夹存在
out_dir = "out"
if not os.path.exists(out_dir):
os.makedirs(out_dir)
# 写入文件
out_file = os.path.join(out_dir, "update.sql")
with open(out_file, 'w', encoding='utf-8') as f:
f.write(update_sql_content)
print(f"update.sql已生成: {out_file}")
print(f"新增表: {len(diff['new_tables'])}")
print(f"新增字段: {len(diff['added_fields'])}")
print(f"新增约束: {len(diff['added_constraints'])}")
# 显示差异详情
if diff['new_tables']:
print("新增表列表:")
for table in diff['new_tables']:
print(f" - {table}")
if diff['added_fields']:
print("新增字段列表:")
for table, field, _ in diff['added_fields']:
print(f" - {table}.{field}")
if diff['added_constraints']:
print("新增约束列表:")
for table, constraint in diff['added_constraints']:
# 提取约束名称(如果有)
constraint_name = ""
constraint_match = re.search(r'(?:KEY|INDEX)\s+`?(.*?)`?\s*\(', constraint)
if constraint_match:
constraint_name = constraint_match.group(1)
elif 'PRIMARY KEY' in constraint:
constraint_name = "PRIMARY KEY"
if constraint_name:
print(f" - {table}.{constraint_name}")
else:
print(f" - {table}: {constraint[:50]}...")
if not diff['new_tables'] and not diff['added_fields'] and not diff['added_constraints']:
print("未发现差异")
if __name__ == "__main__":
main()