Polynomial Fitting¶
Similarly to the linear fit and the transit model we can instead fit our data with a polynomial model. The difference from the linear fit tutorial is that in this case we'll generate a slightly different polynomial function for each wavelength and see how well our model can recover the parameters.
from chromatic_fitting import *
from pymc3 import Normal, Uniform
plt.matplotlib.style.use('default')
First we'll create a Rainbow object from chromatic and then add a wavelength-dependent polynomial model:
# create simulated rainbow
r = SimulatedRainbow(dt=1 * u.minute, R=50).inject_noise(signal_to_noise=20)
# bin:
rb = r.bin(nwavelengths=int(r.nwave/5), dt=5 * u.minute)
# create wavelength-dep linear + constant model:
a = 0.0
b = 0.05
c = 0.0
d = 5.0
x = rb.time.to_value("day")
true_a, true_b, true_c, true_d, poly = [],[],[],[],[]
for i in range(rb.nwave):
true_a.append(a + 1)
true_b.append(b*i)
true_c.append(c*i)
true_d.append(d*i)
poly.append((d*i*(x**3)) + (c*i*(x**2)) + (b*i*x))
rb.fluxlike['flux'] = rb.flux + np.array(poly)
# plot our Rainbow to see how it looks
rb.imshow_quantities();
0%| | 0/116 [00:00<?, ?it/s]
0%| | 0/61 [00:00<?, ?it/s]
Let's plot our data in 2-D so we can see the polynomial shapes we've added:
rb.plot_lightcurves();
Create Polynomial Model¶
We set up the PolynomialModel similarly to the linear model tutorial, however, we need to provide the degree of the polynomial. By setting this argument we can fix some of our degrees to zero. For example if we have a linear model but we want a zero constant offset we could ignore the p_0 parameter which would be fixed by default to 0.
# set up polynomial model:
p = PolynomialModel(degree=3)
p.setup_parameters(
p_0 = WavelikeFitted(Uniform,testval=0.01,upper=1,lower=-1),
p_1 = WavelikeFitted(Uniform,testval=0.01,upper=1,lower=-1),
p_2 = WavelikeFitted(Uniform,testval=0.01,upper=1,lower=-1),
p_3 = WavelikeFitted(Uniform,testval=0.01,upper=50,lower=-1)
)
# print a summary of all params:
p.summarize_parameters()
polynomial_p_0 = <🧮 WavelikeFitted Uniform(testval=0.01, upper=1, lower=-1, name='polynomial_p_0') for each wavelength 🧮> polynomial_p_1 = <🧮 WavelikeFitted Uniform(testval=0.01, upper=1, lower=-1, name='polynomial_p_1') for each wavelength 🧮> polynomial_p_2 = <🧮 WavelikeFitted Uniform(testval=0.01, upper=1, lower=-1, name='polynomial_p_2') for each wavelength 🧮> polynomial_p_3 = <🧮 WavelikeFitted Uniform(testval=0.01, upper=50, lower=-1, name='polynomial_p_3') for each wavelength 🧮>
# setup model the same way as for the transit model!:
p.attach_data(rb)
p.setup_lightcurves()
p.setup_likelihood()
Let's check our PyMC3 model and make sure that the parameters have been set up okay!
p._pymc3_model
Looks good, now onto sampling our model...
Sampling our Model¶
Now we can try to fit our model! Here we will first perform an optimization step (to give our sampling a good first guess) and then the actual NUTS sampling with a number of tuning and draw steps and chains that we define. We can also choose how many cores to assign to this sampling! Bear in mind that we have a decent number of parameters to fit (simultaneously) and so we want to make sure we have enough steps in the MCMC!
# optimize for initial values!
opt = p.optimize(plot=False)
# put those initial values into the sampling and define the number of tuning and draw steps,
# as well as the number of chains.
p.sample(start=opt, tune=1000, draws=1000, chains=4, cores=4)
optimizing logp for variables: [polynomial_p_3_w4, polynomial_p_2_w4, polynomial_p_1_w4, polynomial_p_0_w4, polynomial_p_3_w3, polynomial_p_2_w3, polynomial_p_1_w3, polynomial_p_0_w3, polynomial_p_3_w2, polynomial_p_2_w2, polynomial_p_1_w2, polynomial_p_0_w2, polynomial_p_3_w1, polynomial_p_2_w1, polynomial_p_1_w1, polynomial_p_0_w1, polynomial_p_3_w0, polynomial_p_2_w0, polynomial_p_1_w0, polynomial_p_0_w0]
message: Optimization terminated successfully. logp: -6764473.525476757 -> 1160.4319625602034 /Users/camu5866/opt/anaconda3/envs/chromaticfitting/lib/python3.9/site-packages/deprecat/classic.py:215: FutureWarning: In v4.0, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning. return wrapped_(*args_, **kwargs_) Multiprocess sampling (4 chains in 4 jobs) NUTS: [polynomial_p_3_w4, polynomial_p_2_w4, polynomial_p_1_w4, polynomial_p_0_w4, polynomial_p_3_w3, polynomial_p_2_w3, polynomial_p_1_w3, polynomial_p_0_w3, polynomial_p_3_w2, polynomial_p_2_w2, polynomial_p_1_w2, polynomial_p_0_w2, polynomial_p_3_w1, polynomial_p_2_w1, polynomial_p_1_w1, polynomial_p_0_w1, polynomial_p_3_w0, polynomial_p_2_w0, polynomial_p_1_w0, polynomial_p_0_w0]
/Users/camu5866/opt/anaconda3/envs/chromaticfitting/lib/python3.9/site-packages/scipy/stats/_continuous_distns.py:624: RuntimeWarning: overflow encountered in _beta_ppf return _boost._beta_ppf(q, a, b) /Users/camu5866/opt/anaconda3/envs/chromaticfitting/lib/python3.9/site-packages/scipy/stats/_continuous_distns.py:624: RuntimeWarning: overflow encountered in _beta_ppf return _boost._beta_ppf(q, a, b) /Users/camu5866/opt/anaconda3/envs/chromaticfitting/lib/python3.9/site-packages/scipy/stats/_continuous_distns.py:624: RuntimeWarning: overflow encountered in _beta_ppf return _boost._beta_ppf(q, a, b) /Users/camu5866/opt/anaconda3/envs/chromaticfitting/lib/python3.9/site-packages/scipy/stats/_continuous_distns.py:624: RuntimeWarning: overflow encountered in _beta_ppf return _boost._beta_ppf(q, a, b) Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 8 seconds.
Now we can look at our results:
p.summarize(round_to=7, hdi_prob=0.68, fmt='wide')
mean sd hdi_16% hdi_84% mcse_mean \
polynomial_p_0_w0 0.999470 0.000445 0.999366 1.000000 0.000007
polynomial_p_1_w0 -0.026794 0.022087 -0.047304 -0.002319 0.000449
polynomial_p_2_w0 0.025505 0.140288 -0.107670 0.163992 0.002105
polynomial_p_3_w0 4.784377 3.016049 1.120042 7.335429 0.064071
polynomial_p_0_w1 0.998949 0.000679 0.998581 0.999920 0.000014
polynomial_p_1_w1 0.034036 0.022448 0.012189 0.057369 0.000510
polynomial_p_2_w1 0.261071 0.164144 0.094727 0.420512 0.002997
polynomial_p_3_w1 5.066624 3.068378 1.250641 7.569463 0.074046
polynomial_p_0_w2 0.999388 0.000495 0.999247 1.000000 0.000009
polynomial_p_1_w2 0.105216 0.025445 0.078866 0.127631 0.000667
polynomial_p_2_w2 0.081249 0.146249 -0.069188 0.213364 0.002352
polynomial_p_3_w2 10.456508 3.629041 7.346513 14.245256 0.104220
polynomial_p_0_w3 0.999715 0.000268 0.999666 0.999999 0.000004
polynomial_p_1_w3 0.142064 0.024788 0.119439 0.168665 0.000364
polynomial_p_2_w3 -0.042956 0.130786 -0.161065 0.104178 0.001839
polynomial_p_3_w3 16.052315 3.499595 12.319439 19.318084 0.051303
polynomial_p_0_w4 0.999529 0.000395 0.999428 1.000000 0.000006
polynomial_p_1_w4 0.205724 0.024723 0.182979 0.232134 0.000350
polynomial_p_2_w4 0.056294 0.133682 -0.081951 0.177289 0.002027
polynomial_p_3_w4 21.870544 3.514664 18.363020 25.285231 0.048679
mcse_sd ess_bulk ess_tail r_hat
polynomial_p_0_w0 0.000005 2338.520013 1386.705486 1.001233
polynomial_p_1_w0 0.000323 2435.190134 2998.404890 1.000490
polynomial_p_2_w0 0.002230 4488.081328 2670.702295 0.999631
polynomial_p_3_w0 0.045311 1992.733666 1440.856833 1.000867
polynomial_p_0_w1 0.000010 1701.106444 1000.227790 1.000451
polynomial_p_1_w1 0.000374 1891.408521 2145.717611 1.000257
polynomial_p_2_w1 0.002475 3078.894987 2377.462127 1.002427
polynomial_p_3_w1 0.052367 1450.728445 800.865889 1.000224
polynomial_p_0_w2 0.000007 1485.920215 1071.628199 1.000114
polynomial_p_1_w2 0.000528 1558.850601 908.510733 1.001078
polynomial_p_2_w2 0.002451 4092.521660 2023.375715 1.001122
polynomial_p_3_w2 0.073713 1365.785903 574.955677 1.001041
polynomial_p_0_w3 0.000003 3039.213124 1938.752322 1.001264
polynomial_p_1_w3 0.000261 4637.995498 3114.807890 0.999956
polynomial_p_2_w3 0.002072 5044.293920 2656.659960 1.001428
polynomial_p_3_w3 0.036279 4661.225377 2513.138199 0.999800
polynomial_p_0_w4 0.000004 2451.083872 1393.260638 1.001639
polynomial_p_1_w4 0.000248 5003.024751 3254.419781 1.000587
polynomial_p_2_w4 0.001957 4385.534270 2718.407511 1.001792
polynomial_p_3_w4 0.034423 5218.928066 2984.910011 1.000509
r_hat parameters are close to 1, which is a good sign that our chains have converged!
p.get_results(uncertainty=['hdi_16%','hdi_84%'])
| polynomial_p_0 | polynomial_p_0_hdi_16% | polynomial_p_0_hdi_84% | polynomial_p_1 | polynomial_p_1_hdi_16% | polynomial_p_1_hdi_84% | polynomial_p_2 | polynomial_p_2_hdi_16% | polynomial_p_2_hdi_84% | polynomial_p_3 | polynomial_p_3_hdi_16% | polynomial_p_3_hdi_84% | wavelength | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| w0 | 0.99947 | 0.999366 | 1.0 | -0.026794 | -0.047304 | -0.002319 | 0.025505 | -0.10767 | 0.163992 | 4.784377 | 1.120042 | 7.335429 | 0.639572482934883 micron |
| w1 | 0.998949 | 0.998581 | 0.99992 | 0.034036 | 0.012189 | 0.057369 | 0.261071 | 0.094727 | 0.420512 | 5.066624 | 1.250641 | 7.569463 | 1.013209338074884 micron |
| w2 | 0.999388 | 0.999247 | 1.0 | 0.105216 | 0.078866 | 0.127631 | 0.081249 | -0.069188 | 0.213364 | 10.456508 | 7.346513 | 14.245256 | 1.604998553797903 micron |
| w3 | 0.999715 | 0.999666 | 0.999999 | 0.142064 | 0.119439 | 0.168665 | -0.042956 | -0.161065 | 0.104178 | 16.052315 | 12.319439 | 19.318084 | 2.542436455025025 micron |
| w4 | 0.999529 | 0.999428 | 1.0 | 0.205724 | 0.182979 | 0.232134 | 0.056294 | -0.081951 | 0.177289 | 21.870544 | 18.36302 | 25.285231 | 4.027407446906737 micron |
model = p.get_model()
model.keys()
dict_keys(['w0', 'w1', 'w2', 'w3', 'w4'])
Plot Results¶
Remember that handy plot_lightcurves() function from earlier? Once we have generated a model it should now overplot those models on top of the data.
p.plot_lightcurves()
We can also use the chromatic functions wrapped in chromatic_fitting that can let us look at the residuals:
p.plot_with_model_and_residuals()
No model attached to data. Running `add_model_to_rainbow` now. You can access this data later using [self].data_with_model
p.imshow_with_models(vlimits_data=[0.96, 1.04])
Compare Results to the True Values¶
We can also compare our fitted results to the true values we put in:
results = p.get_results(uncertainty=['sd','sd'])
results
| polynomial_p_0 | polynomial_p_0_sd | polynomial_p_1 | polynomial_p_1_sd | polynomial_p_2 | polynomial_p_2_sd | polynomial_p_3 | polynomial_p_3_sd | wavelength | |
|---|---|---|---|---|---|---|---|---|---|
| w0 | 0.99947 | 0.000445 | -0.026794 | 0.022087 | 0.025505 | 0.140288 | 4.784377 | 3.016049 | 0.639572482934883 micron |
| w1 | 0.998949 | 0.000679 | 0.034036 | 0.022448 | 0.261071 | 0.164144 | 5.066624 | 3.068378 | 1.013209338074884 micron |
| w2 | 0.999388 | 0.000495 | 0.105216 | 0.025445 | 0.081249 | 0.146249 | 10.456508 | 3.629041 | 1.604998553797903 micron |
| w3 | 0.999715 | 0.000268 | 0.142064 | 0.024788 | -0.042956 | 0.130786 | 16.052315 | 3.499595 | 2.542436455025025 micron |
| w4 | 0.999529 | 0.000395 | 0.205724 | 0.024723 | 0.056294 | 0.133682 | 21.870544 | 3.514664 | 4.027407446906737 micron |
print("\t\t\tTrue, \tFitted")
for w in range(p.data.nwave):
for i, coeff in zip(range(p.degree+1),[true_a, true_b, true_c, true_d]):
print(f"wavelength {w}, p_{i}:\t {round(coeff[w],2)}, \t",results.loc[f'w{w}'][f"{p.name}_p_{i}"],"+/-",results.loc[f'w{w}'][f"{p.name}_p_{i}_sd"])
True, Fitted wavelength 0, p_0: 1.0, 0.9994704 +/- 0.0004447 wavelength 0, p_1: 0.0, -0.0267944 +/- 0.0220872 wavelength 0, p_2: 0.0, 0.0255054 +/- 0.1402877 wavelength 0, p_3: 0.0, 4.7843773 +/- 3.0160491 wavelength 1, p_0: 1.0, 0.9989494 +/- 0.0006792 wavelength 1, p_1: 0.05, 0.0340356 +/- 0.0224478 wavelength 1, p_2: 0.0, 0.2610712 +/- 0.1641436 wavelength 1, p_3: 5.0, 5.0666243 +/- 3.0683784 wavelength 2, p_0: 1.0, 0.9993876 +/- 0.0004948 wavelength 2, p_1: 0.1, 0.1052157 +/- 0.0254455 wavelength 2, p_2: 0.0, 0.0812492 +/- 0.1462489 wavelength 2, p_3: 10.0, 10.4565084 +/- 3.629041 wavelength 3, p_0: 1.0, 0.9997147 +/- 0.0002683 wavelength 3, p_1: 0.15, 0.1420635 +/- 0.0247884 wavelength 3, p_2: 0.0, -0.0429563 +/- 0.1307864 wavelength 3, p_3: 15.0, 16.0523151 +/- 3.4995951 wavelength 4, p_0: 1.0, 0.9995292 +/- 0.0003954 wavelength 4, p_1: 0.2, 0.2057236 +/- 0.0247229 wavelength 4, p_2: 0.0, 0.0562941 +/- 0.1336821 wavelength 4, p_3: 20.0, 21.8705439 +/- 3.5146637
Let's plot the data, the true regression line and our fit and see how they compare (If we used store_model=True at the .setup_lightcurves() stage then we could easily generate a 1-sigma region for the model using the errors stored in the summary table!):
fig, ax = plt.subplots(p.data.nwave, figsize=(12,18))
p.plot_model(ax=ax)
plt.tight_layout();
for i in range(len(poly)):
ax[i].plot(p.data.time, poly[i] + 1, label="True polynomial")
ax[i].legend()
This is a good example to see where the model does a good job of fitting, and where, when the noise is larger than the signal, it can overfit (wavelength 0).