看 paper 看烦了来贴点代码玩.

Lagrange插值就是一个多项式乘.


void genLagrangePoly(int n) {
	double *x = new double[n + 1];
	double *y = new double[n + 1];
	double *a = new double[n + 1];
	double *tmp = new double[n + 1];
	for (int i = 0; i <= n; ++ i) {
		x[i] = -5. + 10. / n * i;
		y[i] = f(x[i]);
		a[i] = 0;
	}
	for (int i = 0; i <= n; ++ i) {
		memset(tmp, 0, sizeof(double) * (n + 1));
		tmp[0] = 1.;
		double s(y[i]);
		for (int j = 0; j <= n; ++ j) {
			if (j != i) {
				s /= x[i] - x[j];
				for (int k = min(j, n - 1); k >= 0; -- k) {
					tmp[k + 1] = tmp[k] - tmp[k + 1] * x[j];
				}
				tmp[0] *= -x[j];
			}
		}
		for (int j = 0; j <= n; ++ j) {
			a[j] += tmp[j] * s;
		}
	}
#define LATEX_OUTPUT
	for (int i = 0; i <= n; ++ i) {
#ifdef JS_OUTPUT
		/* JS format output */
		printf(" %+.12lf*pow(x, %d)", a[i], i);
#else
#ifdef PYTHON_OUTPUT
		/* Python format output */
		printf(" %+.12lf* x ** %d", a[i], i);
#else
		/* LaTeX format output */
		printf(" %+lf* x^{%d}", a[i], i);
#endif
#endif
	}
	putchar(10);
	delete [] x;
	delete [] y;
	delete [] a;
	delete [] tmp;
}

三次样条插值是解常数(M_i). 解一个二对角线性方程. 随手高消一下就行.


void genTriSmpPoly(int n) {
	double h = 10. / n;
	double *x = new double[n + 1];
	double *y = new double[n + 1];
	double *d = new double[n + 1];
	double *m = new double[n + 1];
	double *a = new double[n + 1];
	for (int i = 0; i <= n; ++ i) {
		x[i] = -5. + 10. / n * i;
		y[i] = f(x[i]);
	}
	d[0] = 6. / h * (adf(x[0], x[1]) - df(x[0]));
	d[n] = 6. / h * (df(x[n]) - adf(x[n - 1], x[n]));
	for (int i = 1; i < n; ++ i) {
		d[i] = 6. * (adf(x[i], x[i + 1]) - adf(x[i], x[i - 1])) / (2 * h);
	}
	a[0] = 2.;
	for (int i = 1; i <= n; ++ i) {
		double rat(.5 / a[i - 1]);
		a[i] = 2. - .5 * rat;
		d[i] -= d[i - 1] * rat;
	}
	m[n] = d[n] / a[n];
	for (int i = n - 1; i >= 0; -- i) {
		m[i] = (d[i] - m[i + 1] * .5) / a[i];
	}
	for (int i = 0; i < n; ++ i) {
		double a0(m[i] / 6. / h);
		double a1(m[i + 1] / 6. / h);
		double b0(y[i] / h - m[i] * h / 6.);
		double b1(y[i + 1] / h - m[i + 1] * h / 6.);
// #define JS_OUTPUT
#ifdef JS_OUTPUT
		/* JS format output */
		printf("+ ((x >= %.12lf && x < %.12lf) ? (", x[i], x[i + 1]);
		printf(" %+.12lf*pow(%.12lf - x, 3)", a0, x[i + 1]);
		printf(" %+.12lf*pow(x %+.12lf, 3)", a1, -x[i]);
		printf(" %+.12lf*(%.12lf - x)", b0, x[i + 1]);
		printf(" %+.12lf*(x %+.12lf)", b1, -x[i]);
		printf(") : 0)");
#else
#ifdef PYTHON_OUTPUT
		/* Python format output */
		printf("+ ((");
		printf(" %+.12lf*(%.12lf - x)**3", a0, x[i + 1]);
		printf(" %+.12lf*(x %+.12lf)**3", a1, -x[i]);
		printf(" %+.12lf*(%.12lf - x)", b0, x[i + 1]);
		printf(" %+.12lf*(x %+.12lf)", b1, -x[i]);
		printf(") if x >= %.12lf and x < %.12lf else 0)", x[i], x[i + 1]);
#else
		/* LaTeX format output */
		printf(" %.3lf(%.1lf - x)^3", a0, x[i + 1]);
		printf(" %+.3lf(x %+.1lf)^3", a1, -x[i]);
		printf(" %+.3lf(%.1lf - x)", b0, x[i + 1]);
		printf(" %+.3lf(x %+.1lf)", b1, -x[i]);
		printf("& \\texttt{for } x \\in [%.1lf, %.1lf) \\\\\n", x[i], x[i + 1]);
#endif
#endif
	}
	putchar(10);
	delete [] x;
	delete [] y;
	delete [] d;
	delete [] m;
	delete [] a;
}

多项式最小二乘是解一个函数点积和函数值的线性方程组. 还是搞个高消就完了.


void minSqr(int n, int m, double* x, double* y) {
	++ n;
	double** a = new double*[n];
	double* s = new double[n];
	for (int i = 0; i < n; ++ i) {
		a[i] = new double[n + 1];
	}
	for (int i = 0; i < n; ++ i) {
		for (int j = 0; j < n; ++ j) {
			a[i][j] = 0;
			for (int k = 0; k < m; ++ k) {
				a[i][j] += pow(x[k], i + j);
			}
		}
		a[i][n] = 0;
		for (int k = 0; k < m; ++ k) {
			a[i][n] += y[k] * pow(x[k], i);
		}
	}
	for (int i = 0; i < n; ++ i) {
		for (int j = i + 1; j < n; ++ j) {
			double rat(a[i][j] / a[i][i]);
			for (int k = i; k <= n; ++ k) {
				a[j][k] -= a[i][k] * rat;
			}
		}
	}
	s[n - 1] = a[n - 1][n] / a[n - 1][n - 1];
	for (int i = n - 2; i >= 0; -- i) {
		s[i] = a[i][n];
		for (int k = i + 1; k < n; ++ k) {
			s[i] -= a[i][k] * s[k];
		}
		s[i] /= a[i][i];
	}
	for (int i = 0; i < n; ++ i) {
#ifdef PYTHON_FORMAT
		printf("%+.12lf * x**%d ", s[i], i);
#else
		printf("%+.3lf * x^%d ", s[i], i);
#endif
	}
	putchar(10);
	double mn(0);
#ifndef PYTHON_FORMAT
	printf("$P_{%d}$ & ", n - 1);
	for (int i = 0; i < m; ++ i) {
		double p(0);
		for (int j = 0; j < n; ++ j) {
			p += pow(x[i], j) * s[j];
		}
		printf("$%.3lf$ & ", p);
		mn += sqr(p - y[i]);
	}
	printf("$%.3lf$ \\tabularnewline \\hline \n", mn);
#endif
	for (int i = 0; i < n; ++ i) {
		delete [] a[i];
	}
	delete [] a;
	delete [] s;
}