0%

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# -*- coding: utf-8 -*-
import traceback
import smtplib
import os
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from email.mime.application import MIMEApplication
from email.utils import formataddr as _formataddr

def formataddr(addr):
return ','.join([_formataddr([a.split("@")[0].strip(), a.strip()]) for a in addr.split(',')])

"""
smtp_server: SMTP服务器地址
from_addr: 发件人地址
to_addr: 收件人地址(多个地址用英文逗号分割)
subject: 邮件主题
content: 邮件内容
attachmentpaths: 附件的文件路径
html: 邮件内容是否是HTML格式
ssl: 使用SMTP_SLL(465端口)还是SMTL(25端口)
"""
def send_email(smtp_server, from_addr, password, to_addr, subject, content, attachmentpaths=[], html=False, ssl=True):
try:
msg = MIMEMultipart()
msg['From'] = formataddr(from_addr) # 括号里的对应发件人邮箱昵称、发件人邮箱账号
msg['To'] = formataddr(to_addr) # 括号里的对应收件人邮箱昵称、收件人邮箱账号
msg['Subject'] = subject # 邮件标题
if html:
context_part = MIMEText(content, 'html', 'utf-8')
else:
context_part = MIMEText(content, 'plain', 'utf-8')
msg.attach(context_part)

if attachmentpaths:
for path in attachmentpaths:
filedir, filename = os.path.split(path)
part = MIMEApplication(open(path,'rb').read())
part.add_header('Content-Disposition', 'attachment', filename=filename)
msg.attach(part)

if ssl:
server=smtplib.SMTP_SSL(smtp_server, 465)
else:
server=smtplib.SMTP(smtp_server, 25) # 发件人邮箱中的SMTP服务器,端口是25
server.login(from_addr, password) # 括号中对应的是发件人邮箱账号、邮箱密码
server.sendmail(from_addr, to_addr.split(','), msg.as_string())
server.quit() # 关闭连接
return True
except Exception:
traceback.print_exc()
return False

if __name__ == '__main__':
# test

send_email("smtp.163.com", "kyo_86@163.com", "XXXXXXXXX", "kyo_86@163.com",
f"测试HTML内容和附件发送",
"""
<html>
<body>
<img src="http://kyo86.com/images/saber.jpg"/>
</body>
</html>
""",
attachmentpaths=[r"./mailer.py"], html=True, ssl=True)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
# -*- coding: utf-8 -*-
import imaplib
imaplib._MAXLINE = 10000000
import email
import email.utils
import re
import os
import datetime
import pytz
import traceback
from dataclasses import dataclass


def decode(s, charset):
if type(s) is str:
return s
try:
return s.decode(charset)
except Exception:
pass
try:
return s.decode('utf-8')
except Exception:
pass
try:
return s.decode('latin1')
except Exception as e:
pass
return s.decode('gbk')


class Attachment:
def __init__(self, part):
self.content_type = part.get_content_type()
raw_filename = part.get_filename() # .strip()
# print(dir(part), raw_filename)
if raw_filename.startswith("=?") and raw_filename.endswith("?="):
dh = email.header.decode_header(raw_filename)
self.filename = decode(dh[0][0], dh[0][1])
else:
h = email.header.Header(raw_filename)
dh = email.header.decode_header(h)
self.filename = decode(dh[0][0], dh[0][1])
self.data = part.get_payload(decode=True) #下载附件

def __repr__(self):
return f"Attachment(content_type='{self.content_type}', filename='{self.filename}', size={len(self.data)})"

def save_to(self, path):
if os.path.exists(path):
if os.path.isdir(path): # 已附件原文件名保存到目录下
path = os.path.join(path, self.filename)
with open(path, 'wb') as fp:
fp.write(self.data)
else: # 覆盖已存在文件
with open(path, 'wb') as fp:
fp.write(self.data)
else: # 新建文件
with open(path, 'wb') as fp:
fp.write(self.data)


class Mail:
def __init__(self, num, msg):
# 这些字段是在读取邮件列表时就解析的
self.num = num
self.subject: str = self._decode_value(msg.get("subject"))
date = email.utils.parsedate_to_datetime(msg.get("date"))
if date:
timezone = pytz.timezone('Asia/Shanghai')
date = date.astimezone(timezone) # 设置时区为+8区
date = date.replace(tzinfo=None) # 移除时区信息
self.date: str = str(date) if date else msg.get("date")
from_name, self.from_addr = email.utils.parseaddr(msg.get("from"))
self.from_name = self._decode_value(from_name)
to_name, self.to_addr = email.utils.parseaddr(msg.get("to"))
self.to_name = self._decode_value(to_name)

# 这些字段是延迟到需要访问时才解析的
self._plain: str = ""
self._html: str = ""
self._attachments: list = []

self._msg: str = msg
self._parsed: bool = False

def _decode_value(self, value):
try:
header = email.header.decode_header(value)
raw_value = email.header.decode_header(value)[0][0]
charset = email.header.decode_header(value)[0][1]
return decode(raw_value, charset)
except Exception as e:
print(f"decode failed [value] {value}")
traceback.print_exc()
return None

@property
def plain(self):
# 为了延迟解析邮件内容
if not self._parsed:
self.parse_content()
return self._plain

@property
def html(self):
if not self._parsed:
self.parse_content()
return self._html

@property
def attachments(self):
if not self._parsed:
self.parse_content()
return self._attachments

# 解析mail的内容
def parse_content(self):
self._attachments = []
for part in self._msg.walk():
if part.is_multipart():
continue
if part.get_content_type() == "text/plain":
charset = part.get_content_charset()
content = decode(part.get_payload(decode=True), charset)
self._plain = content
if part.get_content_type() == "text/html":
charset = part.get_content_charset()
content = decode(part.get_payload(decode=True), charset)
self._html = content

if part.get_content_disposition():
if part.get_content_disposition() == "inline":
# HTML内容引用的图片之类的
pass
elif part.get_content_disposition() == "attachment":
# 附件
self._attachments.append(Attachment(part))
if self._plain:
self._html = ""
self._parsed = True


