【Python】AtCoder abc125_c GCD on Blackboard

atcoder.jp

さっぱりわからなかったためググる

drken1215.hatenablog.com

とりあえず問題の言い換えとして N 個の値の中から 1 個だけ取り除いた N−1 個の整数の最大公約数の最大値を求めよ という問題だと思うことができる。

あー、なるほど。 「愚直に実装しても計算量的にNGだよ」とも書いてあるがとりあえず実装してみた。

まずテストコードを書く。

import unittest
import main


class MainTest(unittest.TestCase):
    def setUp(self):
        pass

    def tearDown(self):
        pass

    def test_1(self):
        self.assertEqual(2, main.main(3, [7, 6, 8]))

    def test_2(self):
        self.assertEqual(6, main.main(3, [12, 15, 18]))


if __name__ == "__main__":
    unittest.main()

続いて実装。

from fractions import gcd
from itertools import combinations
from functools import reduce


def calc_gcd(nums):
    return reduce(gcd, nums)


def main(input_1, input_2):
    nums_list = list(map(list, list(combinations(input_2, len(input_2)-1))))

    gcd_list = []
    for nums in nums_list:
        calced = calc_gcd(nums)
        gcd_list.append(calced)

    return max(gcd_list)


if __name__ == "__main__":
    input_1 = input()
    input_2 = list(map(int, input().split()))
    print(main(input_1, input_2))

Pythonの文法で下記を参照した。

note.nkmk.me

note.nkmk.me

note.nkmk.me

note.nkmk.me

gcd 関数はPython3.4以前では fractions に属するということに注意(AtCoderPythonは3.4系)。

案の定、このままだと実行時間オーバーなので続きを読んでみる。

これは以下のように累積和ならぬ累積 GCD を前処理で求めておくと、高速にわかるのだ。

このあたりから理解が怪しくなる。

まず累積和がわからないのでリンク先を読む。

qiita.com

累積和についてはなんとなく理解した。

しかしまだ累積GCDの場合のイメージができない……

しばらく考えたが現状の実力では咀嚼できなかったため参考元のC++実装をPythonに書き換えてACすれば良しとした。

from fractions import gcd


def main(input_1, input_2):
    left = [0] * (input_1+1)
    for i in range(input_1):
        left[i+1] = gcd(left[i], input_2[i])

    right = [0] * (input_1 + 1)
    i = input_1 - 1
    while i >= 0:
        right[i] = gcd(right[i+1], input_2[i])
        i -= 1

    while i >= 0:
        right[i] = gcd(right[i+1], input_2[i])

    res = 0
    for i in range(input_1):
        l = left[i]
        r = right[i+1]
        gcd_val = gcd(l, r)
        if gcd_val > res:
            res = gcd_val

    return res


if __name__ == "__main__":
    input_1 = int(input())
    input_2 = list(map(int, input().split()))
    print(main(input_1, input_2))