Project Euler 44 - Pentagon numbers
Here is my solution using SageMath
Naive attempt
var('n')
P = n * (3 * n - 1) / 2
var('x')
solve(P == x, n)
[n == -1/6*sqrt(24*x + 1) + 1/6, n == 1/6*sqrt(24*x + 1) + 1/6]
def is_pentagonal(x):
return 1/6*sqrt(24*x + 1) + 1/6 in ZZ
def main():
for i in PositiveIntegers():
for j in range(1, i):
if is_pentagonal(P(n=i) - P(n=j)) \
and is_pentagonal(P(n=i) + P(n=j)):
print(P(n=i) - P(n=j))
break
else:
continue
break
%time main()
5482660
CPU times: user 4min 13s, sys: 5.12 ms, total: 4min 13s
Wall time: 4min 13s
is indeed the solution, but it is not guaranteed that this is the minimum; it is only considered an upper bound at this stage.
This algorithm is too slow to search the whole problem space.
Optimization
Let’s fix iterations to as we benchmark different solutions:
def main():
for i in range(1, 1000):
for j in range(1, i):
if is_pentagonal(P(n=i) - P(n=j)) \
and is_pentagonal(P(n=i) + P(n=j)):
print(P(n=i) - P(n=j))
break
else:
continue
break
%prun main()
1997260 function calls in 55.963 seconds
Ordered by: internal time
ncalls tottime percall cumtime percall filename:lineno(function)
499314 27.519 0.000 32.720 0.000 1275694909.py:1(is_pentagonal)
1 23.243 23.243 55.963 55.963 3800642736.py:1(main)
499314 4.088 0.000 4.088 0.000 {method 'sqrt' of 'sage.symbolic.expression.Expression' objects}
499314 0.978 0.000 5.201 0.000 functional.py:1897(sqrt)
499314 0.135 0.000 0.135 0.000 {built-in method builtins.isinstance}
1 0.000 0.000 55.963 55.963 {built-in method builtins.exec}
1 0.000 0.000 55.963 55.963 <string>:1(<module>)
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
Python is not a performant language, but we can do better.
Memoization
Our program spends the majority of its time in is_pentagonal
calls. Will memoization help?
@CachedFunction
def pentagonal(n):
return P(n=n)
@CachedFunction
def is_pentagonal(x):
return 1/6*sqrt(24*x + 1) + 1/6 in ZZ
def main():
for i in range(1, 1000):
for j in range(1, i):
if is_pentagonal(pentagonal(n=i) - pentagonal(n=j)) \
and is_pentagonal(pentagonal(n=i) + pentagonal(n=j)):
print(P(n=i) - P(n=j))
break
else:
continue
break
%time main()
CPU times: user 26.5 s, sys: 86.6 ms, total: 26.6 s
Wall time: 26.7 s
A speed-up.
Simplifying the math
Let’s investigate other ways to check goal conditions.
var('i j x')
solve(P(n=i) - P(n=j) == P(n=x), [x])
[x == -1/6*sqrt(36*i^2 - 36*j^2 - 12*i + 12*j + 1) + 1/6, x == 1/6*sqrt(36*i^2 - 36*j^2 - 12*i + 12*j + 1) + 1/6]
solve(P(n=i) + P(n=j) == P(n=x), [x])
[x == -1/6*sqrt(36*i^2 + 36*j^2 - 12*i - 12*j + 1) + 1/6, x == 1/6*sqrt(36*i^2 + 36*j^2 - 12*i - 12*j + 1) + 1/6]
Now we can condense the logic into one function:
def problem_conditions(i, j):
return 1/6*sqrt(36*i^2 - 36*j^2 - 12*i + 12*j + 1) + 1/6 in ZZ and \
1/6*sqrt(36*i^2 + 36*j^2 - 12*i - 12*j + 1) + 1/6 in ZZ
def main():
for i in range(1, 1000):
for j in range(1, i):
if problem_conditions(i, j):
print(P(n=i) - P(n=j))
break
else:
continue
break
%time main()
CPU times: user 27.5 s, sys: 0 ns, total: 27.5 s
Wall time: 27.5 s
Not much difference compared to the memoized solution.
Note that memoizing problem_conditions
wouldn’t help since pairs are never repeated.
Parallelism
import multiprocessing
cpus = multiprocessing.cpu_count()
@parallel(cpus)
def process_batch(batch):
result = [problem_conditions(*tup) for tup in batch]
if any(result):
ind = result.index(True)
(i, j) = batch[ind]
return P(n=i) - P(n=j)
return None
def prepare_input(ls):
batch_size = len(ls) // cpus
return [ls[i * batch_size : (i + 1) * batch_size] \
for i in range(cpus)]
def main():
input = prepare_input([(i, j) \
for i in range(1000) \
for j in range(1, i)])
results = [r[1] for r in list(process_batch(input)) if r[1] in ZZ]
if results:
print(min(results))
%time main()
CPU times: user 77.8 ms, sys: 123 ms, total: 201 ms
Wall time: 11 s
We have sped up the naive approach by a factor of ; but this still feels too slow for a larger scale search which is needed to guarantee minimality.
A new approach
Let’s use plain arrays to store precomputed pentagonals. How big an array do we need to check all up to ?
var('x n')
solve(P(n=x) == 2 * P(n=n), x)[1].rhs().subs(n=1000000).n()
1.41421349333749e6
So about for large .
max_i = 1000
def main():
p = [P(n=i) for i in range(ceil(1.415 * max_i))]
p_set = set(p)
results = [(p[i], p[j])
for i in range(max_i)
for j in range(1, i)
if all([(p[i] + p[j] in p_set), (p[i] - p[j] in p_set)])]
if results:
print(min([abs(item[1] - item[0]) for item in results]))
%time main()
CPU times: user 1.29 s, sys: 6.63 ms, total: 1.3 s
Wall time: 1.3 s
We’re almost at the second mark. We can parallelize this solution:
import multiprocessing
cpus = multiprocessing.cpu_count()
max_i = 1000
p = [P(n=i) for i in range(ceil(1.415 * max_i))]
p_set = set(p)
@parallel(cpus)
def process_batch(batch):
return [abs(p[tup[0]] - p[tup[1]]) for tup in batch
if all([(p[tup[0]] + p[tup[1]] in p_set), \
(p[tup[0]] - p[tup[1]] in p_set)])]
def prepare_input(ls):
batch_size = len(ls) // cpus
return [ls[i * batch_size : (i + 1) * batch_size] \
for i in range(cpus)]
def main():
results = flatten([r[1] for r in
process_batch(prepare_input( \
[(i, j)
for i in range(max_i)
for j in range(1, i)]))])
if results:
print(min(item for item in results))
%time main()
CPU times: user 86.3 ms, sys: 136 ms, total: 223 ms
Wall time: 633 ms
Cython
If we want serious performance, we will need to use Cython to compile our code:
%%cython
from libc.math cimport sqrt
def P(n):
return n * (3 * n - 1) // 2
cdef int is_perfect_square(long long d):
cdef long s = <long long> sqrt(d)
return s * s == d
cdef int problem_conditions(long long i, long long j):
cdef long long s1 = 36*i*i - 36*j*j - 12*i + 12*j + 1
cdef long long s2 = 36*i*i + 36*j*j - 12*i - 12*j + 1
return is_perfect_square(s1) and \
<long> sqrt(s1) % 6 == 5 and \
is_perfect_square(s2) and \
<long> sqrt(s2) % 6 == 5
def cython_results(max_i):
results = []
cdef int i, j
for i in range(1, max_i):
for j in range(1, i):
if problem_conditions(i, j):
results.append(abs(P(n=i) - P(n=j)))
return results
def main():
max_i = 1000
results = cython_results(max_i)
print(results)
%time main()
[]
CPU times: user 1.2 ms, sys: 0 ns, total: 1.2 ms
Wall time: 1.21 ms
Much better now!
def main():
max_i = 3000
results = cython_results(max_i)
print(results)
%time main()
[5482660]
CPU times: user 10.3 ms, sys: 1 µs, total: 10.3 ms
Wall time: 10.4 ms
Guarantee
To guarantee that our solution is minimal, we need to search over all pentagonal pairs with smaller distances. We modify the second loop so that it gradually shifts the starting point closer to the end point.
P = n * (3 * n - 1) / 2
P(n=i) - P(n=i-1)
-1/2*(3*i - 4)*(i - 1) + 1/2*(3*i - 1)*i
solve(-1/2*(3*i - 4)*(i - 1) + 1/2*(3*i - 1)*i == 5482660, i)
[i == 1827554]
So at we can stop since the the difference between two consecutive pentagonal numbers gets bigger than our upper bound.
var('k d')
solve(P(n=k) - P(n=k-d) == P(n=1912), [d])
[d == k - 1/6*sqrt(36*k^2 - 12*k - 131583839) - 1/6, d == k + 1/6*sqrt(36*k^2 - 12*k - 131583839) - 1/6]
We will use this formula to calcualte a starting point for the second loop based on .
def starting_point(i):
return 1 if i <= 1912 else \
round(1/6*sqrt(36*i^2 - 12*i - 131583839) - 1/6)
import matplotlib as mpl
import matplotlib.pyplot as plt
from IPython.display import SVG, display
mpl.use('svg')
def myplot(plt, scheme, filename):
plt.clf()
plt.style.use(scheme)
x = range(1, 10000, 50)
y = [starting_point(i) for i in x]
plt.plot(x, y, label='start')
plt.plot(x, x, label='end')
plt.title('Iteration interval of j')
plt.legend()
plt.ylim(0, 10000)
plt.xlim(0, 10000)
plt.xlabel('i')
plt.savefig(filename)
myplot(plt, scheme='default', filename='44_1.svg')
myplot(plt, scheme='dark_background', filename='44_1_dark.svg')
display(SVG(filename='44_1_dark.svg'))
We estimate the size of our search space by calculating the area between the curves:
size_estimate = numerical_integral(lambda x : x - starting_point(x), 0, 1827554)
size_estimate[0] + size_estimate[1]
15692239.446603088
The size estimate is used to break up our search space into equal sized chunks to be processed in parallel:
%%cython
from libc.math cimport sqrt, round
cdef long long P(long n):
return n * ( 3 * n - 1) / 2
cdef long starting_point(double i):
return 1 if i <= 1912 else \
<long> round(1./6*sqrt(36*(i**2) - 12*i - 131583839) - 1./6)
cdef int is_perfect_square(long long d):
cdef long s = <long long> sqrt(d)
return s * s == d
cdef int problem_conditions(long long i, long long j):
cdef long long s1 = 36*i*i - 36*j*j - 12*i + 12*j + 1
cdef long long s2 = 36*i*i + 36*j*j - 12*i - 12*j + 1
return is_perfect_square(s1) and \
<long> sqrt(s1) % 6 == 5 and \
is_perfect_square(s2) and \
<long> sqrt(s2) % 6 == 5
def break_points(max_i, search_size, num_batches):
cdef long batch_size = search_size // num_batches
ls_break_points = [1]
cdef int part
cdef long i
cdef long long s = 0
for part in range(num_batches - 1):
s = 0
for i in range(ls_break_points[part], max_i):
s += i - starting_point(i)
if s >= batch_size:
ls_break_points.append(i)
break
ls_break_points.append(max_i)
return ls_break_points
def cython_results(p, p_set, start, end):
results = []
cdef int i, j
for i in range(start, end):
for j in range(starting_point(i), i):
if problem_conditions(i, j):
results.append(abs(P(n=i) - P(n=j)))
return results
import multiprocessing
cpus = multiprocessing.cpu_count()
max_i = 1827554
var('n')
P = n * (3 * n - 1) / 2
@parallel(cpus)
def process_batch(start, end):
return cython_results(p, p_set, start, end)
def main():
ls_break_points = break_points(max_i, search_size=15692239, num_batches=cpus)
ls_break_points_l = ls_break_points.copy()
ls_break_points_l.pop()
ls_break_points.pop(0)
boundries = list(zip(ls_break_points_l, ls_break_points))
results = list(process_batch(boundries))
print(min(flatten([r[1] for r in results])))
%time main()
5482660
CPU times: user 16.6 ms, sys: 89.7 ms, total: 106 ms
Wall time: 143 ms
Hence the solution is guaranteed minimal