class ImapMailBox:
def __init__(self, host, username, password, port=None, ssl=None):
self.host = host
self.port = port
self.ssl = ssl
self.username = username
self.password = password
if ssl is None and port == 993:
ssl = True
if ssl:
# connecting to host via SSL
self.conn = imaplib.IMAP4_SSL(host=host, port=port or 993)
else:
self.conn = imaplib.IMAP4(host=host, port=port or 143)
# logging in to servers
self.conn.login(username, password)

def get_mail_count(self):
# Selecting the inbox of the logged in account
self.conn.select('Inbox')
state, data = self.conn.search(None, 'ALL')
mail_list = []
mails = data[0].split()
return len(mails)

def get_mail_list(self, page=1, page_size=50):
# Selecting the inbox of the logged in account
self.conn.select('Inbox')
state, data = self.conn.search(None, 'ALL')
mail_list = []
mails = data[0].split()[::-1]
if page_size:
mails = mails[(page-1)*page_size: page*page_size]
for num in mails:
state, data = self.conn.fetch(num, '(RFC822)')
raw_email = data[0][1]
try:
msg = email.message_from_bytes(raw_email)
mail = Mail(num, msg)
mail_list.append(mail)
except Exception as e:
print(f"Parse raw data failed. [raw_data] '{raw_email}'")
traceback.print_exc()
return mail_list

def mark_as_seen(self, mail):
self.conn.store(mail.num, '+FLAGS', '\\seen')


if __name__ == '__main__':
mailbox = ImapMailBox(
host='imap.aliyun.com', port=993,
username="******", password="******"
)
count = mailbox.get_mail_count()
# 收件箱里的邮件数
print(count)
# 分页获取邮件
for mail in mailbox.get_mail_list(page=1, page_size=25):
# 打印 日期、发件人、标题、纯文本内容
print(mail.date, mail.from_addr, mail.subject, mail.plain)

# 如果有附件,就下载保存到本地
if mail.attachments:
for attachment in mail.attachments:
attachment.save_to("./")

在LeetCode上提交的C代码并不需要include标准库头文件,判题系统会自动包含,并且在二叉树的题目会额外包含TreeNode结构。我希望有一个简洁的main.cpp可以直接提交到LeetCode,里面包含一些最常用的函数,调试输出的代码放在bits/stdc.h中(故意和标准库头文件同名),通过宏定义只在引入本地自定义的bits/stdc++.h时开启cout输出,这样可以直接把main.cpp的全部内容直接提交到网站条件,会自动屏蔽掉所有调试输出代码而不会报错。

main.cpp

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#include "bits/stdc++.h"
using namespace std;

#define all(a) (a).begin(), (a).end()
template<typename T, typename F=less<T>>
using pque=priority_queue<T, vector<T>, F>;
template<typename T, typename F=less<T>>
pque<T, F> make_pque(F cmp) { return pque<T, F>(cmp); }
typedef long long ll;
typedef vector<int> vi;
typedef vector<vi> vvi;
typedef vector<ll> vl;
typedef vector<vl> vvl;
typedef vector<bool> vb;
typedef vector<vb> vvb;
typedef vector<string> vs;
typedef vector<vs> vvs;

template <typename T1, typename T2>
auto max(T1 a, T2 b) -> decltype(a + b) { return a>b?a:b; }
template <typename T1, typename T2>
auto min(T1 a, T2 b) -> decltype(a + b) { return a<b?a:b; }
template <typename T1, typename T2, typename T3>
auto max(T1 a, T2 b, T3 c) -> decltype(a + b + c) { return max(max(a, b), c); }
template <typename T1, typename T2, typename T3>
auto min(T1 a, T2 b, T3 c) -> decltype(a + b + c) { return min(min(a, b), c); }
template <typename T>
T max(const vector<T>& a) {T r = a[0]; for(auto e : a) r = max(r, e); return r;}
template <typename T>
T min(const vector<T>& a) {T r = a[0]; for(auto e : a) r = min(r, e); return r;}
template<typename T>
T sum(const vector<T>& a) { T r = 0; for(auto& e : a) r+=e; return r;}
template<typename T>
T gcd(T a, T b) { while(b) { T r = a%b; a = b; b = r;} return a;}
template<typename T>
T lcm(T a, T b) { return a/gcd(a,b)*b; }
template<typename F>
ll lb(ll b, ll e, F f) {if(b>=e) return e; while(b<e-1) {auto m=b+(e-1-b)/2; if(!f(m)) b=m+1; else e=m+1;} return f(b)?b:e;}
template<typename F>
ll ub(ll b, ll e, F f) {return lb(b, e, [&](ll i){return !f(i);});}

template<typename T>
struct cast_helper {T operator() (stringstream& ss) {T r=T{}; ss >> r; return r;}};
template<>
struct cast_helper<string> {string operator() (stringstream& ss) { return ss.str();}};
template<typename R, typename T>
R sstream_cast(const T& o) {stringstream ss; ss << o; return cast_helper<R>()(ss);}
string format(const char* f, ...){va_list a; va_start(a, f); char b[4096]; vsnprintf(b, 4096, f, a); va_end(a); return b;}
template<typename T>
unordered_map<T, int> counter(const vector<T>& a) {unordered_map<T, int> r;for(auto e : a) ++r[e];return r;}
unordered_map<char, int> counter(const string& a) {unordered_map<char, int> r;for(auto e : a) ++r[e];return r;}
template<typename I>
vector<I> range(I b, I e) {vector<I> r(e-b);iota(all(r), b);return r;}

