Google code jam 2008 round 1 Bの問題A. crop trianglesに挑戦する.
$x$-$y$平面上の点のうち,$x$座標も$y$座標も整数値となっている点を整数格子点という.
この問題を入力と出力で定義すると,
である.
ただし,この問題では3点が異なっていれば(その3点が一直線上にのっていたとしても)三角形と見なす. 一方で,3点のうち重複があったら三角形ではない.
なお,この問題では,入力の$n$点の座標値が直接与えられるわけではなく,点の座標値を$n$個生成するための関数(正確には関数のパラメータ)が与えられる.
予備知識だが,この点の座標値を生成する方法は「線形合同法」とよばれ,擬似乱数を生成する最も単純な方法として用いられるものである.
入力と,それに対する正しい出力の例は,残念ながら陽には与えられていない.
まず,問題のページの説明に従って,入力の点座標を生成する関数を以下にgenarate_pointsとして定義する.
def generate_points(n, a, b, c, d, x_0, y_0, m):
x = x_0
y = y_0
point = [(x, y)]
for i in range(1, n):
x = ((a * x) + b) % m
y = ((c * y) + d) % m
point += [(x, y)]
return point
早速サンプルの1つ目の4点を生成してみる.
point4 = generate_points(4, 10, 7, 1, 2, 0, 1, 20)
point4
[(0, 1), (7, 3), (17, 5), (17, 7)]
座標だけではイメージがわかないかもしれないので,実際に点をプロットして様子を見てみる.
なお,直下のコードは描画用のコードなので気にしないでほしい.
# このコードは描画用
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
plt.figure(figsize=(19/2, 8/2))
plt.axis([-1, 18, 0, 8])
z = plt.plot([0, 7, 17, 0], [1, 3, 5, 1], 'r-')
z = plt.plot([8], [3], 'ro', ms=16)
z = plt.plot(np.array(point4)[:,0], np.array(point4)[:,1], 'o')
plt.grid(True)
z = plt.xticks(np.arange(-1, 19, 1))
上図の青い点が入力の点である.
手計算によると,上記の赤い三角形の重心(赤丸)だけが整数格子点である.
次にサンプルの2つ目の6点を生成してみる.
point6 = generate_points(6, 2, 0, 2, 1, 1, 2, 11)
point6
[(1, 2), (2, 5), (4, 0), (8, 1), (5, 3), (10, 7)]
こちらもプロットしてイメージを掴むことにする.
# このコードは描画用
%matplotlib inline
plt.figure(figsize=(11/2, 9/2))
plt.axis([0, 11, -1, 8])
z = plt.plot([1, 4, 10, 1], [2, 0, 7, 2], 'r-')
z = plt.plot([5], [3], 'ro', ms=16)
z = plt.plot([2, 8, 5, 2], [5, 1, 3, 5], 'r-')
z = plt.plot(np.array(point6)[:,0], np.array(point6)[:,1], 'o')
plt.grid(True)
z = plt.xticks(np.arange(0, 12, 1))
上図の青い点が入力の点である.
手計算によると,今度は2つの三角形(上図の赤い三角形)の重心が整数格子点である. 三角形の1つは,3点が一直線上にのっているため,潰れている. この問題ではこれも三角形とみなす. このサンプルでは,たまたま赤い2つの三角形の重心が同じ点である.
$n$点から三角形の頂点を3つ選ぶ組み合わせは$_n\text{C}_3$通りなので,そのそれぞれに関して重心を計算してその座標値が整数となるものの数を数える.
という方針で出力される答えは正しいはずなので,実際に試してみる.
以下に,関数simple_answerとして定義する.
def is_center_integer(p1, p2, p3):
(x_1, y_1) = p1
(x_2, y_2) = p2
(x_3, y_3) = p3
if (x_1 + x_2 + x_3) % 3 == 0 and (y_1 + y_2 + y_3) % 3 == 0:
return True
else:
return False
def simple_answer(n, a, b, c, d, x_0, y_0, m):
point = generate_points(n, a, b, c, d, x_0, y_0, m) # まず先述の関数generage_pointsを使って,点集合を生成する.
count = 0
for i in range(n - 2):
for j in range(i + 1, n - 1):
for k in range(j + 1, n):
if is_center_integer(point[i], point[j], point[k]) == True:
count += 1
return count
simple_answer(4, 10, 7, 1, 2, 0, 1, 20)
1
simple_answer(6, 2, 0, 2, 1, 1, 2, 11)
2
とりあえず,サンプルの数値例では合っていそうである.
Small datasetでも試してみる.
def all_answer(input_file_name, output_file_name):
input_file = open(input_file_name)
output_file = open(output_file_name, 'w')
N = int(input_file.readline())
for case_number in range(1, N + 1):
n, a, b, c, d, x_0, y_0, m = map(int, input_file.readline().split())
output_file.write(f'Case #{case_number}: {simple_answer(n, a, b, c, d, x_0, y_0, m)}\n')
input_file.close()
output_file.close()
all_answer('A-small-practice.in', 'A-small-practice.out')
良さそうである.
しかし,Large datasetで試してみると,いつまで待っても計算が終わらない.
上記simple_answerおよびall_answerは,正しい計算結果を出力する. しかし,Large datasetを解くには遅すぎる. なお,2018年より前のGoogle code jamのコンテスト本番では,入力ファイルをダウンロード後数分以内に出力ファイルをアップロードしないとtime outとなった. 2018年以降のGoogle code jamではオンラインジャッジが採用されたため,計算時間の評価はより厳しくなった.
実際にどのくらい時間がかかるのか測ってみる. Pythonには計算時間を測る関数がいろいろ用意されているが,ここでは簡便にtimeモジュールのtimeメソッドを使うことにする.
time.time()
でそのコマンドがよばれた瞬間の時刻を(かなり細かい単位で)得られる.
このメソッドを用いて,計算時間も返すように上記simple_answerを再定義する.
import time
def simple_answer(n, a, b, c, d, x_0, y_0, m):
start_time = time.time() # 計算開始時刻をstart_timeに代入する.
point = generate_points(n, a, b, c, d, x_0, y_0, m)
count = 0
for i in range(n - 2):
for j in range(i + 1, n - 1):
for k in range(j + 1, n):
if is_center_integer(point[i], point[j], point[k]) == True:
count += 1
finish_time = time.time() # 計算終了時刻をfinish_timeに代入する.
return count, finish_time - start_time # 解答だけでなく,計算時間も返すようにする.
再度,サンプルデータで計算してみる.
simple_answer(4, 10, 7, 1, 2, 0, 1, 20)
(1, 9.059906005859375e-06)
simple_answer(6, 2, 0, 2, 1, 1, 2, 11)
(2, 3.1948089599609375e-05)
ちなみに,計算時間は秒単位である. 非常に短い時間で計算が完了していることがわかる.
次に,2つ目のサンプルでnだけを10,100,1000と大きくしていって計算時間の推移を見てみる. なお,nが変われば答えも変わる.
simple_answer(10, 2, 0, 2, 1, 1, 2, 11)
(16, 5.602836608886719e-05)
simple_answer(100, 2, 0, 2, 1, 1, 2, 11)
(18100, 0.05512380599975586)
simple_answer(1000, 2, 0, 2, 1, 1, 2, 11)
(18607000, 50.82445311546326)
nが10の場合,100の場合の計算時間は短いが,nが1000の場合には1分くらいかかっている. (この計算時間はあくまでも宮本のパソコンの場合です.)
雰囲気をつかむために,nが100の場合から1000の場合までの計算時間を計測し,表示してみる.
cpu_time = []
for n in range(100, 1001, 100):
sol, t = simple_answer(n, 2, 0, 2, 1, 1, 2, 11)
cpu_time += [(n, t)]
cpu_time
[(100, 0.05331087112426758), (200, 0.38202786445617676), (300, 1.3207190036773682), (400, 3.1812548637390137), (500, 6.216853141784668), (600, 10.762550115585327), (700, 17.645094871520996), (800, 26.440786838531494), (900, 37.66292405128479), (1000, 50.80378079414368)]
# このコードは描画用
%matplotlib inline
plt.title('CPU time of simple method')
plt.xlabel('$n$: # of points')
plt.ylabel('CPU time (seconds)')
data = np.array(cpu_time)
z = plt.plot(data[:,0], data[:,1])
nが増えるに従って,計算時間が急激に増えているのがわかる. nが1000程度ならば1分で終わるが,nがもっと大きくなったらどうなるのだろうか?
ここで,simple_answerを見直してみる. 解答を出すための計算では,forで三重に繰り返している. よって,大雑把に言って,$n^3$に比例した時間がかかりそうである.
実際,nが100のときに比べてnが1000のときはnが10倍になっているわけだが,計算時間は$10^3 = 1000$倍くらいになっている.
Google code jamの問題文によればLarge datasetではnが最大で100000になるようだ. これは,nが1000の場合の何倍かというと,
100000 / 1000
100.0
100倍である.
nが1000の場合の計算時間は約1分なので,大雑把に言って,最悪の場合には$100^3$分くらいかかりそうである. (そしてGoogleは当然その最悪の場合の数値例を用意していそうである.)
$100^3$分は
100 ** 3 / 60
16666.666666666668
時間くらいであり,これは
100 ** 3 / 60 / 24
694.4444444444445
日くらいである. これはおおよそ2年弱である. 計算の終了を待ってはいられない.
というわけで,large datasetも解けるよう,解法を工夫する.
解法の工夫は色々考えられるが,例えば,以下の方針があり得る.
まず,先程のis_center_integerで三角形の重心が整数格子点か否か判定した方法からすぐに分かるが
「$x$座標,$y$座標それぞれに関して,3頂点の座標値を合計して3の倍数になること」
と
「三角形の重心が整数格子点であること」
は同値である.
よって
ということがわかる.
以上より解答は式 $$ \sum_{(i+i'+i'')\%3=0, \ (i+i'+i'')\%3=0} |G(i,j)| \cdot |G(i',j')| \cdot |G(i'',j'')| + \sum_{i=0}^2 \sum_{j=0}^2 {}_{|G(i,j)|}\text{C}_3 $$ で計算できる.
この方針に従って解答を出力する関数answerを以下に定義する. なお,集合$G(i,j)$の個数を格納するためにtuple (i, j)をキーとした辞書を用いる.(ここは別の方法でも良い.)
def answer(n, a, b, c, d, x_0, y_0, m):
start_time = time.time() # 計算開始時刻をstart_timeに代入する.
point = generate_points(n, a, b, c, d, x_0, y_0, m)
'''
ここまでは,先程の関数solve_simplyと同様である.
ここからが,工夫されている.
'''
surplus = {
(0, 0): 0, # キーを(i,j)としてG(i,j)に属する点の数を覚える.点数の初期値は0にしておく.
(0, 1): 0,
(0, 2): 0,
(1, 0): 0,
(1, 1): 0,
(1, 2): 0,
(2, 0): 0,
(2, 1): 0,
(2, 2): 0, # 最後のカンマはいらない.しかし後に要素が増えたときのミスを避けるため,最後にもカンマを付けることを習慣としたい.
}
for p in point: # すべての点に関して以下を繰り返す.
x, y = p # まず,点の座標値をx, yとする.
surplus[(x % 3, y % 3)] += 1 # 該当する集合G(i,j)の個数を1増やす.
'''
ここまでで,G(i,j)に属する点の個数がG[(i, j)]に入っている.
ここからは,余りの値(i, j)ごとに,重心が整数格子点になる組合せを吟味する.
全部ベタに書いても良いが,計算間違いが怖いので,組合せの計算は繰り返し文で行う.
'''
count = 0 # まず,答えの個数を格納する変数の値を0にしておく.
surplus_keys = list(surplus.keys()) # 上で用意したsurplusのkeyをリストとして保存する.
'''
以下ではsurplusのkeyを3つ選ぶやり方のすべてを重複なく吟味するために3重ループを使っている.
'''
for k in range(len(surplus_keys) - 2):
i, j = surplus_keys[k]
for kk in range(k + 1, len(surplus_keys) - 1):
ii, jj = surplus_keys[kk]
for kkk in range(kk + 1, len(surplus_keys)):
iii, jjj = surplus_keys[kkk]
if (i + ii + iii) % 3 == 0 and (j + jj + jjj) % 3 == 0:
count += surplus[(i, j)] * surplus[(ii, jj)] * surplus[(iii, jjj)]
'''
次に,余りが同じ点を3頂点とする三角形の重心は整数格子点になるので,その中で3頂点を選ぶ組合せの数を計算する.
'''
for i, j in surplus_keys:
if surplus[(i, j)] >= 3:
count += surplus[(i, j)] * (surplus[(i, j)] - 1) * (surplus[(i, j)] - 2) // 6 # 3つ選ぶ組合せの数をベタに計算している.
finish_time = time.time() # 計算終了時刻をfinish_timeに代入する.
print(f'n = {n}, CPU time: {finish_time - start_time}') # 計算時間は画面に表示する.
return int(count)
少々長くなったが,関数定義できた. この関数でもforの3重ループがあるが,そのそれぞれはnとは関係なく,3までしか繰り返さないので問題ない.
この関数answerの計算時間は,ざっと見積もって,点の数nに比例しそうである.
まず,サンプルの数値例で試してみる.
answer(4, 10, 7, 1, 2, 0, 1, 20)
n = 4, CPU time: 3.910064697265625e-05
1
answer(6, 2, 0, 2, 1, 1, 2, 11)
n = 6, CPU time: 4.124641418457031e-05
2
解答の数値も計算時間も問題なさそうである. ただし,計算時間はsimple_answerに比べて若干長くなっている.
次に,nが100の場合から1000の場合までの計算時間を測ってみる.
for n in range(100, 1001, 100):
answer(n, 2, 0, 2, 1, 1, 2, 11)
n = 100, CPU time: 8.535385131835938e-05 n = 200, CPU time: 0.0001289844512939453 n = 300, CPU time: 0.00016427040100097656 n = 400, CPU time: 0.00023603439331054688 n = 500, CPU time: 0.00038886070251464844 n = 600, CPU time: 0.0004878044128417969 n = 700, CPU time: 0.00033402442932128906 n = 800, CPU time: 0.0003981590270996094 n = 900, CPU time: 0.0005218982696533203 n = 1000, CPU time: 0.0006368160247802734
nが4や6の場合にはsimple_answerの方が速かったが,nが大きくなるとanswerの方が圧倒的に速い.
もっと大きなnでも計測してみる.
n = 100
while n <= 100000:
answer(n, 2, 0, 2, 1, 1, 2, 11)
n *= 10
n = 100, CPU time: 8.392333984375e-05 n = 1000, CPU time: 0.0004611015319824219 n = 10000, CPU time: 0.006317138671875 n = 100000, CPU time: 0.054682254791259766
Google code jamで想定されている最大のnでも1秒未満で計算が終わりそうである.
最後に,ファイルからデータを読み込んで解答をファイルに書き出す関数all_answerを作ってみる.
def all_answer(input_file_name, output_file_name):
input_file = open(input_file_name)
output_file = open(output_file_name, 'w')
N = int(input_file.readline())
for case_number in range(1, N + 1):
n, a, b, c, d, x_0, y_0, m = map(int, input_file.readline().split())
output_file.write(f'Case #{case_number}: {answer(n, a, b, c, d, x_0, y_0, m)}\n') # この行だけsolve_simplyからsolveに書き換えた.
input_file.close()
output_file.close()
all_answer('A-large-practice.in', 'A-large-practice.out')
n = 30123, CPU time: 0.017490148544311523 n = 50000, CPU time: 0.031896114349365234 n = 20, CPU time: 4.601478576660156e-05 n = 4, CPU time: 4.410743713378906e-05 n = 49999, CPU time: 0.026580810546875 n = 100000, CPU time: 0.05723428726196289 n = 100000, CPU time: 0.05668520927429199 n = 30012, CPU time: 0.01844310760498047 n = 49999, CPU time: 0.03319287300109863 n = 6, CPU time: 4.601478576660156e-05
計算時間は画面に表示され,計算結果はファイルに書き込まれる.
ここまでで,Google code jamのlarge datasetに対しても,(制限時間以内に)正解を出力できるようになった.
しかし,関数answerにはまだまだ改善の余地がある. これ以上書くと長くなるので別のページに改めて書くことにする.