計算チャレンジ

きっかけはふと目にしたツィート。

12歳の女の子がPythonで問題を解いたとか。

ちょっと面白そう。やってみた。



とりあえず、そのまんま単純に組んでみました

foo.py

$ cat foo.py
 #!/usr/bin/env python
 
 import time
 
 mul = lambda v, r=1: r if v == 0 else mul(v // 10, v % 10 * r)
 count = lambda v, c=0: (v, c) if v < 10 else count(mul(v), c+1)
 judge = lambda v, mc, r, c: (c > mc, v, r, max(c, mc))
 calc = lambda v, mc: judge(v, mc, *count(v))
 
 if __name__ == "__main__":
         st = time.time()
         (v, c) = (0, 0)
         while True:
                 (f, v, r, c) = calc(v+1, c)
                 if f:
                         tm = time.time() - st
                         print('{}, {}, {} # {} sec'.format(v, r, c, tm))
 # EOF
       

主要なのは2行

mul = lambda v, r=1: r if v == 0 else mul(v // 10, v % 10 * r)
count = lambda v, c=0: (v, c) if v < 10 else count(mul(v), c+1)

例えば mul(432) で 4*3*2 の結果 24 が返ります。

例えば count(432) で mul(432), mul(24) で 8 が求まり、その回数2回との組み(8, 2) が返ります。

では実行。

$ ./foo.py
10, 0, 1 : 2.09808349609375e-05 sec
25, 0, 2 : 0.0001239776611328125 sec
39, 4, 3 : 0.00017905235290527344 sec
77, 8, 4 : 0.0003139972686767578 sec
679, 6, 5 : 0.002644062042236328 sec
6788, 0, 6 : 0.027772903442382812 sec
68889, 0, 7 : 0.3118159770965576 sec
2677889, 0, 8 : 13.455680131912231 sec
26888999, 0, 9 : 140.47247910499573 sec
  :

え。11回どころか、、、

3778888999, 0, 10 : 21682.518013954163 sec
  :

10回が求まったのが、21682/(60*60) = 6.02277... 6時間後 orz

これは、ちゃんとやらねば。

このままでは、うちのMacのスペックだと、どれだけ時間がかかることやら。

それ以前に、すぐに32ビットの整数の範囲を超えそうだし。


まずはオーバーフロー対策

安易にdecimal.Decimalを使ってみます。

foo2.py

$ cat foo2.py
#!/usr/bin/env python

import time
from decimal import Decimal

def mul(s):
        m = Decimal('1')
        for c in s:
                m *= Decimal(c)
        return str(m)

count = lambda s, c=0: (s, c) if len(s) < 2 else count(mul(s), c+1)
judge = lambda s, mc, r, c: (c > mc, s, r, max(c, mc))
calc = lambda s, mc: judge(s, mc, *count(s))
next_str = lambda s: str( Decimal(s) + 1 )

if __name__ == "__main__":
        st = time.time()
        (s, c) = ('0', 0)
        while True:
                (f, s, r, c) = calc( next_str(s), c )
                if f:
                        tm = time.time() - st
                        print('{}, {}, {} # {} sec'.format(s, r, c, tm))
# EOF

$ ./foo2.py
10, 0, 1 # 0.0008771419525146484 sec
25, 0, 2 # 0.0014989376068115234 sec
39, 4, 3 # 0.0016231536865234375 sec
77, 8, 4 # 0.0019459724426269531 sec
679, 6, 5 # 0.007252931594848633 sec
6788, 0, 6 # 0.060159921646118164 sec
68889, 0, 7 # 0.5654449462890625 sec
2677889, 0, 8 # 22.18914294242859 sec
26888999, 0, 9 # 237.04830479621887 sec

遅くなりました。確実に。

でもこれで、ゆくゆくオーバーフローは気にしなくていいはず...


ちょっと無駄を省けるように、考えてみます

これらを考慮してコーディング。

foo3.py

$ cat foo3.py
#!/usr/bin/env python

import time
from decimal import Decimal

def mul(s):
        m = Decimal('1')
        for c in s:
                m *= Decimal(c)
        return str(m)

dic = {}

def count(s, c=0):
        if len(s) < 2:
                return (s, c)
        if '0' in s:
                return ('0', c+1)
        if '1' in s:
                s = ''.join( filter( lambda c: c!='1', s ) )
                if s == '':
                        return ('1', c+1)
                return count(s, c)

        s = ''.join( sorted(s) )
        if s in dic:
                (s, c2) = dic.get(s)
                return (s, c+c2)

        (s2, c2) = count( mul(s), c+1 )
        if c == 0 and s not in dic:
                dic[s] = (s2, c2)
        return (s2, c2)

judge = lambda s, mc, r, c: (c > mc, s, r, max(c, mc))
calc = lambda s, mc: judge(s, mc, *count(s))

def next_str(s):
        ck = lambda t, s: any( map( lambda c: c in s, t ) )
        def up(i):
                if i < -len(s):
                        return '2'* (len(s) + 1)
                if s[i] != '9':
                        c = str(int(s[i]) + 1)
                        return s[:i] + c * (-i)
                return up(i-1)

        while True:
                s = up(-1)
                if ck('01', s):
                        continue
                if '5' in s and ck('2468', s):
                        continue
                break
        return s

if __name__ == "__main__":
        st = time.time()
        (s, c) = ('0', 0)
        while True:
                (f, s, r, c) = calc( next_str(s), c )
                if f:
                        tm = time.time() - st
                        print('{}, {}, {} # {} sec'.format(s, r, c, tm))
# EOF

count() では、'0'や'1'があるときは特別扱い。

辞書dicをキャッシュにして、ソート後の数値をキーに、その結果を保持しておきます。

次に試す値を用意する next_str() では、単純なインクリメントをやめて、スピードアップ。

各桁の値をソートした結果が同じものは省くので、next_str() が返す値は、昇順にソートされた値だけになるように。

さらに、各桁のどこかに'0'があるものは、1回で終わるのでハイスコアは望めないのでスキップ。

各桁のどこかに'1'があるものは、それ以前に'1'を除いた回でお試し済みなのでスキップ。

各桁のどこかに'5'があって、かつ、どこかに偶数'2','4','6','8'がある場合も、次回に'0'が出現してハイスコアは望めないのでスキップ。

$ ./foo3.py
22, 4, 1 # 0.0004620552062988281 sec
38, 8, 2 # 0.000659942626953125 sec
55, 0, 3 # 0.0007491111755371094 sec
268, 0, 4 # 0.0011289119720458984 sec
679, 6, 5 # 0.0019609928131103516 sec
6788, 0, 6 # 0.005012989044189453 sec
68889, 0, 7 # 0.013716936111450195 sec
2677889, 0, 8 # 0.044028282165527344 sec
26888999, 0, 9 # 0.09941506385803223 sec
3778888999, 0, 10 # 0.40203118324279785 sec
277777788888899, 0, 11 # 3.8815369606018066 sec
  :      

ふう。良かった。

4秒弱で11回。

何とか12歳の女の子のレベルに到達できただろうか?


さらなる12回を目指して...

$ ./foo3.py
22, 4, 1 # 0.0004780292510986328 sec
38, 8, 2 # 0.0006809234619140625 sec
55, 0, 3 # 0.0007700920104980469 sec
268, 0, 4 # 0.001157999038696289 sec
679, 6, 5 # 0.002006053924560547 sec
6788, 0, 6 # 0.005734920501708984 sec
68889, 0, 7 # 0.013610124588012695 sec
2677889, 0, 8 # 0.04772615432739258 sec
26888999, 0, 9 # 0.10118722915649414 sec
3778888999, 0, 10 # 0.39716410636901855 sec
277777788888899, 0, 11 # 3.8843460083007812 sec
Traceback (most recent call last):
  File "./foo3.py", line 61, in <module>
    (f, s, r, c) = calc( next_str(s), c )
  File "./foo3.py", line 36, in <lambda>
    calc = lambda s, mc: judge(s, mc, *count(s))
  File "./foo3.py", line 30, in count
    (s2, c2) = count( mul(s), c+1 )
  File "./foo3.py", line 23, in count
    return count(s, c)
  File "./foo3.py", line 30, in count
    (s2, c2) = count( mul(s), c+1 )
  File "./foo3.py", line 9, in mul
    m *= Decimal(c)
decimal.InvalidOperation: [<class 'decimal.ConversionSyntax'>]

エラー! orz

とりあえず安易にデバッグプリント入れて状況をば。

foo4.py

$ cat foo4.py
#!/usr/bin/env python

import time
from decimal import Decimal

def mul(s):
        m = Decimal('1')
        for c in s:
                try:
                        m *= Decimal(c)
                except:
                        print('err s={} c={} m={}'.format(s, c, m))
        return str(m)
  :

$ ./foo4.py
22, 4, 1 # 8.678436279296875e-05 sec
38, 8, 2 # 0.0002827644348144531 sec
55, 0, 3 # 0.0003719329833984375 sec
268, 0, 4 # 0.0007586479187011719 sec
679, 6, 5 # 0.0016100406646728516 sec
6788, 0, 6 # 0.004712820053100586 sec
68889, 0, 7 # 0.012827873229980469 sec
2677889, 0, 8 # 0.04451584815979004 sec
26888999, 0, 9 # 0.10637378692626953 sec
3778888999, 0, 10 # 0.3972287178039551 sec
277777788888899, 0, 11 # 3.8991498947143555 sec
err s=+.2223333444555566778889999E c=+ m=1
  :

なー。大きくなり過ぎると指数形式に。

まぁ、'0'〜'9'以外をスキップする対策でいけるかな?


泥縄な対応でどうか?

foo5.py

$ cat foo5.py
#!/usr/bin/env python

import time
from decimal import Decimal

def mul(s):
        m = Decimal('1')
        for c in s:
                if c in '0123456789':
                        m *= Decimal(c)
        return str(m)
  :

いざ実行。

$ ./foo5.py
22, 4, 1 # 8.296966552734375e-05 sec
38, 8, 2 # 0.0002779960632324219 sec
55, 0, 3 # 0.0003681182861328125 sec
268, 0, 4 # 0.0007522106170654297 sec
679, 6, 5 # 0.0015931129455566406 sec
6788, 0, 6 # 0.004734992980957031 sec
68889, 0, 7 # 0.012942075729370117 sec
2677889, 0, 8 # 0.04680800437927246 sec
26888999, 0, 9 # 0.1014409065246582 sec
3778888999, 0, 10 # 0.3986780643463135 sec
277777788888899, 0, 11 # 3.9795498847961426 sec

と、4秒弱で11回が表示されてから、かれこれ20時間...

これまでの時間をちょっと確認。

bar.py

bar.txt

$ cat bar.py
#!/usr/bin/env python

import sys
from decimal import Decimal

if __name__ == "__main__":
        lst = []
        while True:
                s = sys.stdin.readline()
                if not s:
                        break
                ss = s.strip().split()
                lst.append(ss)

        print( '\n'.join( map(str, lst) ) )
        print( '-' * 30 )

        ts = list( map( lambda ss: Decimal( ss[4] ), lst ) )
        ds = list( map( lambda i: ts[i+1] - ts[i], range( len(ts) - 1 ) ) )

        print( '\n'.join( map(str, ds) ) )
        print( '-' * 30 )

        ms = list( map( lambda i: ds[i+1] / ds[i], range( len(ds) - 1 ) ) )

        print( '\n'.join( map(str, ms) ) )
        print( '-' * 30 )
# EOF

$ cat bar.txt
22, 4, 1 # 8.296966552734375e-05 sec
38, 8, 2 # 0.0002779960632324219 sec
55, 0, 3 # 0.0003681182861328125 sec
268, 0, 4 # 0.0007522106170654297 sec
679, 6, 5 # 0.0015931129455566406 sec
6788, 0, 6 # 0.004734992980957031 sec
68889, 0, 7 # 0.012942075729370117 sec
2677889, 0, 8 # 0.04680800437927246 sec
26888999, 0, 9 # 0.1014409065246582 sec
3778888999, 0, 10 # 0.3986780643463135 sec
277777788888899, 0, 11 # 3.9795498847961426 sec

$ cat bar.txt | ./bar.py
['22,', '4,', '1', '#', '8.296966552734375e-05', 'sec']
['38,', '8,', '2', '#', '0.0002779960632324219', 'sec']
['55,', '0,', '3', '#', '0.0003681182861328125', 'sec']
['268,', '0,', '4', '#', '0.0007522106170654297', 'sec']
['679,', '6,', '5', '#', '0.0015931129455566406', 'sec']
['6788,', '0,', '6', '#', '0.004734992980957031', 'sec']
['68889,', '0,', '7', '#', '0.012942075729370117', 'sec']
['2677889,', '0,', '8', '#', '0.04680800437927246', 'sec']
['26888999,', '0,', '9', '#', '0.1014409065246582', 'sec']
['3778888999,', '0,', '10', '#', '0.3986780643463135', 'sec']
['277777788888899,', '0,', '11', '#', '3.9795498847961426', 'sec']
------------------------------
0.00019502639770507815
0.0000901222229003906
0.0003840923309326172
0.0008409023284912109
0.0031418800354003904
0.008207082748413086
0.033865928649902343
0.05463290214538574
0.2972371578216553
3.5808718204498291
------------------------------
0.4621026894865523798134731380
4.261904761904763225719324767
2.189323401613904238317815792
3.736319818542670724115081800
2.612156624677492997972702261
4.126427098161113086382764279
1.613211399284728646755120024
5.440625450038622200286960463
12.04718766217775897808440818
------------------------------
$ 

前回の表示からの時間と、それが前回の何倍になっているかを表示してみました。

9回から10回の時間に対して、10回から11回の時間が12倍に跳ね上がっていますが...

20時間は 20*60*60 = 72000秒

7200/4 = 1800 で、既に 1800倍

よくわからんですね


さらに無駄を省けるように、考えてみます

例えば3桁以上のときに

next_str() で次に試す値を用意するときに、ある桁を増やしたとして... その増やした桁の値と、その桁よりも上の各桁の値について

これまで通り、'5'に対して偶数が含まれていると、次回0が出現するので、ハイスコアは望めない。

例えば'2'に対して'2'が含まれていると、以前にそれよりも少ない桁で'4'を含むパターンで試しているはず。

同様に'2'に対して'3'なら'6'で、'4'なら'8'で、既に試してるはず。

同様に'3'に対して'3'なら'9'で...

これらをひっくるめて、次のパターンならスキップしても良いはず。

# 文字列sの中に、文字列ts中のどれかの文字を含むか判定
ts_in = lambda ts, s: any( map( lambda c: c in s, ts ) )


# 増やした桁の値 c と、その桁よりも上の各桁の値 s で、スキップしても良いか判定
# 12345999 から 12346666 にカウントアップした場合だと      
# s == 1234 で c == 6

def is_skip(s, c):
        d = {
                '5': '2468',
                '2': '2345',
                '3': '23',                
                '4': '25',
                '6': '5',
                '8': '5',
        }
        return ts_in( d.get(c), s ) if c in d else False

さらに mul() でのかけ算を、ちょっと変更 (大した効果があるかは謎)

def mul(s):
        d = {}
        for c in s:
                if c not in d:
                        d[c] = 1
                else:
                        d[c] += 1
        m = Decimal('1')
        for (k, v) in d.items():
                b = Decimal(k)
                m *= (b if v == 1 else b*b if v == 2 else b**v)

        s = ''.join( list( filter( lambda c: c.isdigit(), str(m) ) ) )
        return s

これらを考慮して

foo6.py