dbsync.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836
  1. # encoding=utf8
  2. """SnailJob
  3. Author: dhb52
  4. pip install simple-ddl-parser
  5. OR install uv and using with command:
  6. uv run --with simple-ddl-parser dbsync.py postgre
  7. """
  8. import argparse
  9. import pathlib
  10. import re
  11. import time
  12. import sys
  13. from abc import ABC, abstractmethod
  14. from typing import Dict, Generator, Optional, Tuple, Union
  15. from simple_ddl_parser import DDLParser
  16. PREAMBLE = """/*
  17. SnailJob Database Transfer Tool
  18. Source Server Type : MySQL
  19. Target Server Type : {db_type}
  20. Date: {date}
  21. */
  22. """
  23. def load_and_clean(sql_file: str) -> str:
  24. """加载源 SQL 文件,并清理内容方便下一步 ddl 解析
  25. Args:
  26. sql_file (str): sql文件路径
  27. Returns:
  28. str: 清理后的sql文件内容
  29. """
  30. REPLACE_PAIR_LIST = (
  31. (")\nVALUES ", ") VALUES "),
  32. (" CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci ", " "),
  33. (" KEY `", " INDEX `"),
  34. ("UNIQUE INDEX", "UNIQUE KEY"),
  35. ("b'0'", "'0'"),
  36. ("b'1'", "'1'"),
  37. )
  38. content = open(sql_file, encoding="utf-8").read()
  39. for replace_pair in REPLACE_PAIR_LIST:
  40. content = content.replace(*replace_pair)
  41. content = re.sub(r"ENGINE.*COMMENT", "COMMENT", content)
  42. content = re.sub(r"ENGINE.*;", ";", content)
  43. return content
  44. class Convertor(ABC):
  45. def __init__(self, src: str, db_type) -> None:
  46. self.src = src
  47. self.db_type = db_type
  48. self.content = load_and_clean(self.src)
  49. self.table_script_list = re.findall(r"CREATE TABLE [^;]*;", self.content)
  50. @abstractmethod
  51. def translate_type(self, type: str, size: Optional[Union[int, Tuple[int]]]) -> str:
  52. """字段类型转换
  53. Args:
  54. type (str): 字段类型
  55. size (Optional[Union[int, Tuple[int]]]): 字段长度描述, 如varchar(255), decimal(10,2)
  56. Returns:
  57. str: 类型定义
  58. """
  59. pass
  60. @abstractmethod
  61. def gen_create(self, table_ddl: Dict) -> str:
  62. """生成 create 脚本
  63. Args:
  64. table_ddl (Dict): 表DDL
  65. Returns:
  66. str: 生成脚本
  67. """
  68. pass
  69. @abstractmethod
  70. def gen_pk(self, table_name: str) -> str:
  71. """生成主键定义
  72. Args:
  73. table_name (str): 表名
  74. Returns:
  75. str: 生成脚本
  76. """
  77. pass
  78. @abstractmethod
  79. def gen_index(self, ddl: Dict) -> str:
  80. """生成索引定义
  81. Args:
  82. table_ddl (Dict): 表DDL
  83. Returns:
  84. str: 生成脚本
  85. """
  86. pass
  87. @abstractmethod
  88. def gen_comment(self, table_ddl: Dict) -> str:
  89. """生成字段/表注释
  90. Args:
  91. table_ddl (Dict): 表DDL
  92. Returns:
  93. str: 生成脚本
  94. """
  95. pass
  96. @abstractmethod
  97. def gen_uk(self, table_ddl: Dict) -> str:
  98. """生成
  99. Args:
  100. table_ddl (Dict): 表DDL
  101. Returns:
  102. str: 生成脚本
  103. """
  104. @abstractmethod
  105. def gen_insert(self, table_name: str) -> str:
  106. """生成 insert 语句块
  107. Args:
  108. table_name (str): 表名
  109. Returns:
  110. str: 生成脚本
  111. """
  112. pass
  113. @staticmethod
  114. def inserts(table_name: str, script_content: str) -> Generator:
  115. PREFIX = f"INSERT INTO `{table_name}`"
  116. # 收集 `table_name` 对应的 insert 语句
  117. for line in script_content.split("\n"):
  118. if line.startswith(PREFIX):
  119. head, tail = line.replace(PREFIX, "").split(" VALUES ", maxsplit=1)
  120. head = head.strip().replace("`", "").lower()
  121. tail = tail.strip().replace(r"\"", '"')
  122. # tail = tail.replace("b'0'", "'0'").replace("b'1'", "'1'")
  123. yield f"INSERT INTO {table_name.lower()} {head} VALUES {tail}"
  124. @staticmethod
  125. def index(ddl: Dict) -> Generator:
  126. """生成索引定义
  127. Args:
  128. ddl (Dict): 表DDL
  129. Yields:
  130. Generator[str]: create index 语句
  131. """
  132. def generate_columns(columns):
  133. keys = [
  134. f"{col['name'].lower()}{' ' + col['order'].lower() if col['order'] != 'ASC' else ''}"
  135. for col in columns[0]
  136. ]
  137. return ", ".join(keys)
  138. for no, index in enumerate(ddl["index"], 1):
  139. columns = generate_columns(index["columns"])
  140. table_name = ddl["table_name"].lower()
  141. yield f"CREATE INDEX idx_{table_name}_{no:02d} ON {table_name} ({columns})"
  142. @staticmethod
  143. def unique_index(ddl: Dict) -> Generator:
  144. if "constraints" in ddl and "uniques" in ddl["constraints"]:
  145. uk_list = ddl["constraints"]["uniques"]
  146. for uk in uk_list:
  147. table_name = ddl["table_name"]
  148. uk_name = uk["constraint_name"]
  149. uk_columns = uk["columns"]
  150. yield table_name, uk_name, uk_columns
  151. @staticmethod
  152. def filed_comments(table_sql: str) -> Generator:
  153. for line in table_sql.split("\n"):
  154. match = re.match(r"^`([^`]+)`.* COMMENT '([^']+)'", line.strip())
  155. if match:
  156. field = match.group(1)
  157. comment_string = match.group(2).replace("\\n", "\n")
  158. yield field, comment_string
  159. def table_comment(self, table_sql: str) -> str:
  160. match = re.search(r"COMMENT \='([^']+)';", table_sql)
  161. return match.group(1) if match else None
  162. def print(self):
  163. """打印转换后的sql脚本到终端"""
  164. print(
  165. PREAMBLE.format(
  166. db_type=self.db_type,
  167. date=time.strftime("%Y-%m-%d %H:%M:%S"),
  168. )
  169. )
  170. error_scripts = []
  171. for table_sql in self.table_script_list:
  172. ddl = DDLParser(table_sql.replace("`", "")).run()
  173. # 如果parse失败, 需要跟进
  174. if len(ddl) == 0:
  175. error_scripts.append(table_sql)
  176. continue
  177. table_ddl = ddl[0]
  178. table_name = table_ddl["table_name"]
  179. # 解析注释
  180. for column in table_ddl["columns"]:
  181. column["comment"] = bytes(column["comment"], "utf-8").decode(
  182. "unicode_escape"
  183. )[1:-1]
  184. table_ddl["comment"] = bytes(table_ddl["comment"], "utf-8").decode(
  185. "unicode_escape"
  186. )[1:-1]
  187. # 为每个表生成个6个基本部分
  188. create = self.gen_create(table_ddl)
  189. pk = self.gen_pk(table_name)
  190. uk = self.gen_uk(table_ddl)
  191. index = self.gen_index(table_ddl)
  192. comment = self.gen_comment(table_ddl)
  193. inserts = self.gen_insert(table_name)
  194. # 组合当前表的DDL脚本
  195. script = f"""{create}
  196. {pk}
  197. {uk}
  198. {index}
  199. {comment}
  200. {inserts}
  201. """
  202. # 清理
  203. script = re.sub("\n{3,}", "\n\n", script).strip() + "\n"
  204. print(script)
  205. # 将parse失败的脚本打印出来
  206. if error_scripts:
  207. print("!!! 注意下面脚本解析失败:\n", file=sys.stderr)
  208. for script in error_scripts:
  209. print(f"{script}\n", file=sys.stderr)
  210. class PostgreSQLConvertor(Convertor):
  211. def __init__(self, src):
  212. super().__init__(src, "PostgreSQL")
  213. def translate_type(self, type: str, size: Optional[Union[int, Tuple[int]]]):
  214. """类型转换"""
  215. type = type.lower()
  216. if type == "varchar":
  217. return f"varchar({size})"
  218. if type in ("int", "int unsigned"):
  219. return "int"
  220. if type in ("bigint", "bigint unsigned"):
  221. return "bigint"
  222. if type == "datetime":
  223. return "timestamp"
  224. if type == "timestamp":
  225. return f"timestamp({size})"
  226. if type == "bit":
  227. return "bool"
  228. if type in ("tinyint", "smallint"):
  229. return "smallint"
  230. if type in ("text", "longtext"):
  231. return "text"
  232. if type in ("blob", "mediumblob"):
  233. return "bytea"
  234. if type == "decimal":
  235. return (
  236. f"numeric({','.join(str(s) for s in size)})" if len(size) else "numeric"
  237. )
  238. def gen_create(self, ddl: Dict) -> str:
  239. """生成 create"""
  240. def _generate_column(col: str, table_name: str) -> str:
  241. name = col["name"].lower()
  242. if name == "id":
  243. return "id bigserial PRIMARY KEY"
  244. if table_name == "sj_distributed_lock" and name == "name":
  245. return "name varchar(64) NOT NULL PRIMARY KEY"
  246. type = col["type"].lower()
  247. full_type = self.translate_type(type, col["size"])
  248. nullable = "NULL" if col["nullable"] else "NOT NULL"
  249. default = f"DEFAULT {col['default']}" if col["default"] is not None else ""
  250. return f"{name} {full_type} {nullable} {default}"
  251. table_name = ddl["table_name"].lower()
  252. columns = [
  253. f"{_generate_column(col, table_name).strip()}" for col in ddl["columns"]
  254. ]
  255. filed_def_list = ",\n ".join(columns)
  256. script = f"""-- {table_name}
  257. CREATE TABLE {table_name} (
  258. {filed_def_list}
  259. );"""
  260. return script
  261. def gen_index(self, ddl: Dict) -> str:
  262. return "\n".join(f"{script};" for script in self.index(ddl))
  263. def gen_comment(self, table_ddl: Dict) -> str:
  264. """生成字段及表的注释"""
  265. script = ""
  266. for column in table_ddl["columns"]:
  267. table_comment = column["comment"]
  268. script += (
  269. f"COMMENT ON COLUMN {table_ddl['table_name']}.{column['name']} IS '{table_comment}';"
  270. + "\n"
  271. )
  272. table_comment = table_ddl["comment"]
  273. if table_comment:
  274. script += (
  275. f"COMMENT ON TABLE {table_ddl['table_name']} IS '{table_comment}';\n"
  276. )
  277. return script
  278. def gen_pk(self, table_name) -> str:
  279. """生成主键定义"""
  280. return ""
  281. def gen_uk(self, table_ddl: Dict) -> str:
  282. script = ""
  283. uk_list = list(Convertor.unique_index(table_ddl))
  284. for idx, (table_name, _, uk_columns) in enumerate(uk_list, 1):
  285. uk_name = f"uk_{table_name}_{idx:02d}"
  286. script += f"CREATE UNIQUE INDEX {uk_name} ON {table_name} ({', '.join(uk_columns)});\n"
  287. return script
  288. def gen_insert(self, table_name: str) -> str:
  289. """生成 insert 语句,以及根据最后的 insert id+1 生成 Sequence"""
  290. inserts = list(Convertor.inserts(table_name, self.content))
  291. ## 生成 insert 脚本
  292. script = ""
  293. if inserts:
  294. inserts_lines = "\n".join(inserts)
  295. script += f"""\n
  296. {inserts_lines}"""
  297. return script
  298. class OracleConvertor(Convertor):
  299. def __init__(self, src):
  300. super().__init__(src, "Oracle")
  301. def translate_type(self, type: str, size: Optional[Union[int, Tuple[int]]]):
  302. """类型转换"""
  303. type = type.lower()
  304. if type == "varchar":
  305. return f"varchar2({size if size < 4000 else 4000})"
  306. if type in ("int", "int unsigned"):
  307. return "number"
  308. if type == "bigint" or type == "bigint unsigned":
  309. return "number"
  310. if type == "datetime":
  311. return "date"
  312. if type == "timestamp":
  313. return f"timestamp({size})"
  314. if type == "bit":
  315. return "number(1,0)"
  316. if type in ("tinyint", "smallint"):
  317. return "smallint"
  318. if type in ("text", "longtext"):
  319. return "clob"
  320. if type in ("blob", "mediumblob"):
  321. return "blob"
  322. if type == "decimal":
  323. return (
  324. f"number({','.join(str(s) for s in size)})" if len(size) else "number"
  325. )
  326. def gen_create(self, ddl: Dict) -> str:
  327. """生成 CREATE 语句"""
  328. def generate_column(col: str, table_name: str) -> str:
  329. name = col["name"].lower()
  330. if name == "id":
  331. return "id number GENERATED ALWAYS AS IDENTITY"
  332. if table_name == "sj_distributed_lock" and name == "name":
  333. return "name varchar2(64) NOT NULL"
  334. type = col["type"].lower()
  335. full_type = self.translate_type(type, col["size"])
  336. nullable = "NULL" if col["nullable"] else "NOT NULL"
  337. # Oracle的 INSERT '' 不能通过NOT NULL校验,因此对文字类型字段覆写为 NULL
  338. nullable = "NULL" if type in ("varchar", "text", "longtext") else nullable
  339. default = f"DEFAULT {col['default']}" if col["default"] is not None else ""
  340. # Oracle 中 size 不能作为字段名
  341. field_name = '"size"' if name == "size" else name
  342. # Oracle DEFAULT 定义在 NULLABLE 之前
  343. return f"{field_name} {full_type} {default} {nullable}"
  344. table_name = ddl["table_name"].lower()
  345. columns = [
  346. f"{generate_column(col, table_name).strip()}" for col in ddl["columns"]
  347. ]
  348. field_def_list = ",\n ".join(columns)
  349. script = f"""-- {table_name}
  350. CREATE TABLE {table_name} (
  351. {field_def_list}
  352. );"""
  353. # oracle INSERT '' 不能通过 NOT NULL 校验
  354. script = script.replace("DEFAULT '' NOT NULL", "DEFAULT '' NULL")
  355. return script
  356. def gen_index(self, ddl: Dict) -> str:
  357. return "\n".join(f"{script};" for script in self.index(ddl))
  358. def gen_comment(self, table_ddl: Dict) -> str:
  359. script = ""
  360. for column in table_ddl["columns"]:
  361. table_comment = column["comment"]
  362. script += (
  363. f"COMMENT ON COLUMN {table_ddl['table_name']}.{column['name']} IS '{table_comment}';"
  364. + "\n"
  365. )
  366. table_comment = table_ddl["comment"]
  367. if table_comment:
  368. script += (
  369. f"COMMENT ON TABLE {table_ddl['table_name']} IS '{table_comment}';\n"
  370. )
  371. return script
  372. def gen_pk(self, table_name: str) -> str:
  373. """生成主键定义"""
  374. key = "id"
  375. if table_name == "sj_distributed_lock":
  376. key = "name"
  377. return f"ALTER TABLE {table_name} ADD CONSTRAINT pk_{table_name} PRIMARY KEY ({key});\n"
  378. def gen_uk(self, table_ddl: Dict) -> str:
  379. script = ""
  380. uk_list = list(Convertor.unique_index(table_ddl))
  381. for idx, (table_name, _, uk_columns) in enumerate(uk_list, 1):
  382. uk_name = f"uk_{table_name}_{idx:02d}"
  383. script += f"CREATE UNIQUE INDEX {uk_name} ON {table_name} ({', '.join(uk_columns)});\n"
  384. return script
  385. def gen_insert(self, table_name: str) -> str:
  386. """拷贝 INSERT 语句"""
  387. inserts = []
  388. for insert_script in Convertor.inserts(table_name, self.content):
  389. insert_script = (
  390. insert_script.replace("(id,", "(")
  391. .replace("VALUES (1,", "VALUES (")
  392. .replace("now(),", "sysdate,")
  393. )
  394. inserts.append(insert_script)
  395. ## 生成 insert 脚本
  396. script = ""
  397. if inserts:
  398. inserts_lines = "\n".join(inserts)
  399. script += f"""\n\n
  400. {inserts_lines}"""
  401. return script
  402. class SQLServerConvertor(Convertor):
  403. """_summary_
  404. Args:
  405. Convertor (_type_): _description_
  406. """
  407. def __init__(self, src):
  408. super().__init__(src, "Microsoft SQL Server")
  409. def translate_type(self, type: str, size: Optional[Union[int, Tuple[int]]]):
  410. """类型转换"""
  411. type = type.lower()
  412. if type == "varchar":
  413. return f"nvarchar({size if size < 4000 else 4000})"
  414. if type in ("int", "int unsigned"):
  415. return "int"
  416. if type in ("bigint", "bigint unsigned"):
  417. return "bigint"
  418. if type in ("datetime", "timestamp"):
  419. return "datetime2"
  420. if type == "bit":
  421. return "varchar(1)"
  422. if type in ("tinyint", "smallint"):
  423. return "tinyint"
  424. if type in ("text", "longtext"):
  425. return "nvarchar(max)"
  426. if type in ("blob", "mediumblob"):
  427. return "varbinary(max)"
  428. if type == "decimal":
  429. return (
  430. f"numeric({','.join(str(s) for s in size)})" if len(size) else "numeric"
  431. )
  432. def gen_create(self, ddl: Dict) -> str:
  433. """生成 create"""
  434. def _generate_column(col: str, table_name: str) -> str:
  435. name = col["name"].lower()
  436. if name == "id":
  437. return "id bigint NOT NULL PRIMARY KEY IDENTITY"
  438. if table_name == "sj_distributed_lock" and name == "name":
  439. return "name nvarchar(64) NOT NULL PRIMARY KEY"
  440. type = col["type"].lower()
  441. full_type = self.translate_type(type, col["size"])
  442. nullable = "NULL" if col["nullable"] else "NOT NULL"
  443. default = f"DEFAULT {col['default']}" if col["default"] is not None else ""
  444. default = re.sub(r"CURRENT_TIMESTAMP\(\d+\)", "CURRENT_TIMESTAMP", default)
  445. return f"{name} {full_type} {nullable} {default}"
  446. table_name = ddl["table_name"].lower()
  447. columns = [
  448. f"{_generate_column(col, table_name).strip()}" for col in ddl["columns"]
  449. ]
  450. filed_def_list = ",\n ".join(columns)
  451. # fmt: off
  452. script = (f"-- {table_name}\n"
  453. f"CREATE TABLE {table_name} (\n"
  454. f" {filed_def_list}\n"
  455. f")\n"
  456. f"GO")
  457. # fmt: on
  458. return script
  459. def gen_comment(self, table_ddl: Dict) -> str:
  460. """生成字段及表的注释"""
  461. script = ""
  462. table_name = table_ddl["table_name"]
  463. for column in table_ddl["columns"]:
  464. column_comment = column["comment"]
  465. field = column["name"]
  466. script += f"""EXEC sp_addextendedproperty
  467. 'MS_Description', N'{column_comment}',
  468. 'SCHEMA', N'dbo',
  469. 'TABLE', N'{table_name}',
  470. 'COLUMN', N'{field}'
  471. GO
  472. """
  473. table_comment = table_ddl["comment"]
  474. if table_comment:
  475. script += f"""EXEC sp_addextendedproperty
  476. 'MS_Description', N'{table_comment}',
  477. 'SCHEMA', N'dbo',
  478. 'TABLE', N'{table_name}'
  479. GO
  480. """
  481. return script
  482. def gen_pk(self, table_name: str) -> str:
  483. """生成主键定义"""
  484. return ""
  485. def gen_uk(self, table_ddl: Dict) -> str:
  486. script = ""
  487. uk_list = list(Convertor.unique_index(table_ddl))
  488. for idx, (table_name, _, uk_columns) in enumerate(uk_list, 1):
  489. uk_name = f"uk_{table_name}_{idx:02d}"
  490. script += f"CREATE UNIQUE INDEX {uk_name} ON {table_name} ({', '.join(uk_columns)})\nGO"
  491. return script
  492. def gen_index(self, ddl: Dict) -> str:
  493. """生成 index"""
  494. return "\n".join(f"{script}\nGO" for script in self.index(ddl))
  495. def gen_insert(self, table_name: str) -> str:
  496. """生成 insert 语句"""
  497. # 收集 `table_name` 对应的 insert 语句
  498. inserts = []
  499. for insert_script in Convertor.inserts(table_name, self.content):
  500. # SQLServer: 字符串前加N,hack,是否存在替换字符串内容的风险
  501. insert_script = (
  502. insert_script.replace(", '", ", N'")
  503. .replace("VALUES ('", "VALUES (N'")
  504. .replace("(id, ", "(")
  505. .replace("VALUES (1, ", "VALUES (")
  506. .replace("now(), ", "getdate(), ")
  507. )
  508. # 删除 insert 的结尾分号
  509. insert_script = re.sub(";$", r"\nGO", insert_script)
  510. inserts.append(insert_script)
  511. ## 生成 insert 脚本
  512. script = ""
  513. if inserts:
  514. inserts_lines = "\n".join(inserts)
  515. script += f"""\n\n
  516. {inserts_lines}"""
  517. return script
  518. class DM8Convertor(Convertor):
  519. def __init__(self, src):
  520. super().__init__(src, "DM8")
  521. def translate_type(self, type: str, size: Optional[Union[int, Tuple[int]]]):
  522. """类型转换"""
  523. type = type.lower()
  524. if type == "varchar":
  525. return f"varchar({size})"
  526. if type in ("int", "int unsigned"):
  527. return "int"
  528. if type in ("bigint", "bigint unsigned"):
  529. return "bigint"
  530. if type == "datetime":
  531. return "datetime"
  532. if type == "timestamp":
  533. return f"timestamp({size})"
  534. if type == "bit":
  535. return "bit"
  536. if type in ("tinyint", "smallint"):
  537. return "smallint"
  538. if type in ("text", "longtext"):
  539. return "text"
  540. if type in ("blob", "mediumblob"):
  541. return "blob"
  542. if type == "decimal":
  543. return (
  544. f"decimal({','.join(str(s) for s in size)})" if len(size) else "decimal"
  545. )
  546. def gen_create(self, ddl) -> str:
  547. """生成 CREATE 语句"""
  548. def generate_column(col: str, table_name: str):
  549. name = col["name"].lower()
  550. if name == "id":
  551. return "id bigint NOT NULL PRIMARY KEY IDENTITY"
  552. if table_name == "sj_distributed_lock" and name == "name":
  553. return "name varchar(64) NOT NULL PRIMARY KEY"
  554. type = col["type"].lower()
  555. full_type = self.translate_type(type, col["size"])
  556. nullable = "NULL" if col["nullable"] else "NOT NULL"
  557. # Oracle的 INSERT '' 不能通过NOT NULL校验,因此对文字类型字段覆写为 NULL
  558. nullable = "NULL" if type in ("varchar", "text", "longtext") else nullable
  559. default = f"DEFAULT {col['default']}" if col["default"] is not None else ""
  560. # Oracle 中 size 不能作为字段名
  561. field_name = '"size"' if name == "size" else name
  562. # Oracle DEFAULT 定义在 NULLABLE 之前
  563. return f"{field_name} {full_type} {default} {nullable}"
  564. table_name = ddl["table_name"].lower()
  565. columns = [
  566. f"{generate_column(col, table_name).strip()}" for col in ddl["columns"]
  567. ]
  568. field_def_list = ",\n ".join(columns)
  569. script = f"""-- {table_name}
  570. CREATE TABLE {table_name} (
  571. {field_def_list}
  572. );"""
  573. # oracle INSERT '' 不能通过 NOT NULL 校验
  574. # script = script.replace("DEFAULT '' NOT NULL", "DEFAULT '' NULL")
  575. return script
  576. def gen_index(self, ddl: Dict) -> str:
  577. return "\n".join(f"{script};" for script in self.index(ddl))
  578. def gen_comment(self, table_ddl: Dict) -> str:
  579. script = ""
  580. for column in table_ddl["columns"]:
  581. table_comment = column["comment"]
  582. script += (
  583. f"COMMENT ON COLUMN {table_ddl['table_name']}.{column['name']} IS '{table_comment}';"
  584. + "\n"
  585. )
  586. table_comment = table_ddl["comment"]
  587. if table_comment:
  588. script += (
  589. f"COMMENT ON TABLE {table_ddl['table_name']} IS '{table_comment}';\n"
  590. )
  591. return script
  592. def gen_pk(self, table_name: str) -> str:
  593. """生成主键定义"""
  594. return ""
  595. def gen_uk(self, table_ddl: Dict) -> str:
  596. script = ""
  597. uk_list = list(Convertor.unique_index(table_ddl))
  598. for idx, (table_name, _, uk_columns) in enumerate(uk_list, 1):
  599. uk_name = f"uk_{table_name}_{idx:02d}"
  600. script += f"CREATE UNIQUE INDEX {uk_name} ON {table_name} ({', '.join(uk_columns)});\n"
  601. return script
  602. def gen_insert(self, table_name: str) -> str:
  603. """拷贝 INSERT 语句"""
  604. inserts = []
  605. for insert_script in Convertor.inserts(table_name, self.content):
  606. insert_script = (
  607. insert_script.replace("(id,", "(")
  608. .replace("VALUES (1,", "VALUES (")
  609. .replace("now(),", "sysdate,")
  610. )
  611. inserts.append(insert_script)
  612. ## 生成 insert 脚本
  613. script = ""
  614. if inserts:
  615. inserts_lines = "\n".join(inserts)
  616. script += f"""\n\n
  617. {inserts_lines}"""
  618. return script
  619. class KingbaseConvertor(PostgreSQLConvertor):
  620. def __init__(self, src):
  621. super().__init__(src)
  622. self.db_type = "KingbaseES"
  623. def gen_create(self, ddl: Dict) -> str:
  624. """生成 create"""
  625. def _generate_column(col: str, table_name: str) -> str:
  626. name = col["name"].lower()
  627. if name == "id":
  628. return "id bigserial PRIMARY KEY"
  629. if table_name == "sj_distributed_lock" and name == "name":
  630. return "name varchar(64) NOT NULL PRIMARY KEY"
  631. type = col["type"].lower()
  632. full_type = self.translate_type(type, col["size"])
  633. nullable = "NULL" if col["nullable"] else "NOT NULL"
  634. if full_type == "text":
  635. nullable = "NULL"
  636. default = f"DEFAULT {col['default']}" if col["default"] is not None else ""
  637. return f"{name} {full_type} {nullable} {default}"
  638. table_name = ddl["table_name"].lower()
  639. columns = [
  640. f"{_generate_column(col, table_name).strip()}" for col in ddl["columns"]
  641. ]
  642. filed_def_list = ",\n ".join(columns)
  643. script = f"""-- {table_name}
  644. CREATE TABLE {table_name} (
  645. {filed_def_list}
  646. );"""
  647. # kingbase INSERT '' 不能通过 NOT NULL 校验
  648. script = script.replace("NOT NULL DEFAULT ''", "NULL DEFAULT ''")
  649. return script
  650. def main():
  651. parser = argparse.ArgumentParser(description="Snail Job Database Transfer Tool")
  652. parser.add_argument(
  653. "type",
  654. type=str,
  655. help="Target database type",
  656. choices=["postgre", "oracle", "sqlserver", "dm8", "kingbase"],
  657. )
  658. args = parser.parse_args()
  659. sql_file = pathlib.Path("../sql/snail_job_mysql.sql").resolve().as_posix()
  660. convertor = None
  661. if args.type == "postgre":
  662. convertor = PostgreSQLConvertor(sql_file)
  663. elif args.type == "oracle":
  664. convertor = OracleConvertor(sql_file)
  665. elif args.type == "sqlserver":
  666. convertor = SQLServerConvertor(sql_file)
  667. elif args.type == "dm8":
  668. convertor = DM8Convertor(sql_file)
  669. elif args.type == "kingbase":
  670. convertor = KingbaseConvertor(sql_file)
  671. else:
  672. raise NotImplementedError(f"Database type not supported: {args.type}")
  673. convertor.print()
  674. if __name__ == "__main__":
  675. main()