vvi make_vvi(int n, int m, int v=0) { return vvi(n, vi(m, v));}
vvb make_vvb(int n, int m, bool v=false) { return vvb(n, vb(m, v));}
vvs make_vvs(int n, int m, const string& v="") { return vvs(n, vs(m, v));}
typedef tuple<int, int> tii;
typedef tuple<ll, ll> tll;
typedef tuple<int, int, int> tiii;
#define _0(o) get<0>(o)
#define _1(o) get<1>(o)
#define _2(o) get<2>(o)
namespace std {
template<>struct hash<tii>{size_t operator()(const tii& a)const {return _0(a)^_1(a);}};
template<>struct hash<tiii>{size_t operator()(const tiii& a)const {return _0(a)^_1(a)^_2(a);}};
};
tii dir4[] = {tii(-1, 0), tii(0, 1), tii(1, 0), tii(0,-1)};
tii dir8[] = {tii(-1, 0), tii(0, 1), tii(1, 0), tii(0,-1), tii(-1,-1), tii(-1,1), tii(1,1), tii(1,-1)};
tii& operator += (tii& a, const tii& b) { _0(a)+=_0(b); _1(a)+=_1(b); return a; }
tii operator + (const tii& a, const tii& b) { return tii(_0(a)+_0(b), _1(a)+_1(b)); }
tii& operator -= (tii& a, const tii& b) { _0(a)-=_0(b); _1(a)-=_1(b); return a; }
tii operator - (const tii& a, const tii& b) { return tii(_0(a)-_0(b), _1(a)-_1(b)); }
tii operator - (const tii& a) { return tii(-_0(a), -_1(a)); }
tii& operator *= (tii& a, int k) { _0(a)*=k; _1(a)*=k; return a; }
tii operator * (int k, const tii& a) { return tii(k*_0(a), k*_1(a)); }
tii operator * (const tii& a, int k) { return tii(_0(a)*k, _1(a)*k); }
tii& operator /= (tii& a, int k) { _0(a)/=k; _1(a)/=k; return a; }
tii operator / (const tii& a, int k) { return tii(_0(a)/k, _1(a)/k); }
bool in_range(tii p, tii e) {return 0<=_0(p)&&_0(p)<_0(e)&&0<=_1(p)&&_1(p)<_1(e);}
bool in_range(tii p, tii b, tii e) {return _0(b)<=_0(p)&&_0(p)<_0(e)&&_1(b)<=_1(p)&&_1(p)<_1(e);}

constexpr int INF = 1e9+7;
constexpr int MOD = 1e9+7;

#ifndef cout
struct _ {
template <typename T>
_& operator << (const T&){ return *this; }
};
#define cout _()
#define endl '\n'
#endif


int _main_()
{

return 0;
}

#undef cout
#undef endl

一定要在结尾undef掉,不然会导致LeetCode误判

bits/stdc++.h

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
// #include <bits/stdc++.h>
#include <iostream>
#include <list>
#include <vector>
#include <stack>
#include <queue>
#include <string>
#include <cstring>
#include <cstdlib>
#include <set>
#include <map>
#include <tuple>
#include <sstream>
#include <fstream>
#include <algorithm>
#include <functional>
#include <numeric>
#include <iomanip>
#include <thread>
#include <chrono>
#include <unordered_set>
#include <unordered_map>
#include <bitset>
#include <cassert>
#include <cmath>
#include <cstdarg>
using namespace std;

template<typename T>
T _s(const T& t) {return t;}
string _s(const string& t) {return '"'+ t + '"';}

template<typename T1, typename T2>
ostream& operator << (ostream& out, const pair<T1, T2>& o) {
return out << "(" << o.first << ", " << o.second << ")";
}

template<typename T>
ostream& operator << (ostream& out, const vector<T>& v) {
out << "[";
for(int i=0;i<v.size();i++) {
out << _s(v[i]) << (i != v.size() -1 ? ", " : "");
}
return out << "]";
}

template<typename T>
ostream& operator << (ostream& out, const vector<vector<T>>& v) {
out << "[";
for(int i=0;i<v.size();i++) {
out << (i!=0?" ":"") << "[";
for(int j=0;j<v[i].size();j++) {
out << setw(5) << _s(v[i][j]) << (j != v[i].size() -1 ? ", " : "");
}
out << "]" << (i != v.size() -1 ? ",\n" : "");
}
return out << "]";
}

template<typename TK, typename TV>
ostream& operator << (ostream& out, const map<TK,TV>& m) {
out << "{";
auto itr=m.begin();
if(itr != m.end()) {
out << *itr;
for(itr++;itr!=m.end();itr++) {
out << ", " << *itr;
}
}
return out << "}";
}

template<typename TK, typename TV>
ostream& operator << (ostream& out, const multimap<TK,TV>& m) {
out << "{";
auto itr=m.begin();
if(itr != m.end()) {
out << *itr;
for(itr++;itr!=m.end();itr++) {
out << ", " << *itr;
}
}
return out << "}";
}

template<typename TK, typename TV>
ostream& operator << (ostream& out, const unordered_map<TK,TV>& m) {
out << "{";
auto itr=m.begin();
if(itr != m.end()) {
out << *itr;
for(itr++;itr!=m.end();itr++) {
out << ", " << *itr;
}
}
return out << "}";
}

template<typename T>
ostream& operator << (ostream& out, const set<T>& m) {
out << "{";
auto itr=m.begin();
if(itr != m.end()) {
out << *itr;
for(itr++;itr!=m.end();itr++) {
out << ", " << *itr;
}
}
return out << "}";
}

template<typename T>
ostream& operator << (ostream& out, const multiset<T>& m) {
out << "{";
auto itr=m.begin();
if(itr != m.end()) {
out << *itr;
for(itr++;itr!=m.end();itr++) {
out << ", " << *itr;
}
}
return out << "}";
}

template<typename T>
ostream& operator << (ostream& out, const unordered_set<T>& m) {
out << "{";
auto itr=m.begin();
if(itr != m.end()) {
out << *itr;
for(itr++;itr!=m.end();itr++) {
out << ", " << *itr;
}
}
return out << "}";
}

template <size_t N>
struct PrintHelper;

template <>
struct PrintHelper<1>
{
template<typename... Args>
static void recursive_print(ostream& out, const tuple<Args...> t)
{
out << "(" << std::get<0>(t) << ", ";
}
};

template <size_t N>
struct PrintHelper
{
template<typename... Args>
static void recursive_print(ostream& out, const tuple<Args...> t)
{
PrintHelper<N - 1>::recursive_print(out, t);
out << std::get<N - 1>(t) << ", ";
}

template<typename... Args>
static void print(ostream& out, const tuple<Args...> t)
{
PrintHelper<N - 1>::recursive_print(out, t);
out << std::get<N - 1>(t) << ")";
}
};

template <typename... Args>
ostream& operator << (ostream& out, const tuple<Args...> t)
{
PrintHelper<tuple_size<decltype(t)>::value >::print(out, t);
return out;
}

