看 paper 看烦了来贴点代码玩.
Lagrange插值就是一个多项式乘.
void genLagrangePoly(int n) {
double *x = new double[n + 1];
double *y = new double[n + 1];
double *a = new double[n + 1];
double *tmp = new double[n + 1];
for (int i = 0; i <= n; ++ i) {
x[i] = -5. + 10. / n * i;
y[i] = f(x[i]);
a[i] = 0;
}
for (int i = 0; i <= n; ++ i) {
memset(tmp, 0, sizeof(double) * (n + 1));
tmp[0] = 1.;
double s(y[i]);
for (int j = 0; j <= n; ++ j) {
if (j != i) {
s /= x[i] - x[j];
for (int k = min(j, n - 1); k >= 0; -- k) {
tmp[k + 1] = tmp[k] - tmp[k + 1] * x[j];
}
tmp[0] *= -x[j];
}
}
for (int j = 0; j <= n; ++ j) {
a[j] += tmp[j] * s;
}
}
#define LATEX_OUTPUT
for (int i = 0; i <= n; ++ i) {
#ifdef JS_OUTPUT
/* JS format output */
printf(" %+.12lf*pow(x, %d)", a[i], i);
#else
#ifdef PYTHON_OUTPUT
/* Python format output */
printf(" %+.12lf* x ** %d", a[i], i);
#else
/* LaTeX format output */
printf(" %+lf* x^{%d}", a[i], i);
#endif
#endif
}
putchar(10);
delete [] x;
delete [] y;
delete [] a;
delete [] tmp;
}
三次样条插值是解常数(M_i). 解一个二对角线性方程. 随手高消一下就行.
void genTriSmpPoly(int n) {
double h = 10. / n;
double *x = new double[n + 1];
double *y = new double[n + 1];
double *d = new double[n + 1];
double *m = new double[n + 1];
double *a = new double[n + 1];
for (int i = 0; i <= n; ++ i) {
x[i] = -5. + 10. / n * i;
y[i] = f(x[i]);
}
d[0] = 6. / h * (adf(x[0], x[1]) - df(x[0]));
d[n] = 6. / h * (df(x[n]) - adf(x[n - 1], x[n]));
for (int i = 1; i < n; ++ i) {
d[i] = 6. * (adf(x[i], x[i + 1]) - adf(x[i], x[i - 1])) / (2 * h);
}
a[0] = 2.;
for (int i = 1; i <= n; ++ i) {
double rat(.5 / a[i - 1]);
a[i] = 2. - .5 * rat;
d[i] -= d[i - 1] * rat;
}
m[n] = d[n] / a[n];
for (int i = n - 1; i >= 0; -- i) {
m[i] = (d[i] - m[i + 1] * .5) / a[i];
}
for (int i = 0; i < n; ++ i) {
double a0(m[i] / 6. / h);
double a1(m[i + 1] / 6. / h);
double b0(y[i] / h - m[i] * h / 6.);
double b1(y[i + 1] / h - m[i + 1] * h / 6.);
// #define JS_OUTPUT
#ifdef JS_OUTPUT
/* JS format output */
printf("+ ((x >= %.12lf && x < %.12lf) ? (", x[i], x[i + 1]);
printf(" %+.12lf*pow(%.12lf - x, 3)", a0, x[i + 1]);
printf(" %+.12lf*pow(x %+.12lf, 3)", a1, -x[i]);
printf(" %+.12lf*(%.12lf - x)", b0, x[i + 1]);
printf(" %+.12lf*(x %+.12lf)", b1, -x[i]);
printf(") : 0)");
#else
#ifdef PYTHON_OUTPUT
/* Python format output */
printf("+ ((");
printf(" %+.12lf*(%.12lf - x)**3", a0, x[i + 1]);
printf(" %+.12lf*(x %+.12lf)**3", a1, -x[i]);
printf(" %+.12lf*(%.12lf - x)", b0, x[i + 1]);
printf(" %+.12lf*(x %+.12lf)", b1, -x[i]);
printf(") if x >= %.12lf and x < %.12lf else 0)", x[i], x[i + 1]);
#else
/* LaTeX format output */
printf(" %.3lf(%.1lf - x)^3", a0, x[i + 1]);
printf(" %+.3lf(x %+.1lf)^3", a1, -x[i]);
printf(" %+.3lf(%.1lf - x)", b0, x[i + 1]);
printf(" %+.3lf(x %+.1lf)", b1, -x[i]);
printf("& \\texttt{for } x \\in [%.1lf, %.1lf) \\\\\n", x[i], x[i + 1]);
#endif
#endif
}
putchar(10);
delete [] x;
delete [] y;
delete [] d;
delete [] m;
delete [] a;
}
多项式最小二乘是解一个函数点积和函数值的线性方程组. 还是搞个高消就完了.
void minSqr(int n, int m, double* x, double* y) {
++ n;
double** a = new double*[n];
double* s = new double[n];
for (int i = 0; i < n; ++ i) {
a[i] = new double[n + 1];
}
for (int i = 0; i < n; ++ i) {
for (int j = 0; j < n; ++ j) {
a[i][j] = 0;
for (int k = 0; k < m; ++ k) {
a[i][j] += pow(x[k], i + j);
}
}
a[i][n] = 0;
for (int k = 0; k < m; ++ k) {
a[i][n] += y[k] * pow(x[k], i);
}
}
for (int i = 0; i < n; ++ i) {
for (int j = i + 1; j < n; ++ j) {
double rat(a[i][j] / a[i][i]);
for (int k = i; k <= n; ++ k) {
a[j][k] -= a[i][k] * rat;
}
}
}
s[n - 1] = a[n - 1][n] / a[n - 1][n - 1];
for (int i = n - 2; i >= 0; -- i) {
s[i] = a[i][n];
for (int k = i + 1; k < n; ++ k) {
s[i] -= a[i][k] * s[k];
}
s[i] /= a[i][i];
}
for (int i = 0; i < n; ++ i) {
#ifdef PYTHON_FORMAT
printf("%+.12lf * x**%d ", s[i], i);
#else
printf("%+.3lf * x^%d ", s[i], i);
#endif
}
putchar(10);
double mn(0);
#ifndef PYTHON_FORMAT
printf("$P_{%d}$ & ", n - 1);
for (int i = 0; i < m; ++ i) {
double p(0);
for (int j = 0; j < n; ++ j) {
p += pow(x[i], j) * s[j];
}
printf("$%.3lf$ & ", p);
mn += sqr(p - y[i]);
}
printf("$%.3lf$ \\tabularnewline \\hline \n", mn);
#endif
for (int i = 0; i < n; ++ i) {
delete [] a[i];
}
delete [] a;
delete [] s;
}