포스트

[ 백준 ] 1067 이동 ( Python )

BOJ 1067

FFT(Fast Fourier Transform)란?

  • FFT는 이산 푸리에 변환(DFT, Discrete Fourier Transform)을 빠르게 수행하는 알고리즘입니다.
  • DFT는 시간 도메인(time domain)의 신호를 주파수 도메인(frequency domain)으로 변환하는 작업인데, 이 과정에서 O(N^2)의 시간 복잡도가 소요됩니다.
  • FFT는 분할 정복(divide and conquer) 기법을 활용하여 DFT를 O(NlogN) 시간에 수행할 수 있도록 최적화한 알고리즘입니다.
  • FFT는 신호 처리, 이미지 처리, 데이터 압축 등 다양한 분야에서 활용되고 있습니다.

DFT와 FFT의 수식:

  • DFT는 다음과 같은 수식으로 표현됩니다. X[k] = Σ_{n=0}^{N-1} x[n] * e^{-j*2πkn/N} (k = 0, 1, …, N-1) 여기서 x[n]은 시간 도메인의 신호, X[k]는 주파수 도메인의 신호, N은 신호의 길이, j는 허수 단위를 나타냅니다.
  • FFT는 DFT를 분할 정복 방식으로 재귀적으로 계산합니다. X[k] = Even[k] + e^{-j2πk/N} * Odd[k] (k = 0, 1, …, N/2-1) X[k+N/2] = Even[k] - e^{-j2πk/N} * Odd[k] (k = 0, 1, …, N/2-1) 여기서 Even[k]와 Odd[k]는 각각 짝수 번째, 홀수 번째 요소들의 DFT 결과를 의미합니다.

컨볼루션(Convolution) 연산이란?

  • 컨볼루션 연산은 두 개의 함수 f, g가 주어졌을 때, 다음과 같이 정의됩니다. (f _ g)[n] = Σ_{m=-∞}^{∞} f[m] _ g[n-m]
  • 컨볼루션 연산은 신호 처리, 이미지 처리 등에서 널리 사용되는 연산으로, 두 신호를 겹치면서 더하는 작업입니다.
  • 컨볼루션 연산은 시간 도메인에서 수행하면 O(N^2)의 시간 복잡도가 소요되지만, FFT를 활용하면 O(NlogN)에 계산할 수 있습니다.

문제 해결 과정:

  1. 입력으로 주어진 수열 X와 Y의 길이를 N이라 할 때, FFT를 적용하기 위해 N을 2의 제곱수로 만들어줍니다. 이는 X와 Y의 길이를 2의 제곱수보다 크지 않은 다음 2의 제곱수로 늘리고, 추가된 부분은 0으로 채우는 것을 의미합니다. 이 과정을 패딩(padding)이라고 합니다.

  2. 수열 Y를 뒤집어서 Y_rev를 만듭니다. 이는 컨볼루션 연산 결과가 최대 곱 점수가 되도록 하기 위함입니다.

  3. FFT를 사용하여 수열 X와 Y_rev를 주파수 도메인으로 변환합니다. 변환된 결과를 X_fft와 Y_fft라고 하겠습니다. X_fft = fft(X), Y_fft = fft(Y_rev)

  4. 주파수 도메인에서는 컨볼루션 연산이 요소별 곱셈으로 계산될 수 있습니다. 따라서 X_fft와 Y_fft를 요소별로 곱해줍니다. Z_fft[k] = X_fft[k] * Y_fft[k] (k = 0, 1, …, N-1)

  5. Z_fft를 IFFT(Inverse FFT)를 사용하여 다시 시간 도메인으로 변환합니다. 변환된 결과를 Z라고 하겠습니다. Z = ifft(Z_fft)

  6. Z에는 X와 Y의 컨볼루션 연산 결과가 저장되어 있습니다. 이 중 최댓값을 찾아 출력하면 원하는 최대 곱 점수를 얻을 수 있습니다. result = max(Z)

시행착오 및 풀이 과정: 처음에는 단순히 이중 for문을 사용하여 모든 경우를 계산하려고 했습니다. 수열 X를 고정하고, Y를 한 칸씩 밀어가며 매번 곱 점수를 계산하는 방식입니다. 그러나 이 방법은 O(N^2)의 시간 복잡도를 가지므로, N이 최대 60,000인 이 문제에서는 시간 초과가 발생할 수밖에 없었습니다.