template<typename T>
struct _cast_helper
{
T operator() (stringstream& ss) {
T result;
ss >> result;
return result;
}
};

template<>
struct _cast_helper<string>
{
string operator() (stringstream& ss) {
return ss.str();
}
};

template<typename R, typename T>
R _sstream_cast(const T& o) {
stringstream ss;
ss << o;
return _cast_helper<R>()(ss);
}

template<typename T>
vector<T> split(const string& s, const string& delim, const string& stripchars="", bool drop_empty=false) {
vector<T> result;
int b = 0;
int e = 0;
int i = b;
int state = 0;
do {
bool isspace = (stripchars.find(s[i]) != -1);
bool isdelim = (s[i]=='\0' || delim.find(s[i]) != -1);
if(isdelim) {
if(e != b || !drop_empty) {
result.emplace_back(_sstream_cast<T>(string(&s[b], &s[e])));
}
state = 0;
e = b = i + 1;
} else if(isspace) {
if(state == 0) {
e = b = i + 1;
}
} else {
state = 1;
e = i + 1;
}
if(s[i]=='\0') break;
i++;
} while(true);
return result;
}

typedef vector<int> vi;
typedef vector<vi> vvi;
vi make_vi(const string& s) {return split<int>(s, ",", "[] ", true);}
vvi make_vvi(const string& s) { vvi r; for(auto e : split<string>(s, "[]", ", ", true)) r.emplace_back(make_vi(e)); return r;}

/**
* Definition for singly-linked list.
*/
struct ListNode {
int val;
ListNode *next;
ListNode() : val(0), next(nullptr) {}
ListNode(int x) : val(x), next(nullptr) {}
ListNode(int x, ListNode *next) : val(x), next(next) {}
};

/**
* Definition for a binary tree node.
*/
struct TreeNode {
int val;
TreeNode *left;
TreeNode *right;
TreeNode(int x) : val(x), left(NULL), right(NULL) {}
};


#define cout cout

int _main_();
int main() {
std::thread th1([](){this_thread::sleep_for(chrono::seconds(3));cerr << "!!!Timeout!!!" << endl;exit(1);});
th1.detach();
return _main_();
}

今天在B站看到一个视频《斐波那契数列,全网最优解》,UP主给出了求解斐波那契数列通项公式的推导思路。

因为我早年也对这种数列有过研究,而且记得一个更简单的解法,所以记录一下。

斐波那契数列是这样一种数列:

a(1) = a(2) = 1

a(n) = a(n-1) + a(n-2), n>=2

上面是通过递推公式的形式给出的定义,我们注意到递推公式是前两项的线性组合。而线性变换可以通过矩阵表示,我们不妨转换思路来求向量

(an,an+1)(a_n, a_{n+1})

的通项公式

我们写出根据a(n-1), a(n)推得a(n), a(n+1)的递推公式

