import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
pd.options.display.max_columns = 99
os.chdir('..') # change to root directory
# Load data
import pickle
from firenet.util import add_uncertainty_features
with open('./data/d_data.pkl', 'rb') as infile:
d_data = pickle.load(infile)
d_data = add_uncertainty_features(d_data)
d_data.keys()
dict_keys(['fullbay', 'fullbayerr', 'shortbay', 'shortbayerr', 'redshift', 'observed', 'observederr', 'obserr_to_short', 'obs_to_short'])
from firenet.ml import RegUncPredictor
pred = RegUncPredictor(d_data)
pred.preprocess()
pred.train_regressor()
epoch train_loss valid_loss cp dur ------- ------------ ------------ ---- ------ 1 0.8752 0.6712 + 0.0751 2 0.5768 0.3935 + 0.0875 3 0.3913 0.3257 + 0.0869 4 0.3369 0.2845 + 0.0880 5 0.3090 0.2638 + 0.0847 6 0.2873 0.2496 + 0.0851 7 0.2706 0.2386 + 0.0788 8 0.2578 0.2300 + 0.1118 9 0.2475 0.2228 + 0.0995 10 0.2391 0.2171 + 0.0814 11 0.2319 0.2120 + 0.0902 12 0.2256 0.2080 + 0.0681 13 0.2199 0.2049 + 0.0727 14 0.2150 0.2024 + 0.0774 15 0.2106 0.2003 + 0.0767 16 0.2065 0.1986 + 0.0759 17 0.2028 0.1971 + 0.0822 18 0.1996 0.1958 + 0.0802 19 0.1967 0.1944 + 0.0801 20 0.1940 0.1931 + 0.0821 21 0.1915 0.1916 + 0.0713 22 0.1891 0.1903 + 0.0701 23 0.1869 0.1890 + 0.0676 24 0.1847 0.1877 + 0.0714 25 0.1827 0.1866 + 0.0723 26 0.1810 0.1856 + 0.0802 27 0.1795 0.1847 + 0.0786 28 0.1781 0.1839 + 0.0851 29 0.1768 0.1832 + 0.0799 30 0.1757 0.1827 + 0.0749 31 0.1747 0.1822 + 0.0710 32 0.1738 0.1818 + 0.0777 33 0.1730 0.1815 + 0.0713 34 0.1722 0.1812 + 0.0817 35 0.1716 0.1809 + 0.0750 36 0.1710 0.1806 + 0.0856 37 0.1705 0.1804 + 0.0826 38 0.1700 0.1802 + 0.0807 39 0.1696 0.1800 + 0.0870 40 0.1693 0.1799 + 0.0656 41 0.1690 0.1798 + 0.0746 42 0.1688 0.1797 + 0.0687 43 0.1686 0.1796 + 0.0678 44 0.1684 0.1795 + 0.0654 45 0.1683 0.1795 + 0.0714 46 0.1682 0.1794 + 0.0722 47 0.1681 0.1794 + 0.0673 48 0.1680 0.1794 + 0.0727 49 0.1680 0.1794 + 0.0794 50 0.1680 0.1794 + 0.0657 51 0.1679 0.1794 + 0.0712 52 0.1679 0.1794 0.0836 53 0.1679 0.1794 + 0.0841 54 0.1679 0.1794 + 0.0815 55 0.1680 0.1793 + 0.0705 56 0.1679 0.1793 + 0.0697 57 0.1679 0.1793 + 0.0676 58 0.1679 0.1793 + 0.0721 59 0.1679 0.1792 + 0.0711 60 0.1678 0.1792 + 0.0687 61 0.1677 0.1791 + 0.0949 62 0.1676 0.1791 + 0.0694 63 0.1675 0.1790 + 0.0797 64 0.1674 0.1789 + 0.0811 65 0.1672 0.1787 + 0.0780 66 0.1670 0.1786 + 0.0779 67 0.1668 0.1785 + 0.0789 68 0.1665 0.1784 + 0.0730 69 0.1663 0.1782 + 0.0902 70 0.1660 0.1780 + 0.0688 71 0.1657 0.1779 + 0.0725 72 0.1654 0.1778 + 0.0732 73 0.1651 0.1777 + 0.0726 74 0.1648 0.1779 0.0721 75 0.1645 0.1779 0.0664 76 0.1642 0.1782 0.0701 77 0.1640 0.1786 0.0689 78 0.1638 0.1795 0.1036 79 0.1637 0.1802 0.0887 80 0.1634 0.1812 0.0943 81 0.1632 0.1824 0.0805 82 0.1630 0.1834 0.0740 83 0.1627 0.1843 0.0703 84 0.1623 0.1849 0.0695 85 0.1619 0.1853 0.0830 86 0.1614 0.1855 0.0907 87 0.1608 0.1854 0.0928 88 0.1603 0.1857 0.0917 89 0.1599 0.1855 0.0660 90 0.1593 0.1858 0.0701 91 0.1589 0.1865 0.0754 92 0.1584 0.1869 0.0839 93 0.1581 0.1871 0.0766 94 0.1577 0.1879 0.0702 95 0.1574 0.1882 0.1012 96 0.1569 0.1886 0.0750 97 0.1566 0.1897 0.0709 98 0.1562 0.1895 0.0688 99 0.1557 0.1900 0.0696 100 0.1554 0.1898 0.0720 101 0.1550 0.1898 0.0672 102 0.1545 0.1898 0.0960 103 0.1541 0.1895 0.0704 104 0.1536 0.1886 0.0700 105 0.1530 0.1876 0.0685 106 0.1523 0.1871 0.0678 107 0.1517 0.1864 0.0716 108 0.1513 0.1863 0.0718 109 0.1507 0.1863 0.0668 110 0.1504 0.1862 0.0699 111 0.1500 0.1864 0.0697 112 0.1498 0.1863 0.0685 113 0.1496 0.1868 0.0732 114 0.1495 0.1867 0.0687 115 0.1495 0.1865 0.0792 116 0.1496 0.1845 0.0678 117 0.1496 0.1823 0.0657 118 0.1493 0.1780 0.0684 119 0.1488 0.1743 + 0.0834 120 0.1477 0.1716 + 0.0839 121 0.1464 0.1704 + 0.0828 122 0.1449 0.1702 + 0.0749 123 0.1435 0.1704 0.0956 124 0.1426 0.1705 0.0792 125 0.1419 0.1708 0.0765 126 0.1415 0.1707 0.0772 127 0.1411 0.1709 0.0797 128 0.1407 0.1708 0.0772 129 0.1402 0.1709 0.0704 130 0.1399 0.1710 0.0799 131 0.1397 0.1709 0.0759 132 0.1393 0.1709 0.0761 133 0.1391 0.1710 0.0741 134 0.1389 0.1710 0.0751 135 0.1386 0.1710 0.0781 136 0.1384 0.1711 0.1264 137 0.1383 0.1711 0.1886 138 0.1381 0.1711 0.1071 139 0.1380 0.1711 0.0754 140 0.1378 0.1711 0.0904 141 0.1377 0.1711 0.1040 142 0.1376 0.1712 0.0931 143 0.1375 0.1712 0.1100 144 0.1375 0.1712 0.0926 145 0.1374 0.1712 0.0786 146 0.1373 0.1712 0.0752 147 0.1373 0.1712 0.0769 148 0.1373 0.1712 0.0732 149 0.1372 0.1712 0.0739 150 0.1372 0.1712 0.0718
pred.train_uncertainty()
epoch train_loss valid_loss cp dur ------- ------------ ------------ ---- ------ 1 313.3290 73.8520 + 0.1296 2 -254.7472 -648.8933 + 0.1005 3 -1139.7797 -1500.9908 + 0.0980 4 -1952.6871 -2081.4747 + 0.1172 5 -2464.5263 -2396.6403 + 0.0881 6 -2752.4977 -2560.7280 + 0.1292 7 -2915.1661 -2648.5552 + 0.0719 8 -3011.9995 -2698.6603 + 0.0987 9 -3074.2287 -2729.5791 + 0.1081 10 -3117.2766 -2749.8247 + 0.1248 11 -3148.7736 -2763.9027 + 0.0760 12 -3172.6325 -2773.9058 + 0.0727 13 -3190.9592 -2781.2553 + 0.0758 14 -3205.2596 -2786.8473 + 0.0737 15 -3216.7073 -2791.2217 + 0.0746 16 -3226.1320 -2794.9273 + 0.0759 17 -3234.1256 -2798.3066 + 0.0766 18 -3241.0154 -2801.3620 + 0.0709 19 -3247.1000 -2804.0809 + 0.0753 20 -3252.4108 -2806.3007 + 0.0702 21 -3257.0786 -2808.2655 + 0.0783 22 -3261.2049 -2810.0148 + 0.0824 23 -3264.8522 -2811.6030 + 0.0761 24 -3268.1912 -2812.9841 + 0.0761 25 -3271.2079 -2814.1714 + 0.0769 26 -3273.8869 -2815.2133 + 0.0698 27 -3276.3192 -2816.1795 + 0.0762 28 -3278.5045 -2816.9384 + 0.0716 29 -3280.4999 -2817.6092 + 0.0766 30 -3282.2594 -2818.2738 + 0.0753 31 -3283.8510 -2818.8569 + 0.0729 32 -3285.2652 -2819.3847 + 0.0748 33 -3286.5510 -2819.8746 + 0.0729 34 -3287.6819 -2820.2886 + 0.0721 35 -3288.6946 -2820.6672 + 0.0685 36 -3289.5922 -2821.0029 + 0.0708 37 -3290.3863 -2821.3010 + 0.0703 38 -3291.0772 -2821.5659 + 0.0753 39 -3291.6801 -2821.8107 + 0.0750 40 -3292.1984 -2822.0330 + 0.0712 41 -3292.6401 -2822.1971 + 0.0823 42 -3293.0049 -2822.3355 + 0.0755 43 -3293.3130 -2822.4540 + 0.0728 44 -3293.5602 -2822.5496 + 0.0744 45 -3293.7562 -2822.6277 + 0.0736 46 -3293.9070 -2822.6826 + 0.0736 47 -3294.0169 -2822.7217 + 0.0719 48 -3294.0936 -2822.7452 + 0.0744 49 -3294.1430 -2822.7592 + 0.0733 50 -3294.1708 -2822.7657 + 0.0735
pred.reg.test()
PACS_70 0.218260 PACS_100 0.191469 PACS_160 0.173414 SPIRE_250 0.185114 SPIRE_350 0.200334 SPIRE_500 0.214310 Name: rmse, dtype: float64
pred.unc.test()
PACS_70 1.055272 PACS_100 1.108588 PACS_160 1.208961 SPIRE_250 1.089232 SPIRE_350 1.023933 SPIRE_500 0.994720 Name: mean_chisq, dtype: float64
y_t, y_p, y_perr = pred.get_target_set()
# Storing and loading models
# from firenet.ml import ModelStore
# ModelStore().store(pred, name='nnet') # Save to './models/nnet.pkl' by default
# pred = ModelStore().load(d_data, name='nnet') # Load "nnet" model
from firenet.plotting.truevspred import TrueVSPredPlotter
firbands = d_data['fullbay'].columns[-6:]
idx = y_t.index
y_terr = d_data['fullbayerr'].loc[idx, firbands].divide(d_data['fullbay'].loc[idx] * np.log(10), axis=0)
tvpplot = TrueVSPredPlotter(figsize=(12.8, 8.8))
tvpplot.create_panels(nrows=1)
panel = tvpplot.get_panel(0)
panel.stylized_plot(y_t, y_p, y_terr=y_terr, y_perr=y_perr,
style='firflux')
Using a 4-fold train/test split, training 4 predictors. Each galaxy is used as a test sample once.
from firenet.ml.fullsetpredictor import FullSetPredictor
fspred = FullSetPredictor(d_data)
fspred.prepare_splits(shuffle_state=123)
fspred.train()
Start training model 1/4... Start training model 2/4... Start training model 3/4... Start training model 4/4...
# Storing and loading models
# from firenet.ml.modelstore import ModelStore
# ModelStore().store(fspred, name='fsnnet') # Save to './models/fsnnet.pkl' by default
# fspred = ModelStore().load(d_data, name='fsnnet') # Load "fsnnet" model
y_t, y_p, y_perr = fspred.get_combined_test()
# See paper Fig. 3
from firenet.plotting.truevspred import TrueVSPredPlotter
firbands = d_data['fullbay'].columns[-6:]
idx = y_t.index
y_terr = d_data['fullbayerr'].loc[idx, firbands].divide(d_data['fullbay'].loc[idx] * np.log(10), axis=0)
tvpplot = TrueVSPredPlotter(figsize=(12.8, 8.8))
tvpplot.create_panels(nrows=1)
panel = tvpplot.get_panel(0)
panel.stylized_plot(y_t, y_p, y_terr=y_terr, y_perr=y_perr,
style='firflux')
This model can then be used for other data sets (see notebook 04_predicting.ipynb
)
# Use all data for training and testing (no longer unbiased test set)
idx_tot = d_data['fullbay'].index.values.copy()
np.random.seed(123)
np.random.shuffle(idx_tot)
pred = RegUncPredictor(d_data)
pred.preprocess(idx_train=idx_tot, idx_test=idx_tot)
pred.train_regressor()
epoch train_loss valid_loss cp dur ------- ------------ ------------ ---- ------ 1 0.7882 0.5118 + 0.0984 2 0.4379 0.3321 + 0.0988 3 0.3485 0.3019 + 0.0940 4 0.3225 0.2850 + 0.0934 5 0.3032 0.2703 + 0.0950 6 0.2868 0.2591 + 0.1835 7 0.2726 0.2499 + 0.1148 8 0.2611 0.2430 + 0.1084 9 0.2525 0.2375 + 0.0899 10 0.2457 0.2334 + 0.0979 11 0.2403 0.2299 + 0.0918 12 0.2358 0.2271 + 0.0915 13 0.2318 0.2238 + 0.0897 14 0.2283 0.2207 + 0.0910 15 0.2250 0.2175 + 0.0961 16 0.2220 0.2137 + 0.0937 17 0.2187 0.2109 + 0.0933 18 0.2160 0.2078 + 0.0944 19 0.2131 0.2058 + 0.0927 20 0.2107 0.2034 + 0.1157 21 0.2083 0.2011 + 0.1047 22 0.2061 0.1988 + 0.0919 23 0.2039 0.1968 + 0.0929 24 0.2019 0.1950 + 0.0943 25 0.2000 0.1935 + 0.0931 26 0.1984 0.1920 + 0.0970 27 0.1969 0.1906 + 0.0958 28 0.1954 0.1894 + 0.0937 29 0.1940 0.1884 + 0.0965 30 0.1929 0.1875 + 0.0998 31 0.1917 0.1866 + 0.0990 32 0.1907 0.1860 + 0.0917 33 0.1898 0.1854 + 0.0958 34 0.1889 0.1849 + 0.0936 35 0.1882 0.1843 + 0.0941 36 0.1875 0.1839 + 0.0987 37 0.1869 0.1835 + 0.0947 38 0.1863 0.1832 + 0.0990 39 0.1859 0.1830 + 0.0959 40 0.1855 0.1827 + 0.0943 41 0.1851 0.1826 + 0.0993 42 0.1848 0.1824 + 0.0957 43 0.1845 0.1823 + 0.0961 44 0.1843 0.1822 + 0.0981 45 0.1842 0.1821 + 0.0962 46 0.1840 0.1821 + 0.1001 47 0.1839 0.1820 + 0.0913 48 0.1838 0.1820 + 0.0949 49 0.1838 0.1820 + 0.0898 50 0.1837 0.1820 + 0.0944 51 0.1837 0.1820 + 0.0965 52 0.1837 0.1820 0.0988 53 0.1837 0.1820 + 0.0952 54 0.1837 0.1819 + 0.0930 55 0.1837 0.1819 + 0.0943 56 0.1837 0.1819 + 0.0945 57 0.1837 0.1819 + 0.0934 58 0.1837 0.1819 + 0.1035 59 0.1837 0.1819 + 0.0955 60 0.1837 0.1819 + 0.0968 61 0.1836 0.1819 + 0.0919 62 0.1836 0.1819 + 0.0926 63 0.1835 0.1818 + 0.0985 64 0.1834 0.1818 0.0950 65 0.1833 0.1818 + 0.0917 66 0.1832 0.1818 0.0911 67 0.1830 0.1818 0.0962 68 0.1829 0.1819 0.0907 69 0.1827 0.1821 0.0960 70 0.1826 0.1824 0.0951 71 0.1825 0.1829 0.0958 72 0.1824 0.1835 0.0905 73 0.1824 0.1843 0.1277 74 0.1824 0.1852 0.0946 75 0.1824 0.1865 0.0899 76 0.1824 0.1879 0.0960 77 0.1826 0.1893 0.0928 78 0.1827 0.1904 0.0923 79 0.1827 0.1916 0.0902 80 0.1826 0.1923 0.0921 81 0.1827 0.1928 0.0984 82 0.1824 0.1925 0.1073 83 0.1821 0.1924 0.1049 84 0.1816 0.1921 0.0921 85 0.1810 0.1922 0.0920 86 0.1805 0.1920 0.0979 87 0.1799 0.1920 0.0933 88 0.1794 0.1921 0.0904 89 0.1790 0.1919 0.0915 90 0.1783 0.1918 0.0950 91 0.1777 0.1920 0.0944 92 0.1774 0.1923 0.0888 93 0.1770 0.1925 0.0933 94 0.1765 0.1927 0.0937 95 0.1761 0.1936 0.0925 96 0.1758 0.1942 0.1003 97 0.1755 0.1948 0.0921 98 0.1750 0.1944 0.0984 99 0.1742 0.1949 0.0943 100 0.1740 0.1944 0.0971 101 0.1731 0.1945 0.0962 102 0.1726 0.1942 0.0966 103 0.1721 0.1942 0.0966 104 0.1714 0.1938 0.0962 105 0.1711 0.1937 0.0978 106 0.1704 0.1929 0.0939 107 0.1697 0.1929 0.0947 108 0.1694 0.1915 0.0945 109 0.1685 0.1916 0.0934 110 0.1681 0.1900 0.0913 111 0.1673 0.1893 0.0947 112 0.1668 0.1875 0.0932 113 0.1659 0.1866 0.0924 114 0.1652 0.1858 0.0981 115 0.1647 0.1844 0.0921 116 0.1639 0.1831 0.0938 117 0.1631 0.1820 0.0935 118 0.1625 0.1818 + 0.0936 119 0.1619 0.1815 + 0.1019 120 0.1614 0.1816 0.0972 121 0.1609 0.1817 0.0956 122 0.1604 0.1823 0.0929 123 0.1600 0.1832 0.0954 124 0.1599 0.1831 0.0948 125 0.1595 0.1825 0.0961 126 0.1593 0.1807 + 0.0931 127 0.1588 0.1781 + 0.0934 128 0.1579 0.1762 + 0.0969 129 0.1567 0.1750 + 0.0938 130 0.1557 0.1746 + 0.0992 131 0.1548 0.1742 + 0.0905 132 0.1542 0.1741 + 0.0922 133 0.1538 0.1739 + 0.0985 134 0.1534 0.1737 + 0.0944 135 0.1530 0.1735 + 0.0981 136 0.1526 0.1735 + 0.0992 137 0.1524 0.1734 + 0.0947 138 0.1521 0.1733 + 0.0927 139 0.1519 0.1733 + 0.0973 140 0.1517 0.1732 + 0.0969 141 0.1515 0.1732 + 0.1019 142 0.1514 0.1731 + 0.0934 143 0.1512 0.1731 0.0951 144 0.1511 0.1731 + 0.0978 145 0.1510 0.1731 + 0.0929 146 0.1510 0.1731 + 0.1009 147 0.1509 0.1731 + 0.0925 148 0.1509 0.1730 + 0.0948 149 0.1508 0.1730 + 0.0933 150 0.1508 0.1730 + 0.0919
pred.train_uncertainty()
epoch train_loss valid_loss cp dur ------- ------------ ------------ ---- ------ 1 243.0562 -106.4027 + 0.1200 2 -646.7032 -1268.8473 + 0.1008 3 -1828.0494 -2154.9756 + 0.1004 4 -2527.2691 -2536.9697 + 0.0951 5 -2826.8179 -2685.8429 + 0.0972 6 -2954.9312 -2753.2275 + 0.1068 7 -3019.6931 -2791.9428 + 0.1022 8 -3059.4041 -2817.2525 + 0.1011 9 -3086.9265 -2834.6661 + 0.0957 10 -3107.3336 -2847.2882 + 0.0945 11 -3123.2371 -2857.1242 + 0.0979 12 -3136.0263 -2865.3026 + 0.0938 13 -3146.6214 -2872.5512 + 0.0950 14 -3155.5494 -2878.8792 + 0.1004 15 -3163.4479 -2884.4436 + 0.0964 16 -3170.2819 -2889.1628 + 0.0956 17 -3176.1572 -2893.1781 + 0.0993 18 -3181.2677 -2896.5506 + 0.0978 19 -3185.9359 -2899.4992 + 0.0959 20 -3190.0292 -2902.2631 + 0.0949 21 -3193.6522 -2904.8469 + 0.1009 22 -3196.9657 -2907.0520 + 0.0956 23 -3199.9758 -2908.9954 + 0.0961 24 -3202.6706 -2910.8420 + 0.0990 25 -3205.0487 -2912.4327 + 0.0988 26 -3207.3861 -2913.7792 + 0.0953 27 -3209.2724 -2915.1478 + 0.0970 28 -3211.1166 -2916.1755 + 0.0949 29 -3212.7076 -2917.0117 + 0.1010 30 -3214.2191 -2917.9324 + 0.0978 31 -3215.5549 -2918.7651 + 0.0972 32 -3216.7740 -2919.5100 + 0.0923 33 -3217.8496 -2920.2501 + 0.0956 34 -3218.8195 -2920.8881 + 0.0997 35 -3219.7173 -2921.4086 + 0.0970 36 -3220.4751 -2921.9233 + 0.0957 37 -3221.1644 -2922.3357 + 0.0981 38 -3221.7768 -2922.7467 + 0.0965 39 -3222.3233 -2923.0637 + 0.1053 40 -3222.7671 -2923.3394 + 0.0980 41 -3223.1699 -2923.5700 + 0.1221 42 -3223.5107 -2923.7631 + 0.0956 43 -3223.7875 -2923.9434 + 0.0984 44 -3224.0184 -2924.0480 + 0.0996 45 -3224.2026 -2924.1593 + 0.0986 46 -3224.3494 -2924.2309 + 0.0939 47 -3224.4520 -2924.2832 + 0.0955 48 -3224.5300 -2924.3175 + 0.0995 49 -3224.5810 -2924.3343 + 0.1013 50 -3224.6094 -2924.3422 + 0.0956
pred.reg.test(tset='train')
PACS_70 0.205549 PACS_100 0.176100 PACS_160 0.151044 SPIRE_250 0.161366 SPIRE_350 0.177310 SPIRE_500 0.192000 Name: rmse, dtype: float64
# Store all data model
# ModelStore().store(pred, name='nnet_alldata')