class Solution:
def nthMagicalNumber(self, n: int, a: int, b: int) -> int:
def gcd(x: int, y: int) -> int:
if x > y:
x, y = y, x
if x == 0:
return y
return gcd(y % x, x)
common = a * (b / gcd(a, b))
i = 1
j = a * n
while i < j:
m = (i + j) // 2
t = m // a + m // b - m // common
if t < n:
i = m + 1
else:
j = m
return i % (10 ** 9 + 7)
if __name__ == '__main__':
x = Solution()
n = 8
a = 10
b = 5
print(x.nthMagicalNumber(n, a, b))
其次用数学方法,利用周期性
import math
class Solution:
def nthMagicalNumber(self, n: int, a: int, b: int) -> int:
common = a / math.gcd(a, b) * b
L = common // a + common // b - 1
t, r = divmod(n, L)
nums = [a, b]
if r == 0:
return int(t * common) % (10 ** 9 + 7)
for i in range(int(r) - 1):
if nums[0] < nums[1]:
nums[0] += a
else:
nums[1] += b
return int(t * common + min(nums)) % (10 ** 9 + 7)
if __name__ == '__main__':
x = Solution()
n = 4
a = 2
b = 3
print(x.nthMagicalNumber(n, a, b))
1201. 丑数 III
本题首先使用的是昨天丑数二的三指针法,但报超时
class Solution:
def nthUglyNumber(self, n: int, a: int, b: int, c: int) -> int:
ptr = [1, 1, 1]
nums = [a, b, c]
minX = 0
for i in range(n):
minX = ptr[0] * nums[0]
for j in range(1, len(ptr)):
minX = min(minX, ptr[j] * nums[j])
for j in range(len(nums)):
if minX == ptr[j] * nums[j]:
ptr[j] += 1
return minX
答案缘起于一道hard题,如上,首先是用二分法
import math
class Solution:
def nthUglyNumber(self, n: int, a: int, b: int, c: int) -> int:
common = [0, a // math.gcd(a, b) * b, b // math.gcd(b, c) * c, a // math.gcd(a, c) * c]
common[0] = a // math.gcd(a, common[2]) * common[2]
i = 1
j = min(a, b, c) * n
while i < j:
m = (i + j) >> 1
num = m // a + m // b + m // c - m // common[1] - m // common[2] - m // common[3] + m // common[0]
if num < n:
i = m + 1
else:
j = m
return i
if __name__ == '__main__':
x = Solution()
n = 4
a = 3
b = 5
c = 7
print(x.nthUglyNumber(n, a, b, c))
再用一个周期法+二分,复杂度降到最低
import math
class Solution:
def nthUglyNumber(self, n: int, a: int, b: int, c: int) -> int:
common = [0, a // math.gcd(a, b) * b, b // math.gcd(b, c) * c, a // math.gcd(a, c) * c]
common[0] = a // math.gcd(a, common[2]) * common[2]
def num(x: int) -> int:
return x // a + x // b + x // c - x // common[1] - x // common[2] - x // common[3] + x // common[0]
L = num(common[0])
x, y = divmod(n, L)
if y == 0:
return x * common[0]
i = 1
j = min(a, b, c) * n
while i < j:
m = (i + j) >> 1
if num(m) < y:
i = m + 1
else:
j = m
return x * common[0] + i
if __name__ == '__main__':
x = Solution()
n = 3
a = 3
b = 4
c = 7
print(x.nthUglyNumber(n, a, b, c))