{an=anan+1=an1+an\begin{cases} a_n = a_n \\ a_{n+1} = a_{n-1} + a_n \end{cases}

写成矩阵的形式

(anan+1)=(an1an)[0111]\begin{pmatrix}a_n & a_{n+1} \\ \end{pmatrix} = \begin{pmatrix}a_{n-1} & a_n \\ \end{pmatrix} \begin{bmatrix}0 & 1 \\ 1 & 1\\ \end{bmatrix}

我们可以看到这就类似等比数列的递推公式,只不过公比q是个矩阵,等比数列通项公式是

an=a1qn1a_n = a_1 * q^{n-1}

类比得到,上面递推公式的通项公式

(anan+1)=(a1a2)[0111]n1=(11)[0111]n1\begin{pmatrix}a_n & a_{n+1} \\ \end{pmatrix} = \begin{pmatrix}a_1 & a_2 \\ \end{pmatrix} \begin{bmatrix}0 & 1 \\ 1 & 1\\ \end{bmatrix}^{n-1} = \begin{pmatrix}1 & 1 \\ \end{pmatrix} \begin{bmatrix}0 & 1 \\ 1 & 1\\ \end{bmatrix}^{n-1}

BTW: 其实这里矩阵和数还是有点区别的,要利用矩阵乘法有结合律(本来是先做向量和矩阵乘法的,通项公式是先做了后面的矩阵乘法最后再让向量左乘矩阵),而且是方阵才能求幂,而这里都是满足的。

对于斐波那契数列的变形也特别容易推导,无论是改变首项还是改变递推关系,包括把两项和变成前n项的线性组合,只要还是线性的,就可以这么推导。

之前用在SQLAlchemy的ORM模型的类名(驼峰风格)和数据库表名(下划线风格)的转换。

Python类名驼峰风格这个不用解释,数据库表名使用下划线风格主要是因为一些数据库系统如果使用了带大写字母的表名,那么在select、insert、update、delete语句中都要用特殊分割符包住表名才能使用,很麻烦。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# 驼峰转下划线
def hump_to_underline(name, drop_first_underline=True):
result = re.sub(r'([A-Z])', r'_\1', name).lower()
if drop_first_underline and result[0] == '_':
result = result[1:]
return result

# 下划线转驼峰
def underline_to_hump(name, capitalize_first_letter=True):
ret = ""
i = 0
while i < len(name):
if name[i] == '_' and i+1 < len(name) and name[i+1] != '_':
i += 1
ret += name[i].upper()
else:
ret += name[i]
i += 1
if capitalize_first_letter:
return ret[0].upper()+ret[1:]
else:
return ret

基于SQLAlchemy的Upserter,当时是基于SQLAlchemy写的,不过最后似乎没怎么用到SQLAlchemy的特性,只是取了一下数据库的类型。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
# -*- coding: utf-8 -*-

import math
import datetime
import numpy as np
import pandas as pd
import sqlalchemy
import decimal
from sqlalchemy import text
from sqlalchemy.orm import sessionmaker
from contextlib import contextmanager

'''
engine: SQLAlchemy Engine
buffer_size: 缓存条目数,当缓存满时自动flush
update_on_duplicate: 当唯一键重复时的行为,默认是update,设置为False表示不更新,即忽略插入失败。
'''
def create_upserter(engine, buffer_size=5000, update_on_duplicate=True, dummy=False):
if dummy:
return DummyUpserter(engine)
if engine.dialect.name.lower().find("mysql") != -1:
return MySQLUpserter(engine, buffer_size, update_on_duplicate)
elif engine.dialect.name.lower().find("postgresql") != -1:
return PSQLUpserter(engine, buffer_size, update_on_duplicate)
else:
print(f"没有为{engine.dialect.name}实现特殊的Upsert,使用默认版本,请确认可以正常工作,建议特化一个专门版本")
return UpserterBase(engine, buffer_size, update_on_duplicate)


def is_duplicate_key(e):
for T in UpserterBase.__subclasses__():
if T.is_duplicate_key(e):
return True
return UpserterBase.is_duplicate_key(e)


'''
class Upserter:

# 传入的engine类型应该和使用的Upserter支持的数据库类型相匹配
# buffer_size表示插入或更新数据缓存到多少才flush(即向数据库插入或更新),None表示在析构时flush,0表示不缓存
# update_on_duplicate当唯一键重复时的行为,默认是update,设置为False表示不更新,即忽略插入失败。
def __init__(self, engine, buffer_size=None, update_on_duplicate=True):
pass

# tablename为数据库表名
# pk为主键的元组,可以不是真正的表主键,但是可以用来判重决定insert还是update,例如('exchange_id', 'trade_id')
# data为单条数据,dict的形式,例如{'exchange_id': 'DCE', 'trade_id': ' 1', 'price': 1.2, 'volume': 1}
def upsert(self, tablename, pk, data):
pass

# 立即把缓冲器的数据推到数据库,会在buffer_size满了或者析构时自动调用,也可以手动调用
def flush(self):
pass
'''
class UpserterBase:
def __init__(self, engine, buffer_size=None, update_on_duplicate=True, field_quote_mark=''):
self.engine = engine
self.session_maker = sessionmaker(expire_on_commit=False)
self.session_maker.configure(bind=engine)
self.tablename2pk = {}
self.tablename2datas = {}
self.buffer_size = buffer_size
self.update_on_duplicate = update_on_duplicate
self.field_quote_mark = field_quote_mark

def __del__(self):
self.flush()

@contextmanager
def session_scope(self):
session = self.session_maker()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()

def flush(self):
for (tablename, pk) in self.tablename2pk.items():
datas = self.tablename2datas[tablename]
if len(datas) > 0:
with self.session_scope() as session:
self._flush(session, tablename, pk, datas)
self.tablename2datas[tablename] = []

def _gen_batch_insert_sql(self, tablename, datas):
quote_mark = self.field_quote_mark
columns = datas[0].keys()
# ex: INSERT INTO tablename(`field1`, `field2`, `field3`) VALUES
sql = f"""INSERT INTO {tablename}({quote_mark}{f"{quote_mark}, {quote_mark}".join(columns)}{quote_mark}) VALUES\n"""
for i, data in enumerate(datas):
if i != len(datas) - 1:
sql += f""" ({self._format_values(data.values())}),\n"""
else:
sql += f""" ({self._format_values(data.values())});\n"""
return sql

def _flush(self, session, tablename, pk, datas):
sql = self._gen_batch_insert_sql(tablename, datas)
try:
session.execute(text(sql))
except sqlalchemy.exc.IntegrityError as e:
if self.is_duplicate_key(e):
# 插入遇到重复KEY
if len(datas) <= 500:
for data in datas:
self.upsert_one(session, tablename, pk, data)
else:
l = len(datas)
p = int(l // 2)
self._flush(session, tablename, pk, datas[:p])
self._flush(session, tablename, pk, datas[p:])
else:
raise e

def upsert_one(self, session, tablename, pk, data):
quote_mark = self.field_quote_mark
r = None
if self.update_on_duplicate:
update_str = self._format_update_values(pk, data)
if self.update_on_duplicate and update_str.strip():
r = session.execute(text(f"UPDATE {tablename} SET {update_str} WHERE {self._format_update_conditions(pk, data)}"))
if not r or r.rowcount == 0:
try:
r = session.execute(text(
f"INSERT INTO {tablename}({quote_mark}{f'{quote_mark}, {quote_mark}'.join(data.keys())}{quote_mark}) VALUES ({self._format_values(data.values())})"
))
except sqlalchemy.exc.IntegrityError as e:
if self.is_duplicate_key(e):
pass
else:
raise e

@staticmethod
def is_duplicate_key(e):
if type(e) != sqlalchemy.exc.IntegrityError:
return False
return (str(e.orig).lower().find("duplicate") != -1)

def _isinf(self, x):
return x>=9223372036854775807 or x<=-9223372036854775808

def _format_value(self, v):
if v is None:
return "null"
elif type(v) == float or type(v) == decimal.Decimal:
if math.isnan(v) or math.isinf(v) or self._isinf(v):
return "null"
else:
return f"{v}"
elif type(v) == int:
if self._isinf(v):
return "null"
else:
return f"{v}"
elif type(v) == datetime.datetime:
return "'"+v.strftime("%Y-%m-%d %H:%M:%S")+"'"
elif type(v) == datetime.date:
return "'"+v.strftime("%Y-%m-%d")+"'"
elif type(v) == pd.Timestamp:
return "'"+v.strftime("%Y-%m-%d %H:%M:%S")+"'"
elif type(v) == str:
return repr(v).replace(r'%',r"%%") # 可以解决字符串包含'"%
else:
return repr(f'{v}').replace(r'%',r"%%")

def _format_values(self, data):
s = ''
for i, e in enumerate(data):
s += self._format_value(e)
s += ', '
return s[:-2]

def _format_update_values(self, pk, data):
s = ''
for i, (k, v) in enumerate(data.items()):
if k not in pk:
s += f"{self.field_quote_mark}{k}{self.field_quote_mark}={self._format_value(v)}, "
return s[:-2]

def _format_update_conditions(self, pk, data):
s = ''
for i, (k, v) in enumerate(data.items()):
if k in pk:
s += f"{self.field_quote_mark}{k}{self.field_quote_mark}={self._format_value(v)} and "
return s[:-4]

def upsert(self, tablename, pk, data):
if self.buffer_size is not None and self.buffer_size == 0:
with self.session_scope() as session:
self.upsert_one(session, tablename, pk, data)
else:
if pk:
self.tablename2pk[tablename] = pk
if tablename not in self.tablename2datas:
self.tablename2datas[tablename] = []
self.tablename2datas[tablename].append(data)

if self.buffer_size is not None and len(self.tablename2datas[tablename]) >= self.buffer_size:
with self.session_scope() as session:
self._flush(session, tablename, self.tablename2pk[tablename], self.tablename2datas[tablename])
self.tablename2datas[tablename] = []

def upsert_dataframe(self, tablename, pk, df):
for index, row in df.iterrows():
self.upsert(tablename, pk, row.to_dict())


class MySQLUpserter(UpserterBase):
def __init__(self, engine, buffer_size=None, update_on_duplicate=True):
super().__init__(engine, buffer_size, update_on_duplicate, field_quote_mark='`')

def __del__(self):
super().__del__()

@staticmethod
def is_duplicate_key(e):
if type(e) != sqlalchemy.exc.IntegrityError:
return False
if len(e.orig.args) > 1 and str(e.orig.args[1]).startswith("Duplicate entry"):
return True
return False

def upsert_one(self, session, tablename, pk, data):
if self.update_on_duplicate:
update_str = self._format_update_values(pk, data)
if self.update_on_duplicate and update_str.strip():
sql = f"""INSERT INTO {tablename}(`{"`, `".join(data.keys())}`) VALUES
({self._format_values(data.values())})
ON DUPLICATE KEY UPDATE {update_str}\n"""
else:
sql = f"""INSERT INTO {tablename}(`{"`, `".join(data.keys())}`) VALUES
({self._format_values(data.values())})
ON DUPLICATE KEY UPDATE `{pk[0]}`=VALUES(`{pk[0]}`)\n"""
session.execute(text(sql))


class PSQLUpserter(UpserterBase):
def __init__(self, engine, buffer_size=None, update_on_duplicate=True):
super().__init__(engine, buffer_size, update_on_duplicate, field_quote_mark='"')

def __del__(self):
super().__del__()

@staticmethod
def is_duplicate_key(e):
if type(e) != sqlalchemy.exc.IntegrityError:
return False
if str(e.orig).startswith("duplicate key"):
return True
return False

def upsert_one(self, session, tablename, pk, data):
if self.update_on_duplicate:
update_str = self._format_update_values(pk, data)
if self.update_on_duplicate and update_str.strip():
sql = f"""INSERT INTO {tablename}("{'", "'.join(data.keys())}") VALUES
({self._format_values(data.values())})
on conflict ("{'", "'.join(pk)}")
do update set {update_str}\n"""
else:
sql = f"""INSERT INTO {tablename}("{'", "'.join(data.keys())}") VALUES
({self._format_values(data.values())})
on conflict ("{'", "'.join(pk)}")
do nothing\n"""
session.execute(text(sql))


class DummyUpserter(UpserterBase):
def __init__(self, engine):
super().__init__(engine)

def upsert(self, tablename, pk, data):
pass

def upsert_one(self, session, tablename, pk, data):
pass

def upsert_dataframe(self, tablename, pk, df):
pass

std::priority_queue

std::priority_queue 是C++标准库提供的优先队列(最大堆)实现,位于头文件

默认情况下要求元素有“小于”运算,取堆顶,返回最大值。

可以通过模板参数调整排序方式让其返回最小值,或者为自定义类型定义排序方式。

1
2
template <class T, class Container = vector<T>, class Compare = less<typename Container::value_type> >
class priority_queue;

模板参数:

T是数据类型

Container是维护最大(小)堆使用的容器类型

Compare是一个function object的类型,定义了排序方式

什么是function object?

function object是一种对象,这个对象的类重载了括号运算符,也就是 operator() ,所以这个对象可以使用 obj(…),看上去就像在调用一个function一样。

使用比较器类定义优先队列

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#include <queue>
#include <vector>
using namespace std;

struct Student { // 学生
int id; // 学号
int height; // 身高
};

int main()
{
// 定义比较器
struct Cmp {
bool operator() (const Node& a, const Node& b) {
return a.height < b.height;
}
};
// 定义优先队列
priority_queue<Student, vector<Student>, Cmp> PQ;
}

使用lamda表达式排序

1
2
3
4
5
6
7
vector<Student> students = {...};
std::sort(students.begin(), students.end(), [](const Student& a, const Student& b){
return a.height < b.height;
});

// 等价的
std::sort(students.begin(), students.end(), Cmp());

使用lambda表达式的好处是让“比较方法的描述”接近sort的调用,无论从编写还是阅读都是更好的。

使用lambda表达式的坏处是,不方便复用比较方法。

使用lamda表达式定义优先队列

实际上priority_queue有一个构造函数,可以传递一个比较对象,如果不传递就会用模板参数定义默认的比较对象。

1
explicit priority_queue (const Compare& comp = Compare(), Container&& ctnr = Container());

我们可以通过构造函数参数传递一个lambda表达式定义比较方式,我们期望的定义优先队列的方式是

1
2
3
priority_queue<Node> PQ([](const Node& a, const Node& b) {
return a.height < b.height;
});

但是很遗憾,我们并不能这样定义,这会导致编译错误,原因是我们在模板参数仅传递了数据类型T,而没有传递Compare,因此Compare使用了默认的less,而我们传递的lambda表达式显然不是less类型,因此不符合构造函数的参数要求。

使用decltype获取lambda表达式类型

因此我们不得不传递Compare为我们定义的lambda表达式的类型,这里可以使用 decltype 关键字,这个关键字直到C++11才被引入。

1
2
3
4
5
// 通过lambda表达式定义序
auto cmp = [](const Node& a, const Node& b) {
return a.height < b.height;
};
priority_queue<Node, vector<Node>, decltype(cmp)> PQ(cmp);

看上去和通过定义比较器定义优先队列似乎差不多,实际上lambda表达式的魅力在于可以访问当前上下文中的其他变量。

例如:假设我们有一个 vector<Student>存储着学生信息,我们想定义一个存储学号的优先队列priority_queue,依然按照身高对其中学号排序

1
2
3
4
5
6
7
8
9
10
11
vector<Student> students = {...};
unordered_map<int, Student> id2stu;
for(auto& stu: students) {
id2stu[stu.id] = stu;
}

// 我们可以很方便地把id2stu绑定到lambda表达式中用来排序
auto cmp = [&](int a, int b) {
return id2stu[a].height < id2stu[b].height;
};
priority_queue<int, vector<int>, decltype(cmp)> PQ(cmp);

如果用定义比较器类的方式则需要通过构造函数传递id2stu的引用,然后绑定给成员变量。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
vector<Student> students = {...};
unordered_map<int, Student> id2stu;
for(auto& stu: students) {
id2stu[stu.id] = stu;
}

struct Cmp {
const unordered_map<int, Student>& id2stu;
Cmp(const unordered_map<int, Student>& id2stu) : id2stu(id2stu) {}
bool operator() (int a, int b) const {
return id2stu[a].height < id2stu[b].height;
}
};
priority_queue<int, vector<int>, Cmp> PQ;

本质是一样的,但是写法有些累赘。

问题:在字符串s中查找字符串p首次出现的位置。

正常情况下对s和p进行匹配的最坏时间复杂度是O(len(s)*len(p)),我们用i,j分别从s,p的头部进行匹配,每次匹配失败我们回退j到0,i+=1,进行下一轮匹配。

1
2
3
4
5
6
7
8
9
10
11
int find(const string& s, const string& p) {
for(int i=0;i<=s.size()-p.size();i++) {
int j=0; // 每次j都从0开始
for(;j<p.size();j++) {
if(s[i+j] != p[j]) break;
}
if(j == p.size())return i;
// 如果当前以i开始的字串不匹配,则从i+1继续尝试
}
return -1;
}

KMP的思想就是预处理p得到next数组,保证i不回退,next就是预先算出i不回退的情况下j应该回退到哪,这样算法复杂度就降到了O(len(s)+len(p)) 也就是 O(len(s))。

有时候模式串是固定的,需要重复在不同的串中查找模式串,所以next数组也可以预先算好一直复用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
// 计算next数组,即j不匹配时的回退位置
vector<int> compute_next(const string& p){
vector<int> next(p.size(), 0);
for(int i=2; i<p.size(); i++){
int j = next[i-1];
while(p[i-1] != p[j] && j>0)
j = next[j];
if(p[i-1]==p[j])
next[i] = j+1;
else
next[i] = 0;
}
return move(next);
}

//串匹配的KMP算法
//返回s中第一个与p匹配的子串的起始下标,若找不到则返回-1
int find(const string& s, const string& p, const vector<int>* pnext=NULL) {
vector<int> _next;
if(pnext == NULL) {
_next = compute_next(p);
pnext = &_next;
}
const vector<int>& next(*pnext);
int i = 0, j = 0;
while(s[i] && p[j]){
if(s[i] == p[j]){
++i;
++j;
}
else {
if(j==0) ++i; // 第一个字符就不匹配,直接后移i
else if(p[j]==0) break; // p[j] 表示找到了匹配,跳出循环
else j = next[j]; // 这里就是利用预处理好的next来回退j,而i不用变
}
}
if(p[j]==0) return i-j;
else return -1;
}

有了next数组后的匹配就想前面说的,只要根据next进行回退就可以了,没有过多技巧。

那么主要讲一下next数组的生成思路,根据next的定义,其实next[i]表示的是p[i]前面最长能有多少字符和p的开头匹配

例如:我们生成 "aabaaab"的next数组,考察next[4]和next[5],p[4]的前面最长有“a”和p的开头匹配,所以next[4]=1,

p[5]的前面最长有“aa”和p的开头匹配,所以next[5]=2。

1
2
3
4
5
0123456
aabaaab
- -↑
aabaaab
-- --↑

总有next[0] = next[1] = 0,我们只要从下标2开始计算next。

对于next[i],我们可以采用数学归纳法的思维,我们找到i-1回退的位置,取j = next[i-1],如果p[i-1]==p[j],那么显然next[i] = next[i-1] + 1,如果p[i-1]!=p[j]呢,next[i]=0吗?并不是

我们还是以 aabaaab 为例,考察next[6],首先我们算出了next[0…5]=[0, 0, 1, 0, 1, 2],而p[5] !=p[2] (‘a’ != ‘b’)

1
2
3
4
5
0123456
aabaaab
-- --↑
aabaaab
-- --↑

这里就有个技巧了,对于p[i-1]和p[j]不匹配时,我们想知道让j回退多少,我们可以利用next数组的含义,尝试让j回退到next[j],再看看p[i-1]和p[j]是否相等,我们在生成next的时候就用到了规模更小的next,还是数学归纳法的思维,j=next[5]=2, 因为p[5]!=p[2] ,令j=next[j]=next[2]=1,而p[5]==p[1],所以next[6] = next[1] + 1 = 2,大致的理解思路就是这样,严格的证明见:前缀函数与 KMP 算法 - OI Wiki (oi-wiki.org)

实现比较简单直接看代码,说几点:

  1. (i%MAX_SIZE+MAX_SIZE)%MAX_SIZE 是为了支持负数下标,如果不需要负下标可以直接i%MAX_SIZE
  2. 循环队列中,因为begin、end一直增加,所以不需要full标志仍然可以把空间用足,不存在队满和队空条件相同。
  3. 如果限定MAX_SIZE是2的幂,可以用 i&(MAX_SIZE-1) 来代替取模,而且同样支持负数下标,真是又快又好 _
  4. CircularArray的主要用途是在DP循环的时候如果状态仅依赖前N项,那么可以简单地把空间节省到N。
  5. Queue的用途是为了替代std::queue,但实际上std::queue已经相当快了。

总的来说这两个数据结构在比赛中基本用不到。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
template<typename T, int MAX_SIZE>
class CircularArray {
T data[MAX_SIZE];
public:
CircularArray(const T& t=T()) {fill(data, data+MAX_SIZE, t);}
T& operator [] (int i) {
return data[(i%MAX_SIZE+MAX_SIZE)%MAX_SIZE];
}
};

template<typename T, int MAX_SIZE>
class Queue {
CircularArray<T, MAX_SIZE> data;
int begin = 0;
int end = 0;
public:
void push(const T& t) {
assert(!full());
data[end++] = t;
}

T pop() {
assert(!empty());
return data[begin++];
}

T& peek() {
assert(!empty());
return data[begin];
}

const T& peek() const {
return const_cast<Queue*>(this)->peek();
}

bool empty() const {
return end == begin;
}

size_t size() const {
return end-begin;
}

bool full() const {
return size() == MAX_SIZE;
}
};

什么是树状数组?

树状数组又称二叉索引树(Binary Indexed Tree),又以其发明者命名为Fenwick树

是一种支持以O(logn)时间计算区间和同时以O(logn)时间修改元素值的数据结构。

它的功能可以被线段树替代,而且线段树提供了更多功能,树状数组的优势是实现简单。

树状数组的的实现

树状数组提供两种操作:

1)对单点赋值。

2)查询区间和。

