1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
| import re
class Token: def __init__(self, type, value): self.type = type self.value = value
def tokenize(expression): token_specification = [ ('STOCKSET', r'(ALL|[0-9a-zA-Z]+\.[0-9a-zA-Z]+)'), ('PLUS', r'\+'), ('MINUS', r'-'), ('INTERSECT', r'\&'), ('LPAREN', r'\('), ('RPAREN', r'\)'), ('WS', r'\s+'), ] tok_regex = '|'.join('(?P<%s>%s)' % pair for pair in token_specification) token_list = [] for mo in re.finditer(tok_regex, expression): kind = mo.lastgroup value = mo.group() if kind == 'WS': continue token_list.append(Token(kind, value)) return token_list
class ASTNode: def __init__(self, type, value=None, left=None, right=None): self.type = type self.value = value self.left = left self.right = right
def parse_expression(tokens):
def _parse_expression(tokens): node = _parse_term(tokens) while tokens and (tokens[0].type in ('PLUS', 'MINUS')): op = tokens.pop(0).type right = _parse_term(tokens) node = ASTNode(op, left=node, right=right)
return node
def _parse_term(tokens): node = _parse_factor(tokens) while tokens and (tokens[0].type in ('INTERSECT',)): op = tokens.pop(0).type right = _parse_factor(tokens) node = ASTNode(op, left=node, right=right) return node
def _parse_factor(tokens): if not tokens: raise SyntaxError("Invalid syntax") if tokens[0].type == 'STOCKSET': value = tokens.pop(0).value return ASTNode('STOCKSET', value=value) elif tokens[0].type == 'LPAREN': tokens.pop(0) node = _parse_expression(tokens) if not tokens or tokens[0].type!= 'RPAREN': raise SyntaxError("Invalid syntax") tokens.pop(0) return node else: raise SyntaxError("Invalid syntax")
ret = _parse_expression(tokens) if len(tokens): raise SyntaxError("Invalid syntax") return ret
def evaluate_ast(node, code2set): if node.type == 'STOCKSET': return code2set(node.value) elif node.type == 'PLUS': return evaluate_ast(node.left, code2set).union(evaluate_ast(node.right, code2set)) elif node.type == 'MINUS': return evaluate_ast(node.left, code2set) - evaluate_ast(node.right, code2set) elif node.type == 'INTERSECT': return evaluate_ast(node.left, code2set) & evaluate_ast(node.right, code2set)
class StockSetExpression: def __init__(self, exp_str): exp_str = exp_str tokens = tokenize(exp_str) self.ast = parse_expression(tokens)
def evaluate(self, code2set): return evaluate_ast(self.ast, code2set)
def get_stock_set_codes(self): ret = [] def dfs(node): if node.type == 'STOCKSET': ret.append(node.value) else: if node.left: dfs(node.left) if node.right: dfs(node.right) dfs(self.ast) return ret
if __name__ == "__main__": def bind(func, *args, **kw): return lambda *_args, **_kw: func(*args, *_args, **kw, **_kw)
expression = "000300.SH + 000905.SH & 000852.SH - 399006.SZ"
def get_index_con(index_code, trade_date): if index_code == "000300.SH": return {"000001.SZ", "000002.SZ", "000003.SZ"} if index_code == "000905.SH": return {"000002.SZ", "000003.SZ", "000004.SZ"} if index_code == "000852.SH": return {"000003.SZ", "000004.SZ", "000005.SZ"} if index_code == "399006.SZ": return {"000004.SZ", "000005.SZ", "000006.SZ"}
exp_obj = StockSetExpression(expression) print("Result of {}: {}".format( expression, exp_obj.evaluate(bind(get_index_con, trade_date='2024-12-17')) ))
print(exp_obj.get_stock_set_codes())
|