1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
| import re
def bind(func, *args, **kw): return lambda *_args, **_kw: func(*args, *_args, **kw, **_kw)
class Token: def __init__(self, type, value): self.type = type self.value = value
def tokenize(expression): token_specification = [ ('STOCKSET', r'[0-9a-zA-Z]+\.[0-9a-zA-Z]+'), ('PLUS', r'\+'), ('MINUS', r'-'), ('INTERSECT', r'\&'), ('LPAREN', r'\('), ('RPAREN', r'\)'), ('WS', r'\s+'), ] tok_regex = '|'.join('(?P<%s>%s)' % pair for pair in token_specification) token_list = [] for mo in re.finditer(tok_regex, expression): kind = mo.lastgroup value = mo.group() if kind == 'WS': continue token_list.append(Token(kind, value)) return token_list
class ASTNode: def __init__(self, type, value=None, left=None, right=None): self.type = type self.value = value self.left = left self.right = right
def parse_expression(tokens): def parse_term(): if tokens[0].type == 'STOCKSET': value = tokens.pop(0).value return ASTNode('STOCKSET', value=value) elif tokens[0].type == 'LPAREN': tokens.pop(0) node = parse_expression(tokens) if tokens[0].type != 'RPAREN': raise SyntaxError("Invalid syntax") tokens.pop(0) return node else: raise SyntaxError("Invalid syntax")
node = parse_term() while tokens: if tokens[0].type == 'RPAREN': break if tokens[0].type not in ('PLUS', 'MINUS', 'INTERSECT'): raise SyntaxError("Invalid syntax") op = tokens.pop(0).type right = parse_term() node = ASTNode(op, left=node, right=right) return node
def evaluate_ast(node, code2set): if node.type == 'STOCKSET': return code2set(node.value) elif node.type == 'PLUS': return evaluate_ast(node.left, code2set).union(evaluate_ast(node.right, code2set)) elif node.type == 'MINUS': return evaluate_ast(node.left, code2set) - evaluate_ast(node.right, code2set) elif node.type == 'INTERSECT': return evaluate_ast(node.left, code2set) & evaluate_ast(node.right, code2set)
class StockSetExpression: def __init__(self, exp_str): exp_str = exp_str tokens = tokenize(exp_str) self.ast = parse_expression(tokens)
def evaluate(self, code2set): return evaluate_ast(self.ast, code2set)
if __name__ == "__main__":
expression = "(000300.SH + 000905.SH) & 000852.SH - 399006.SZ"
def get_index_con(index_code, trade_date): if index_code == "000300.SH": return {"000001.SZ", "000002.SZ", "000003.SZ"} if index_code == "000905.SH": return {"000002.SZ", "000003.SZ", "000004.SZ"} if index_code == "000852.SH": return {"000003.SZ", "000004.SZ", "000005.SZ"} if index_code == "399006.SZ": return {"000004.SZ", "000005.SZ", "000006.SZ"}
exp_obj = StockSetExpression(expression) print("Result of {}: {}".format( expression, exp_obj.evaluate(bind(get_index_con, trade_date='2024-12-17')) ))
|