(事实上也可以扩展出区间修改和单点查询,我们暂不考虑)

假设我们有一个数组

1
arr = {3, 8, 3, 3, 5, 6, 8, 7};

通常情况比如我们求区间[2,7)的和需要遍历区间上的元素,O(区间长度),如何减少运算次数呢,最简单的思路是我们预处理前缀和sum[],令

1
sum[i]= arr[0]+arr[1]+...+arr[i-1]

当我们需要算[2,7)的区间和,其实就算[0,7)前缀和减去[0,2)的前缀和,即sum[7]-sum[2],这样就可以用O(1)的时间算出任意区间的和,如果数组元素不会发生动态变化这样是可以的,但如果需要交替修改数组元素和查询区间和,这样处理会导致前缀和维护的成本很高,原本的arr[i]=x,我们不得不修改所有k>i的sum[k]来维护前缀和,这样修改数组元素的时间就从原本的O(1)变成了O(n)。

有没有方法可以在修改单点值的便利性和查询区间和的便利性上做个折中呢,肯定是有的,我们可以预处理一些子段和而不是所有前缀和,让修改单点值和查询区间和都只需要访问O(logn)的元素,树状数组线段树都是类似这个思想。

这是在讲线段树时的图,如果我们只考虑前缀和,即从0开始的区间和,这中间很多子段和的存储是不必要的,我们来看[0, 7)的和,如果我们有了下图维护的子段和信息,[0, 7)的和最快可以通过 [0, 4)的和 + [4, 6)的和 + [6,7)的和 = 17 + 11 + 8 计算得到。