이에 컨볼루션 연산을 활용해 O(NlogN)에 문제를 해결하는 방법을 떠올렸습니다. 컨볼루션 연산은 FFT와 IFFT를 활용하면 효율적으로 계산할 수 있습니다. 그러나 FFT 알고리즘 자체가 생소하고 복잡하게 느껴져서 구현에 어려움을 겪었습니다.

특히 재귀 함수를 활용해 분할 정복 방식으로 FFT를 구현하는 부분이 쉽지 않았습니다. 짝수 번째 요소와 홀수 번째 요소를 분리하여 재귀적으로 FFT를 수행한 뒤, 결과를 조합하는 과정에서 많은 고민이 필요했습니다. 또한 입력 크기에 따라 적절히 패딩을 해주는 부분도 헷갈려서 여러 번 수정이 필요했습니다.

하지만 FFT 알고리즘을 차근차근 이해하고 구현해 나가면서, 동일한 연산을 반복하는 과정을 재귀 호출로 표현할 수 있다는 것을 깨달았습니다. 또한 복소수 연산과 분할 정복 과정에 익숙해지면서 FFT의 동작 원리를 명확히 이해할 수 있었습니다.

결과적으로, FFT와 IFFT를 활용한 컨볼루션 연산으로 문제를 해결할 수 있었습니다. 단순한 방법으로는 시간 초과가 발생하던 문제를 효율적인 알고리즘을 통해 해결하는 과정이 인상 깊었고, 알고리즘의 중요성을 다시 한번 실감할 수 있었습니다.

해답 보기

전체코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import sys
from cmath import exp, pi
input = sys.stdin.readline

def fft(x):
    n = len(x)
    if n == 1:
        return x
    even_part = fft(x[0::2])
    odd_part = fft(x[1::2])
    w = [exp(2j * pi * k / n) for k in range(n // 2)]
    return [even_part[k] + w[k] * odd_part[k] for k in range(n // 2)] + \
           [even_part[k] - w[k] * odd_part[k] for k in range(n // 2)]

def ifft(x):
    n = len(x)
    if n == 1:
        return x
    even_part = ifft(x[0::2])
    odd_part = ifft(x[1::2])
    w = [exp(-2j * pi * k / n) for k in range(n // 2)]
    return [even_part[k] + w[k] * odd_part[k] for k in range(n // 2)] + \
           [even_part[k] - w[k] * odd_part[k] for k in range(n // 2)]

m = int(input())
n = 2 * m
flag = 0
for i in range(18):
    if m == 2**i:
        flag = -100
        break
    elif n < 2**i:
        flag = i
        break

a = list(map(int, input().split()))
b = list(map(int, input().split()))

if flag == -100:
    a = a + a
    b = b[::-1] + [0] * m
    c = [0] * n
    a_fft = fft(a)
    b_fft = fft(b)
    for i in range(n):
        c[i] = a_fft[i] * b_fft[i]
    c_ifft = ifft(c)
    result = max(round(c_ifft[k].real / n) for k in range(n))
else:
    n_prime = 2**i
    n, n_prime = n_prime, n
    a = a + [0] * (n - n_prime // 2)
    b = b[::-1] + [0] * (n - n_prime) + b[::-1]
    c = [0] * n
    a_fft = fft(a)
    b_fft = fft(b)
    for i in range(n):
        c[i] = a_fft[i] * b_fft[i]
    c_ifft = ifft(c)
    result = max(round(c_ifft[k].real / n) for k in range(n))

print(result)

이 문제를 통해 FFT와 컨볼루션 연산에 대해 깊이 있게 공부할 수 있었습니다. 수학적 개념을 코드로 구현하는 과정이 쉽지는 않았지만, 차근차근 이해하고 적용해 나가면서 알고리즘의 동작 원리를 파악할 수 있었습니다.

또한 단순한 방법으로는 해결하기 어려운 문제도 적절한 알고리즘을 활용한다면 효과적으로 풀어낼 수 있다는 점을 배웠습니다. 앞으로도 다양한 알고리즘을 학습하고 응용하여 문제 해결 능력을 꾸준히 키워나가고 싶습니다.

긴 설명을 읽어주셔서 감사합니다. 혹시 이해가 되지 않는 부분이나 추가적인 설명이 필요한 부분이 있다면 말씀해 주세요. 더 자세히 설명드리겠습니다.

이 기사는 저작권자의 CC BY 4.0 라이센스를 따릅니다.