# Concept
- `Fourier Transform` : 시간에 대한 함수를 구성하고 있는 주파수 성분으로 분해하는 과정을 말한다.
- Fourier Transform 원리 (출처 : 공대생의 차고)
![[FT 원리.gif]]
- `DFT(Discrete Fourier Transform)` : 연속이 아닌 **이산 시간**에 대한 함수를 **이산 주파수**로 변환하는 과정을 말한다. 컴퓨터 속 퓨리에 변환은 대부분 DFT를 의미한다.
- FFT는 이러한 퓨리에 변환을 빠르게 하는 방법을 말하며 주로 알고리즘에서의 FFT는 이산 합성곱(두 벡터의 합성곱)을 `O(nlogn)`의 시간복잡도로 계산하는 알고리즘을 말한다. ex) (x+2)(x+3) = x<sup>2</sup>+5x+6 -> `v = [1,2] u = [1,3] v * u = [1, 5, 6]`
# FFT 원리
- 2부터 차례대로 배수를 구하면서 N이하인 2의 배수들을 배열에서 제거해나가는 방식이다.
#### 🖼️그림으로 이해하기
![[]]
# FFT CODE
- 1~n까지의 숫자들 중 소수가 무엇인지를 찾는 것이기 때문에 n이 커지면 커질 수록 계산 시간이 오래 걸린다.
#### ⌨️ Code
```cpp
```
##### ❓ 예제 Input
56
##### ⭐ 예제 Output
2 3 5 7 11 13 17 19 23 29 31 37 41 43 47 53
# FFT 응용문제
### 📑[17134 - 르모앙의 추측](https://www.acmicpc.net/problem/17134)
#### 🔓 KeyPoint
- 4자리 소수에서 한 자리씩 수를 바꾸어 최종 원하는 소수로 바꾸는 과정에서 존재하는 모든 수
#### ⌨️ Code
```cpp
#include <bits/stdc++.h>
#define MAX 1000000
using namespace std;
typedef complex<double> cpx;
const double PI = acos(-1);
int t;
bool IsPrime[MAX] = {0, };
void FFT(vector<cpx> &v_cpx, bool IsIDFT) {
int n = (int)v_cpx.size();
for ( int i = 1, j = 0; i < n; i++ ) {
int bit = n / 2;
while ( j >= bit ) {
j -= bit;
bit /= 2;
}
j += bit;
if ( i < j ) swap(v_cpx[i], v_cpx[j]);
}
for ( int i = 1; i < n; i *= 2 ) {
cpx w;
if ( IsIDFT ) w = cpx(cos(PI / i), sin(PI / i));
else w = cpx(cos(-PI / i), sin(-PI / i));
for ( int j = 0; j < n; j += i * 2 ) {
cpx nw(1, 0);
for ( int k = 0; k < i; k++ ) {
cpx even = v_cpx[j+k];
cpx odd = v_cpx[i+j+k];
v_cpx[j+k] = even + nw * odd;
v_cpx[i+j+k] = even - nw * odd;
nw *= w;
}
}
}
if ( IsIDFT ) {
for ( int i = 0; i < n; i++ ) v_cpx[i] /= n;
}
}
vector<cpx> multiply(vector<cpx> &v_cpx, vector<cpx> &u_cpx) {
int n = 2;
while ( n < (int)v_cpx.size() + (int)u_cpx.size() ) n *= 2;
v_cpx.resize(n);
u_cpx.resize(n);
FFT(v_cpx, false);
FFT(u_cpx, false);
vector<cpx> r_cpx(n);
for ( int i = 0; i < n; i++ ) r_cpx[i] = v_cpx[i] * u_cpx[i];
FFT(r_cpx, true);
return r_cpx;
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL); cout.tie(NULL);
cin >> t;
memset(IsPrime, 1, sizeof(IsPrime));
for ( int i = 2; i <= sqrt(MAX); i++ ) {
if ( !IsPrime[i] ) continue;
for ( int j = i * i; j < MAX; j += i ) IsPrime[j] = false;
}
vector<cpx> v_cpx(MAX), u_cpx(MAX);
for ( int i = 2; i < MAX; i++ ) {
if ( IsPrime[i] ) {
v_cpx[i] = cpx(1, 0);
if ( i * 2 < MAX ) u_cpx[i*2] = cpx(1, 0);
}
}
vector<cpx> r_cpx = multiply(v_cpx, u_cpx);
while ( t-- ) {
int n;
cin >> n;
cout << round(r_cpx[n].real()) << '\n';
}
return 0;
}
```
### 📑[22289 - 큰 수 곱셈 (3)](https://www.acmicpc.net/problem/22289)
#### 🔓 KeyPoint
- 여러 수로 이루어진 배열을 두 개의 쌍(그룹)으로 나누어 짝지었을때, 짝지은 모든 수들의 합이
#### ⌨️ Code
```cpp
#include <bits/stdc++.h>
using namespace std;
typedef complex<double> cpx;
const double PI = acos(-1);
void FFT(vector<cpx> &v, bool IsIDFT) {
int n = (int)v.size();
for ( int i = 1, j = 0; i < n; i++ ) {
int bit = n / 2;
while ( j >= bit ) {
j -= bit;
bit /= 2;
}
j += bit;
if ( i < j ) swap(v[i], v[j]);
}
for ( int i = 1; i < n; i *= 2 ) {
cpx w;
if ( IsIDFT ) w = cpx(cos(PI / i), sin(PI / i));
else w = cpx(cos(-PI / i), sin(-PI / i));
for ( int j = 0; j < n; j += i*2 ) {
cpx nw(1, 0);
for ( int k = 0; k < i; k++ ) {
cpx even = v[j+k];
cpx odd = v[i+j+k];
v[j+k] = even + nw * odd;
v[i+j+k] = even - nw * odd;
nw *= w;
}
}
}
if ( IsIDFT ) {
for ( int i = 0; i < n; i++ ) v[i] /= n;
}
}
vector<int> multiply(vector<int> &v, vector<int> &u) {
vector<cpx> v_cpx, u_cpx;
for ( int i = 0; i < (int)v.size(); i++ ) v_cpx.push_back(cpx(v[i], 0));
for ( int i = 0; i < (int)u.size(); i++ ) u_cpx.push_back(cpx(u[i], 0));
int n = 2;
while ( n < (int)v.size() + (int)u.size() ) n *= 2;
v_cpx.resize(n);
u_cpx.resize(n);
FFT(v_cpx, false);
FFT(u_cpx, false);
vector<cpx> r_cpx(n);
for( int i = 0; i < n; i++ ) r_cpx[i] = v_cpx[i] * u_cpx[i];
FFT(r_cpx, true);
vector<int> result(n);
for( int i = 0; i < n; i++ ) result[i] = round(r_cpx[i].real());
return result;
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL); cout.tie(NULL);
string a, b;
cin >> a >> b;
vector<int> v, u;
for ( int i = 0; i < (int)a.length(); i++ ) v.push_back(a[i] - '0');
for ( int i = 0; i < (int)b.length(); i++ ) u.push_back(b[i] - '0');
reverse(v.begin(), v.end());
reverse(u.begin(), u.end());
vector<int> result = multiply(v, u);
for ( int i = 0; i < (int)result.size() - 1; i++ ) {
if ( result[i] / 10 ) {
result[i+1] += result[i] / 10;
result[i] %= 10;
}
}
reverse(result.begin(), result.end());
int idx = 0;
while(result[idx] == 0) idx++;
if(idx >= (int)result.size()) {
cout << 0;
return 0;
}
while(idx < (int)result.size()) {
cout << result[idx];
idx++;
}
return 0;
}
```