segtree

我们看图时是很容易想到的,那么这个[0, 4)、[4, 6)、[6,7)划分是怎么得出来的呢,我们可以把前缀和[0, 7)的右端点7转换成2进制,即111,如果仅保留最左侧的1其他位置0,得到100,就是十进制的4,保留最左侧的两个1,得到110,就是十进制的6,保留最最侧的三个1,得到111,就是十进制的7。4,6,7 正好和我们的划分是一样的。7并不是特殊的,可以选择其他数字也都有这个规律。那么给定一个i,我们就可以通过不断把最右侧1变成0,记录这个过程中所有的数,就可以得到需要用到的子段和的划分。具体操作方式我们可以通过位运算。

1
2
3
4
// 得到n二进制最右侧的1表示的数
int lowbit(int n) {
return n & (-n);
}

还有一些等价的写法

1
n & ~(n-1) 或 n ^ (n & (n-1))

划分子段和的方式有了,刚才提到如果我们只计算前缀和,线段树维护的这些子段很多是多余的,再结合我们的划分方式,其实只要把子段和存储在子段的右端点即可。需要用到的子段右端点是不会重复的,因为任何一个右端点i 对应唯一子段就是 [k , i),其中k是i的二进制去掉最右的1。也就是下面这样,虚线上方圆圈内的值就是存储在右端点的子段和。

