1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
|
import math import datetime import numpy as np import pandas as pd import sqlalchemy import decimal from sqlalchemy import text from sqlalchemy.orm import sessionmaker from contextlib import contextmanager
''' engine: SQLAlchemy Engine buffer_size: 缓存条目数,当缓存满时自动flush update_on_duplicate: 当唯一键重复时的行为,默认是update,设置为False表示不更新,即忽略插入失败。 ''' def create_upserter(engine, buffer_size=5000, update_on_duplicate=True, dummy=False): if dummy: return DummyUpserter(engine) if engine.dialect.name.lower().find("mysql") != -1: return MySQLUpserter(engine, buffer_size, update_on_duplicate) elif engine.dialect.name.lower().find("postgresql") != -1: return PSQLUpserter(engine, buffer_size, update_on_duplicate) else: print(f"没有为{engine.dialect.name}实现特殊的Upsert,使用默认版本,请确认可以正常工作,建议特化一个专门版本") return UpserterBase(engine, buffer_size, update_on_duplicate)
def is_duplicate_key(e): for T in UpserterBase.__subclasses__(): if T.is_duplicate_key(e): return True return UpserterBase.is_duplicate_key(e)
''' class Upserter:
# 传入的engine类型应该和使用的Upserter支持的数据库类型相匹配 # buffer_size表示插入或更新数据缓存到多少才flush(即向数据库插入或更新),None表示在析构时flush,0表示不缓存 # update_on_duplicate当唯一键重复时的行为,默认是update,设置为False表示不更新,即忽略插入失败。 def __init__(self, engine, buffer_size=None, update_on_duplicate=True): pass
# tablename为数据库表名 # pk为主键的元组,可以不是真正的表主键,但是可以用来判重决定insert还是update,例如('exchange_id', 'trade_id') # data为单条数据,dict的形式,例如{'exchange_id': 'DCE', 'trade_id': ' 1', 'price': 1.2, 'volume': 1} def upsert(self, tablename, pk, data): pass
# 立即把缓冲器的数据推到数据库,会在buffer_size满了或者析构时自动调用,也可以手动调用 def flush(self): pass ''' class UpserterBase: def __init__(self, engine, buffer_size=None, update_on_duplicate=True, field_quote_mark=''): self.engine = engine self.session_maker = sessionmaker(expire_on_commit=False) self.session_maker.configure(bind=engine) self.tablename2pk = {} self.tablename2datas = {} self.buffer_size = buffer_size self.update_on_duplicate = update_on_duplicate self.field_quote_mark = field_quote_mark
def __del__(self): self.flush()
@contextmanager def session_scope(self): session = self.session_maker() try: yield session session.commit() except Exception: session.rollback() raise finally: session.close()
def flush(self): for (tablename, pk) in self.tablename2pk.items(): datas = self.tablename2datas[tablename] if len(datas) > 0: with self.session_scope() as session: self._flush(session, tablename, pk, datas) self.tablename2datas[tablename] = []
def _gen_batch_insert_sql(self, tablename, datas): quote_mark = self.field_quote_mark columns = datas[0].keys() sql = f"""INSERT INTO {tablename}({quote_mark}{f"{quote_mark}, {quote_mark}".join(columns)}{quote_mark}) VALUES\n""" for i, data in enumerate(datas): if i != len(datas) - 1: sql += f""" ({self._format_values(data.values())}),\n""" else: sql += f""" ({self._format_values(data.values())});\n""" return sql
def _flush(self, session, tablename, pk, datas): sql = self._gen_batch_insert_sql(tablename, datas) try: session.execute(text(sql)) except sqlalchemy.exc.IntegrityError as e: if self.is_duplicate_key(e): if len(datas) <= 500: for data in datas: self.upsert_one(session, tablename, pk, data) else: l = len(datas) p = int(l // 2) self._flush(session, tablename, pk, datas[:p]) self._flush(session, tablename, pk, datas[p:]) else: raise e
def upsert_one(self, session, tablename, pk, data): quote_mark = self.field_quote_mark r = None if self.update_on_duplicate: update_str = self._format_update_values(pk, data) if self.update_on_duplicate and update_str.strip(): r = session.execute(text(f"UPDATE {tablename} SET {update_str} WHERE {self._format_update_conditions(pk, data)}")) if not r or r.rowcount == 0: try: r = session.execute(text( f"INSERT INTO {tablename}({quote_mark}{f'{quote_mark}, {quote_mark}'.join(data.keys())}{quote_mark}) VALUES ({self._format_values(data.values())})" )) except sqlalchemy.exc.IntegrityError as e: if self.is_duplicate_key(e): pass else: raise e
@staticmethod def is_duplicate_key(e): if type(e) != sqlalchemy.exc.IntegrityError: return False return (str(e.orig).lower().find("duplicate") != -1)
def _isinf(self, x): return x>=9223372036854775807 or x<=-9223372036854775808
def _format_value(self, v): if v is None: return "null" elif type(v) == float or type(v) == decimal.Decimal: if math.isnan(v) or math.isinf(v) or self._isinf(v): return "null" else: return f"{v}" elif type(v) == int: if self._isinf(v): return "null" else: return f"{v}" elif type(v) == datetime.datetime: return "'"+v.strftime("%Y-%m-%d %H:%M:%S")+"'" elif type(v) == datetime.date: return "'"+v.strftime("%Y-%m-%d")+"'" elif type(v) == pd.Timestamp: return "'"+v.strftime("%Y-%m-%d %H:%M:%S")+"'" elif type(v) == str: return repr(v).replace(r'%',r"%%") else: return repr(f'{v}').replace(r'%',r"%%")
def _format_values(self, data): s = '' for i, e in enumerate(data): s += self._format_value(e) s += ', ' return s[:-2]
def _format_update_values(self, pk, data): s = '' for i, (k, v) in enumerate(data.items()): if k not in pk: s += f"{self.field_quote_mark}{k}{self.field_quote_mark}={self._format_value(v)}, " return s[:-2]
def _format_update_conditions(self, pk, data): s = '' for i, (k, v) in enumerate(data.items()): if k in pk: s += f"{self.field_quote_mark}{k}{self.field_quote_mark}={self._format_value(v)} and " return s[:-4]
def upsert(self, tablename, pk, data): if self.buffer_size is not None and self.buffer_size == 0: with self.session_scope() as session: self.upsert_one(session, tablename, pk, data) else: if pk: self.tablename2pk[tablename] = pk if tablename not in self.tablename2datas: self.tablename2datas[tablename] = [] self.tablename2datas[tablename].append(data)
if self.buffer_size is not None and len(self.tablename2datas[tablename]) >= self.buffer_size: with self.session_scope() as session: self._flush(session, tablename, self.tablename2pk[tablename], self.tablename2datas[tablename]) self.tablename2datas[tablename] = []
def upsert_dataframe(self, tablename, pk, df): for index, row in df.iterrows(): self.upsert(tablename, pk, row.to_dict())
class MySQLUpserter(UpserterBase): def __init__(self, engine, buffer_size=None, update_on_duplicate=True): super().__init__(engine, buffer_size, update_on_duplicate, field_quote_mark='`')
def __del__(self): super().__del__()
@staticmethod def is_duplicate_key(e): if type(e) != sqlalchemy.exc.IntegrityError: return False if len(e.orig.args) > 1 and str(e.orig.args[1]).startswith("Duplicate entry"): return True return False
def upsert_one(self, session, tablename, pk, data): if self.update_on_duplicate: update_str = self._format_update_values(pk, data) if self.update_on_duplicate and update_str.strip(): sql = f"""INSERT INTO {tablename}(`{"`, `".join(data.keys())}`) VALUES ({self._format_values(data.values())}) ON DUPLICATE KEY UPDATE {update_str}\n""" else: sql = f"""INSERT INTO {tablename}(`{"`, `".join(data.keys())}`) VALUES ({self._format_values(data.values())}) ON DUPLICATE KEY UPDATE `{pk[0]}`=VALUES(`{pk[0]}`)\n""" session.execute(text(sql))
class PSQLUpserter(UpserterBase): def __init__(self, engine, buffer_size=None, update_on_duplicate=True): super().__init__(engine, buffer_size, update_on_duplicate, field_quote_mark='"')
def __del__(self): super().__del__()
@staticmethod def is_duplicate_key(e): if type(e) != sqlalchemy.exc.IntegrityError: return False if str(e.orig).startswith("duplicate key"): return True return False
def upsert_one(self, session, tablename, pk, data): if self.update_on_duplicate: update_str = self._format_update_values(pk, data) if self.update_on_duplicate and update_str.strip(): sql = f"""INSERT INTO {tablename}("{'", "'.join(data.keys())}") VALUES ({self._format_values(data.values())}) on conflict ("{'", "'.join(pk)}") do update set {update_str}\n""" else: sql = f"""INSERT INTO {tablename}("{'", "'.join(data.keys())}") VALUES ({self._format_values(data.values())}) on conflict ("{'", "'.join(pk)}") do nothing\n""" session.execute(text(sql))
class DummyUpserter(UpserterBase): def __init__(self, engine): super().__init__(engine)
def upsert(self, tablename, pk, data): pass
def upsert_one(self, session, tablename, pk, data): pass
def upsert_dataframe(self, tablename, pk, df): pass
|