Squiz Matrix  4.12.2
 All Data Structures Namespaces Functions Variables Pages
CurveFitter.java
1 package ij.measure;
2 import ij.*;
3 import ij.gui.*;
4 
22 public class CurveFitter {
23  public static final int STRAIGHT_LINE=0,POLY2=1,POLY3=2,POLY4=3,
24  EXPONENTIAL=4,POWER=5,LOG=6,RODBARD=7,GAMMA_VARIATE=8, LOG2=9;
25 
26  public static final int IterFactor = 500;
27 
28  public static final String[] fitList = {"Straight Line","2nd Degree Polynomial",
29  "3rd Degree Polynomial", "4th Degree Polynomial","Exponential","Power",
30  "log","Rodbard", "Gamma Variate", "y = a+b*ln(x-c)"};
31 
32  public static final String[] fList = {"y = a+bx","y = a+bx+cx^2",
33  "y = a+bx+cx^2+dx^3", "y = a+bx+cx^2+dx^3+ex^4","y = a*exp(bx)","y = ax^b",
34  "y = a*ln(bx)", "y = d+(a-d)/(1+(x/c)^b)", "y = a*(x-b)^c*exp(-(x-b)/d)", "y = a+b*ln(x-c)"};
35 
36  private static final double alpha = -1.0; // reflection coefficient
37  private static final double beta = 0.5; // contraction coefficient
38  private static final double gamma = 2.0; // expansion coefficient
39  private static final double root2 = 1.414214; // square root of 2
40 
41  private int fit; // Number of curve type to fit
42  private double[] xData, yData; // x,y data to fit
43  private int numPoints; // number of data points
44  private int numParams; // number of parametres
45  private int numVertices; // numParams+1 (includes sumLocalResiduaalsSqrd)
46  private int worst; // worst current parametre estimates
47  private int nextWorst; // 2nd worst current parametre estimates
48  private int best; // best current parametre estimates
49  private double[][] simp; // the simplex (the last element of the array at each vertice is the sum of the square of the residuals)
50  private double[] next; // new vertex to be tested
51  private int numIter; // number of iterations so far
52  private int maxIter; // maximum number of iterations per restart
53  private int restarts; // number of times to restart simplex after first soln.
54  private double maxError; // maximum error tolerance
55 
57  public CurveFitter (double[] xData, double[] yData) {
58  this.xData = xData;
59  this.yData = yData;
60  numPoints = xData.length;
61  }
62 
70  public void doFit(int fitType) {
71  doFit(fitType, false);
72  }
73 
74  public void doFit(int fitType, boolean showSettings) {
75  if (fitType < STRAIGHT_LINE || fitType > LOG2)
76  throw new IllegalArgumentException("Invalid fit type");
77  fit = fitType;
78  initialize();
79  if (showSettings) settingsDialog();
80  restart(0);
81 
82  numIter = 0;
83  boolean done = false;
84  double[] center = new double[numParams]; // mean of simplex vertices
85  while (!done) {
86  numIter++;
87  for (int i = 0; i < numParams; i++) center[i] = 0.0;
88  // get mean "center" of vertices, excluding worst
89  for (int i = 0; i < numVertices; i++)
90  if (i != worst)
91  for (int j = 0; j < numParams; j++)
92  center[j] += simp[i][j];
93  // Reflect worst vertex through centre
94  for (int i = 0; i < numParams; i++) {
95  center[i] /= numParams;
96  next[i] = center[i] + alpha*(simp[worst][i] - center[i]);
97  }
98  sumResiduals(next);
99  // if it's better than the best...
100  if (next[numParams] <= simp[best][numParams]) {
101  newVertex();
102  // try expanding it
103  for (int i = 0; i < numParams; i++)
104  next[i] = center[i] + gamma * (simp[worst][i] - center[i]);
105  sumResiduals(next);
106  // if this is even better, keep it
107  if (next[numParams] <= simp[worst][numParams])
108  newVertex();
109  }
110  // else if better than the 2nd worst keep it...
111  else if (next[numParams] <= simp[nextWorst][numParams]) {
112  newVertex();
113  }
114  // else try to make positive contraction of the worst
115  else {
116  for (int i = 0; i < numParams; i++)
117  next[i] = center[i] + beta*(simp[worst][i] - center[i]);
118  sumResiduals(next);
119  // if this is better than the second worst, keep it.
120  if (next[numParams] <= simp[nextWorst][numParams]) {
121  newVertex();
122  }
123  // if all else fails, contract simplex in on best
124  else {
125  for (int i = 0; i < numVertices; i++) {
126  if (i != best) {
127  for (int j = 0; j < numVertices; j++)
128  simp[i][j] = beta*(simp[i][j]+simp[best][j]);
129  sumResiduals(simp[i]);
130  }
131  }
132  }
133  }
134  order();
135 
136  double rtol = 2 * Math.abs(simp[best][numParams] - simp[worst][numParams]) /
137  (Math.abs(simp[best][numParams]) + Math.abs(simp[worst][numParams]) + 0.0000000001);
138 
139  if (numIter >= maxIter) done = true;
140  else if (rtol < maxError) {
141  //System.out.print(getResultString());
142  restarts--;
143  if (restarts < 0) {
144  done = true;
145  }
146  else {
147  restart(best);
148  }
149  }
150  }
151  }
152 
155  void initialize() {
156  // Calculate some things that might be useful for predicting parametres
157  numParams = getNumParams();
158  numVertices = numParams + 1; // need 1 more vertice than parametres,
159  simp = new double[numVertices][numVertices];
160  next = new double[numVertices];
161 
162  double firstx = xData[0];
163  double firsty = yData[0];
164  double lastx = xData[numPoints-1];
165  double lasty = yData[numPoints-1];
166  double xmean = (firstx+lastx)/2.0;
167  double ymean = (firsty+lasty)/2.0;
168  double slope;
169  if ((lastx - firstx) != 0.0)
170  slope = (lasty - firsty)/(lastx - firstx);
171  else
172  slope = 1.0;
173  double yintercept = firsty - slope * firstx;
174  maxIter = IterFactor * numParams * numParams; // Where does this estimate come from?
175  restarts = 1;
176  maxError = 1e-9;
177  switch (fit) {
178  case STRAIGHT_LINE:
179  simp[0][0] = yintercept;
180  simp[0][1] = slope;
181  break;
182  case POLY2:
183  simp[0][0] = yintercept;
184  simp[0][1] = slope;
185  simp[0][2] = 0.0;
186  break;
187  case POLY3:
188  simp[0][0] = yintercept;
189  simp[0][1] = slope;
190  simp[0][2] = 0.0;
191  simp[0][3] = 0.0;
192  break;
193  case POLY4:
194  simp[0][0] = yintercept;
195  simp[0][1] = slope;
196  simp[0][2] = 0.0;
197  simp[0][3] = 0.0;
198  simp[0][4] = 0.0;
199  break;
200  case EXPONENTIAL:
201  simp[0][0] = 0.1;
202  simp[0][1] = 0.01;
203  break;
204  case POWER:
205  simp[0][0] = 0.0;
206  simp[0][1] = 1.0;
207  break;
208  case LOG:
209  simp[0][0] = 0.5;
210  simp[0][1] = 0.05;
211  break;
212  case RODBARD:
213  simp[0][0] = firsty;
214  simp[0][1] = 1.0;
215  simp[0][2] = xmean;
216  simp[0][3] = lasty;
217  break;
218  case GAMMA_VARIATE:
219  // First guesses based on following observations:
220  // t0 [b] = time of first rise in gamma curve - so use the user specified first limit
221  // tm = t0 + a*B [c*d] where tm is the time of the peak of the curve
222  // therefore an estimate for a and B is sqrt(tm-t0)
223  // K [a] can now be calculated from these estimates
224  simp[0][0] = firstx;
225  double ab = xData[getMax(yData)] - firstx;
226  simp[0][2] = Math.sqrt(ab);
227  simp[0][3] = Math.sqrt(ab);
228  simp[0][1] = yData[getMax(yData)] / (Math.pow(ab, simp[0][2]) * Math.exp(-ab/simp[0][3]));
229  break;
230  case LOG2:
231  simp[0][0] = 0.5;
232  simp[0][1] = 0.05;
233  simp[0][2] = 0.0;
234  break;
235  }
236  }
237 
239  private void settingsDialog() {
240  GenericDialog gd = new GenericDialog("Simplex Fitting Options", IJ.getInstance());
241  gd.addMessage("Function name: " + fitList[fit] + "\n" +
242  "Formula: " + fList[fit]);
243  char pChar = 'a';
244  for (int i = 0; i < numParams; i++) {
245  gd.addNumericField("Initial "+(new Character(pChar)).toString()+":", simp[0][i], 2);
246  pChar++;
247  }
248  gd.addNumericField("Maximum iterations:", maxIter, 0);
249  gd.addNumericField("Number of restarts:", restarts, 0);
250  gd.addNumericField("Error tolerance [1*10^(-x)]:", -(Math.log(maxError)/Math.log(10)), 0);
251  gd.showDialog();
252  if (gd.wasCanceled() || gd.invalidNumber()) {
253  IJ.error("Parameter setting canceled.\nUsing default parameters.");
254  }
255  // Parametres:
256  for (int i = 0; i < numParams; i++) {
257  simp[0][i] = gd.getNextNumber();
258  }
259  maxIter = (int) gd.getNextNumber();
260  restarts = (int) gd.getNextNumber();
261  maxError = Math.pow(10.0, -gd.getNextNumber());
262  }
263 
265  void restart(int n) {
266  // Copy nth vertice of simplex to first vertice
267  for (int i = 0; i < numParams; i++) {
268  simp[0][i] = simp[n][i];
269  }
270  sumResiduals(simp[0]); // Get sum of residuals^2 for first vertex
271  double[] step = new double[numParams];
272  for (int i = 0; i < numParams; i++) {
273  step[i] = simp[0][i] / 2.0; // Step half the parametre value
274  if (step[i] == 0.0) // We can't have them all the same or we're going nowhere
275  step[i] = 0.01;
276  }
277  // Some kind of factor for generating new vertices
278  double[] p = new double[numParams];
279  double[] q = new double[numParams];
280  for (int i = 0; i < numParams; i++) {
281  p[i] = step[i] * (Math.sqrt(numVertices) + numParams - 1.0)/(numParams * root2);
282  q[i] = step[i] * (Math.sqrt(numVertices) - 1.0)/(numParams * root2);
283  }
284  // Create the other simplex vertices by modifing previous one.
285  for (int i = 1; i < numVertices; i++) {
286  for (int j = 0; j < numParams; j++) {
287  simp[i][j] = simp[i-1][j] + q[j];
288  }
289  simp[i][i-1] = simp[i][i-1] + p[i-1];
290  sumResiduals(simp[i]);
291  }
292  // Initialise current lowest/highest parametre estimates to simplex 1
293  best = 0;
294  worst = 0;
295  nextWorst = 0;
296  order();
297  }
298 
299  // Display simplex [Iteration: s0(p1, p2....), s1(),....] in ImageJ window
300  void showSimplex(int iter) {
301  ij.IJ.write("" + iter);
302  for (int i = 0; i < numVertices; i++) {
303  String s = "";
304  for (int j=0; j < numVertices; j++)
305  s += " "+ ij.IJ.d2s(simp[i][j], 6);
306  ij.IJ.write(s);
307  }
308  }
309 
311  public int getNumParams() {
312  switch (fit) {
313  case STRAIGHT_LINE: return 2;
314  case POLY2: return 3;
315  case POLY3: return 4;
316  case POLY4: return 5;
317  case EXPONENTIAL: return 2;
318  case POWER: return 2;
319  case LOG: return 2;
320  case RODBARD: return 4;
321  case GAMMA_VARIATE: return 4;
322  case LOG2: return 3;
323  }
324  return 0;
325  }
326 
328  public static double f(int fit, double[] p, double x) {
329  switch (fit) {
330  case STRAIGHT_LINE:
331  return p[0] + p[1]*x;
332  case POLY2:
333  return p[0] + p[1]*x + p[2]* x*x;
334  case POLY3:
335  return p[0] + p[1]*x + p[2]*x*x + p[3]*x*x*x;
336  case POLY4:
337  return p[0] + p[1]*x + p[2]*x*x + p[3]*x*x*x + p[4]*x*x*x*x;
338  case EXPONENTIAL:
339  return p[0]*Math.exp(p[1]*x);
340  case POWER:
341  if (x == 0.0)
342  return 0.0;
343  else
344  return p[0]*Math.exp(p[1]*Math.log(x)); //y=ax^b
345  case LOG:
346  if (x == 0.0)
347  x = 0.5;
348  return p[0]*Math.log(p[1]*x);
349  case RODBARD:
350  double ex;
351  if (x == 0.0)
352  ex = 0.0;
353  else
354  ex = Math.exp(Math.log(x/p[2])*p[1]);
355  double y = p[0]-p[3];
356  y = y/(1.0+ex);
357  return y+p[3];
358  case GAMMA_VARIATE:
359  if (p[0] >= x) return 0.0;
360  if (p[1] <= 0) return -100000.0;
361  if (p[2] <= 0) return -100000.0;
362  if (p[3] <= 0) return -100000.0;
363 
364  double pw = Math.pow((x - p[0]), p[2]);
365  double e = Math.exp((-(x - p[0]))/p[3]);
366  return p[1]*pw*e;
367  case LOG2:
368  double tmp = x-p[2];
369  if (tmp<0.001) tmp = 0.001;
370  return p[0]+p[1]*Math.log(tmp);
371  default:
372  return 0.0;
373  }
374  }
375 
377  public double[] getParams() {
378  order();
379  return simp[best];
380  }
381 
383  public double[] getResiduals() {
384  double[] params = getParams();
385  double[] residuals = new double[numPoints];
386  for (int i = 0; i < numPoints; i++)
387  residuals[i] = yData[i] - f(fit, params, xData[i]);
388  return residuals;
389  }
390 
391  /* Last "parametre" at each vertex of simplex is sum of residuals
392  * for the curve described by that vertex
393  */
394  public double getSumResidualsSqr() {
395  double sumResidualsSqr = (getParams())[getNumParams()];
396  return sumResidualsSqr;
397  }
398 
401  public double getSD() {
402  double sd = Math.sqrt(getSumResidualsSqr() / numVertices);
403  return sd;
404  }
405 
409  public double getFitGoodness() {
410  double sumY = 0.0;
411  for (int i = 0; i < numPoints; i++) sumY += yData[i];
412  double mean = sumY / numVertices;
413  double sumMeanDiffSqr = 0.0;
414  int degreesOfFreedom = numPoints - getNumParams();
415  double fitGoodness = 0.0;
416  for (int i = 0; i < numPoints; i++) {
417  sumMeanDiffSqr += sqr(yData[i] - mean);
418  }
419  if (sumMeanDiffSqr > 0.0 && degreesOfFreedom != 0)
420  fitGoodness = 1.0 - (getSumResidualsSqr() / degreesOfFreedom) * ((numParams) / sumMeanDiffSqr);
421 
422  return fitGoodness;
423  }
424 
428  public String getResultString() {
429  StringBuffer results = new StringBuffer("\nNumber of iterations: " + getIterations() +
430  "\nMaximum number of iterations: " + getMaxIterations() +
431  "\nSum of residuals squared: " + getSumResidualsSqr() +
432  "\nStandard deviation: " + getSD() +
433  "\nGoodness of fit: " + getFitGoodness() +
434  "\nParameters:");
435  char pChar = 'a';
436  double[] pVal = getParams();
437  for (int i = 0; i < numParams; i++) {
438  results.append("\n" + pChar + " = " + pVal[i]);
439  pChar++;
440  }
441  return results.toString();
442  }
443 
444  double sqr(double d) { return d * d; }
445 
447  void sumResiduals (double[] x) {
448  x[numParams] = 0.0;
449  for (int i = 0; i < numPoints; i++) {
450  x[numParams] = x[numParams] + sqr(f(fit,x,xData[i])-yData[i]);
451  // if (IJ.debugMode) ij.IJ.log(i+" "+x[n-1]+" "+f(fit,x,xData[i])+" "+yData[i]);
452  }
453  }
454 
456  void newVertex() {
457  for (int i = 0; i < numVertices; i++)
458  simp[worst][i] = next[i];
459  }
460 
462  void order() {
463  for (int i = 0; i < numVertices; i++) {
464  if (simp[i][numParams] < simp[best][numParams]) best = i;
465  if (simp[i][numParams] > simp[worst][numParams]) worst = i;
466  }
467  nextWorst = best;
468  for (int i = 0; i < numVertices; i++) {
469  if (i != worst) {
470  if (simp[i][numParams] > simp[nextWorst][numParams]) nextWorst = i;
471  }
472  }
473  // IJ.write("B: " + simp[best][numParams] + " 2ndW: " + simp[nextWorst][numParams] + " W: " + simp[worst][numParams]);
474  }
475 
477  public int getIterations() {
478  return numIter;
479  }
480 
482  public int getMaxIterations() {
483  return maxIter;
484  }
485 
487  public void setMaxIterations(int x) {
488  maxIter = x;
489  }
490 
492  public int getRestarts() {
493  return restarts;
494  }
495 
497  public void setRestarts(int x) {
498  restarts = x;
499  }
500 
507  public static int getMax(double[] array) {
508  double max = array[0];
509  int index = 0;
510  for(int i = 1; i < array.length; i++) {
511  if(max < array[i]) {
512  max = array[i];
513  index = i;
514  }
515  }
516  return index;
517  }
518 
519 }