bitree

那么我们就可以仅用一个数组来存储上面的这棵树了。

1
2
bitree[] = {0, 3, 11, 3, 17, 5, 11, 8, 43};
0 1 2 3 4 5 6 7 8

实际上这里 len(bitree) = len(arr) + 1,不过bitree[0] 代表空的子段和,总是0,想省下多出来的1个单位空间也是可以的,不过没有必要。

所以求任意的区间和,我们先转成前缀和相减,再划分为子段和去bitree[]里取值就可以了

例如:

1
2
3
4
5
  [2, 7)的和
= [0, 7)的和 - [0, 2)的和
= (bitree[4] + bitree[6] + bitree[7]) - bitree[2]
= (17 + 11 + 8) - 11
= 25

用代码实现就是

1
2
3
4
5
6
7
8
9
// 查询前缀和
int query(int i) {
int result = 0;
while(i) {
result += bitree[i];
i -= lowbit(i); // 等价的 i &= i-1
}
return result;
}

剩下的问题是如何修改元素值呢?我们看上面的树状图,当要修改元素4的时候,会影响到[4,5)、[4,6),[0,8) 三个区间,即一直要沿着父节点修改到根,那么就是我们用bitree的下标表示就是5-6-8,如何得到这一串数呢,是否也和二进制存在某种关系呢,直接说答案,从要修改的元素编号+1开始,每次令i+=lowbit(i) 得到下个序号(正好是把查询里的减法变成加法),直到下标超出bitree的长度,事实上最后一次的下标总是bitree的根,也就是最后一个的元素。

1
2
3
4
5
6
7
8
// 修改元素i的值
void add(int i, int x) {
++i;
while(i<=8) {
bitree[i] += x;
i += lowbit(i);
}
}

这里实现的是add方法,如果我们想设置元素i为新值,可以

1
2
add(i, -query(i,i+1)); // 减去旧值
add(i, x); // 加上新值

完整实现

和线段树比一下是不是简单很多

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
template <typename T, int N=(1<<17)>
class BITree{
T C[N+1];
int lowbit(int i){return i&(-i);}
public:
BITree(){clear();}

void clear(){memset(C,0,sizeof(C));}

// 增量修改元素i的值
void add(int i, T d){
for(i++; i<=N; i+=lowbit(i)) C[i]+=d;
}

// [0, i)元素的和
T sum(int i){
T r=0;
for(; i; i-=lowbit(i)) r+=C[i];
return r;
}

// [b, e)元素的和
T sum(int b, int e) {
return sum(e) - sum(b);
}

// 修改元素i的值
void set(int i, T d) {
add(i, -sum(i, i+1));
add(i, d);
}
};