using JLD
using Glob
using DataFrames
using Query
using Plots
using StatPlots
WARNING: Method definition describe(AbstractArray) in module StatsBase at /home/ubuntu/.julia/v0.5/StatsBase/src/scalarstats.jl:573 overwritten in module DataFrames at /home/ubuntu/.julia/v0.5/DataFrames/src/abstractdataframe/abstractdataframe.jl:407. WARNING: Method definition require(Symbol) in module Base at loading.jl:345 overwritten in module Query at /home/ubuntu/.julia/v0.5/Requires/src/require.jl:12. INFO: Recompiling stale cache file /home/ubuntu/.julia/lib/v0.5/Ratios.ji for module Ratios. INFO: Recompiling stale cache file /home/ubuntu/.julia/lib/v0.5/KernelDensity.ji for module KernelDensity.
using MLLabelUtils
using ColoringNames
cd(Pkg.dir("ColoringNames"))
function add_run!(all_runs, filename, results_field="validation_set_results")
run_data = load(filename)
row = Dict{Symbol,Any}(Symbol(kk)=>vv for (kk, vv) in run_data if Symbol(kk) in names(all_runs))
merge!(row, run_data[results_field])
push!(all_runs,row)
all_runs
end
all_runs= DataFrame(
embedding_dim=Int[],
hidden_layer_size=Int[],
perp=Float64[],
mse_to_peak=Float64[],
perp_sat=Float64[],
perp_hue=Float64[],
perp_val=Float64[],
perp_uniform_baseline=Float64[])
for fn in glob(glob"models/hyperparam_validation/*/meta_v2.jld")
add_run!(all_runs, fn, "results_validation_set")
end
sort!(all_runs; cols=:perp)
embedding_dim | hidden_layer_size | perp | mse_to_peak | perp_sat | perp_hue | perp_val | perp_uniform_baseline | |
---|---|---|---|---|---|---|---|---|
1 | 32 | 256 | 26.90898256259195 | 0.14335936595135218 | 41.338195124626694 | 15.148114275611267 | 31.115855655806715 | 99048.01581624868 |
2 | 16 | 256 | 26.934645903802632 | 0.1444501937216382 | 41.38137921505924 | 15.166804764188534 | 31.13398800536437 | 99048.01581624868 |
3 | 64 | 256 | 26.93810838672824 | 0.1435878074284456 | 41.38600244458802 | 15.179770301362238 | 31.115917331142004 | 99048.01581624868 |
4 | 16 | 128 | 27.07751590109371 | 0.1436052077751108 | 41.539098633945464 | 15.28917015916051 | 31.259748246285312 | 99048.01581624868 |
5 | 64 | 128 | 27.08713107740773 | 0.14461119554545107 | 41.53503833758362 | 15.298638181725666 | 31.27675144761631 | 99048.01581624868 |
6 | 32 | 128 | 27.101861094032508 | 0.14373450283174846 | 41.5312461560398 | 15.321616335439217 | 31.283677463125095 | 99048.01581624868 |
7 | 3 | 256 | 27.367691498332302 | 0.14850155636478612 | 41.90608583915064 | 15.45736524655851 | 31.644768982260345 | 108545.01036974895 |
8 | 32 | 64 | 27.66846077524411 | 0.14921134498550873 | 42.18606434059774 | 15.720754208582974 | 31.938359893759444 | 99048.01581624868 |
9 | 64 | 64 | 27.793579624983007 | 0.1604225611077853 | 42.4715737738912 | 15.729653856244925 | 32.13778448180776 | 99048.01581624868 |
10 | 3 | 128 | 28.110542420488045 | 0.16816509831629103 | 42.62920543677656 | 15.952837825101788 | 32.66348272898015 | 108545.01036974895 |
11 | 16 | 64 | 29.222100228035952 | 0.19566695442802176 | 45.18248168781976 | 16.108634654991064 | 34.28510350692026 | 108545.01036974895 |
12 | 3 | 64 | 29.989463482194843 | 0.21400501074668526 | 44.99036196180689 | 16.763826996890938 | 35.76131096886941 | 108545.01036974895 |
13 | 32 | 32 | 30.83811331589243 | 0.22612610471616135 | 46.31343995981237 | 17.799671787846876 | 35.57495552385066 | 99048.01581624868 |
14 | 64 | 32 | 31.57827575383564 | 0.26116269690721206 | 47.905068360615395 | 17.572466875078728 | 37.406841629343184 | 99048.01581624868 |
15 | 3 | 32 | 33.500923222461594 | 0.29184005803662366 | 50.308574577902576 | 18.902133141734694 | 39.53825484523089 | 108545.01036974895 |
16 | 16 | 32 | 33.55135099587924 | 0.27066911300530655 | 49.154345269138005 | 19.372683016083638 | 39.6623414487058 | 108545.01036974895 |
all_runs= DataFrame(
splay_std_dev_in_bins=Float64[],
perp=Float64[],
mse_to_peak=Float64[],
perp_sat=Float64[],
perp_hue=Float64[],
perp_val=Float64[],
perp_uniform_baseline=Float64[])
for fn in glob(glob"models/spread_validation/*/params.jld")
add_run!(all_runs, fn, "validation_set_results")
end
#sort!(all_runs; cols=:perp)
all_runs
noml_runs= DataFrame(
splay_std_dev_in_bins=Float64[],
output_res = Int[],
perp=Float64[],
mse_to_peak=Float64[],
perp_sat=Float64[],
perp_hue=Float64[],
perp_val=Float64[],
perp_uniform_baseline=Float64[])
for fn in glob(glob"models/noml_validation/*/params_with_model.jld")
add_run!(noml_runs, fn, "validation_set_results")
end
sort!(noml_runs; cols=:mse_to_peak)
noml_runs
splay_std_dev_in_bins | output_res | perp | mse_to_peak | perp_sat | perp_hue | perp_val | perp_uniform_baseline | |
---|---|---|---|---|---|---|---|---|
1 | 0.0625 | 32 | 13.919446644910431 | 0.1259636684991798 | 20.988543753244947 | 8.075007506919757 | 15.912582567625028 | 108545.01036974895 |
2 | 0.03125 | 32 | 13.930073330236636 | 0.12597586871485716 | 20.99190035877263 | 8.091090174902417 | 15.914808244689898 | 108545.01036974895 |
3 | 0.125 | 32 | 13.899763572334903 | 0.1271359672483436 | 20.98238824055258 | 8.045428811341461 | 15.908093122733215 | 108545.01036974895 |
4 | 0.25 | 32 | 13.866945348309283 | 0.12716111808947902 | 20.972304453509388 | 7.996363722400742 | 15.900241399954785 | 108545.01036974895 |
5 | 0.5 | 32 | 13.827639771977285 | 0.12785972850960312 | 20.9607556078467 | 7.938636702028436 | 15.888807713549932 | 108545.01036974895 |
6 | 1.0 | 32 | 13.855264943822736 | 0.13289876042655735 | 20.970472924689087 | 7.985933664110563 | 15.882197106086828 | 108545.01036974895 |
7 | 0.03125 | 64 | Inf | 0.1344654230541015 | Inf | Inf | Inf | 108545.01036974895 |
8 | 0.0625 | 64 | Inf | 0.1345590103947113 | Inf | Inf | Inf | 108545.01036974895 |
9 | 0.125 | 64 | Inf | 0.13459408004710655 | Inf | Inf | Inf | 108545.01036974895 |
10 | 0.25 | 64 | Inf | 0.1346025155196911 | Inf | Inf | Inf | 108545.01036974895 |
11 | 0.5 | 64 | Inf | 0.1361606358093347 | Inf | Inf | Inf | 108545.01036974895 |
12 | 1.0 | 64 | Inf | 0.13866139818291154 | Inf | Inf | Inf | 108545.01036974895 |
13 | 0.125 | 128 | Inf | 0.14122297853358728 | Inf | Inf | Inf | 108545.01036974895 |
14 | 0.25 | 128 | Inf | 0.14126555156452594 | Inf | Inf | Inf | 108545.01036974895 |
15 | 0.03125 | 128 | Inf | 0.141305511030235 | Inf | Inf | Inf | 108545.01036974895 |
16 | 0.0625 | 128 | Inf | 0.14135078952942265 | Inf | Inf | Inf | 108545.01036974895 |
17 | 2.0 | 32 | 14.34923140186697 | 0.14243460708581576 | 21.109303149726166 | 8.78490203350346 | 15.932176106367118 | 108545.01036974895 |
18 | 0.5 | 128 | Inf | 0.14269063302309423 | Inf | Inf | Inf | 108545.01036974895 |
19 | 2.0 | 64 | Inf | 0.1439852205818157 | Inf | Inf | Inf | 108545.01036974895 |
20 | 1.0 | 128 | Inf | 0.14463982309288878 | Inf | Inf | Inf | 108545.01036974895 |
21 | 2.0 | 128 | Inf | 0.14608109625516497 | Inf | Inf | Inf | 108545.01036974895 |
22 | 0.03125 | 256 | Inf | 0.14832759868050543 | Inf | Inf | Inf | 108545.01036974895 |
23 | 0.0625 | 256 | Inf | 0.14833133469491117 | Inf | Inf | Inf | 108545.01036974895 |
24 | 0.25 | 256 | Inf | 0.14846922251124808 | Inf | Inf | Inf | 108545.01036974895 |
25 | 0.125 | 256 | Inf | 0.14849634787865684 | Inf | Inf | Inf | 108545.01036974895 |
26 | 0.5 | 256 | Inf | 0.148670980397359 | Inf | Inf | Inf | 108545.01036974895 |
27 | 1.0 | 256 | Inf | 0.15048794916524227 | Inf | Inf | Inf | 108545.01036974895 |
28 | 4.0 | 128 | Inf | 0.15143878560665758 | Inf | Inf | Inf | 108545.01036974895 |
29 | 2.0 | 256 | Inf | 0.15146099973363153 | Inf | Inf | Inf | 108545.01036974895 |
30 | 4.0 | 256 | Inf | 0.15243308598558855 | Inf | Inf | Inf | 108545.01036974895 |
⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ |
m1 = load("models/noml_validation/sib0.5_or64/params_with_model.jld")
Dict{String,Any} with 7 entries: "git_hash" => "e50d814a6e3cae536b857849beadc246fe608725" "validation_set_results" => Dict(:perp_uniform_baseline=>108545.0,:perp=>Inf,… "executing_file" => "/mnt_volume/julia_dir/v0.5/ColoringNames/expr/ru… "model" => ColoringNames.TermToColorDistributionEmpirical{3}… "splay_std_dev" => 0.0078125 "output_res" => 64 "splay_std_dev_in_bins" => 0.5
mdl = m1["model"]
ColoringNames.TermToColorDistributionEmpirical{3}(64,Dict("grass green"=>(Float32[3.46681f-22,1.47001f-18,2.34105f-15,1.41428f-12,3.30064f-10,3.07994f-8,1.22468f-6,2.30046f-5,0.000230938,0.0013672 … 0.0,0.0,0.0,0.0,0.0,0.0,5.4844f-41,1.21672f-35,1.00036f-30,3.053f-26],Float32[9.88503f-31,3.01526f-26,3.41744f-22,1.44233f-18,2.27352f-15,1.3438f-12,2.99453f-10,2.5342f-8,8.22183f-7,1.0343f-5 … 0.0382473,0.0380517,0.0370082,0.0365966,0.0381575,0.0397527,0.039755,0.0414449,0.0452282,0.0429944],Float32[5.15958f-42,1.39012f-36,1.39103f-31,5.17924f-27,7.19261f-23,3.73705f-19,7.29301f-16,5.37361f-13,1.50527f-10,1.61839f-8 … 0.0332029,0.0308241,0.0269464,0.02304,0.0202082,0.0181884,0.0163083,0.0142603,0.0128536,0.0105627]),"celery"=>(Float32[1.68503f-17,2.77729f-14,1.71639f-11,4.00065f-9,3.54985f-7,1.22147f-5,0.000171466,0.00115525,0.0052838,0.0191129 … 0.0,0.0,0.0,0.0,1.4013f-45,5.0542f-40,1.1736f-34,1.00949f-29,3.22077f-25,3.81779f-21],Float32[1.79162f-9,1.94243f-7,8.22826f-6,0.000139244,0.000975536,0.00307057,0.00535714,0.0070853,0.00860261,0.0102388 … 0.00701873,0.00553211,0.00730939,0.00868375,0.00733508,0.00648898,0.00802778,0.00956066,0.00831466,0.00481498],Float32[0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0 … 0.04718,0.0476924,0.0559202,0.0619617,0.0690172,0.0799618,0.0824976,0.0794836,0.0827476,0.0757195]),"lipstick"=>(Float32[0.076896,0.0325507,0.00924846,0.00142477,9.95811f-5,2.88291f-6,3.31354f-8,1.47903f-10,2.53017f-13,1.6401f-16 … 0.0134629,0.018935,0.0251707,0.0388219,0.0571045,0.0867097,0.139425,0.180491,0.176712,0.134063],Float32[0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0 … 0.0388251,0.0331862,0.0267442,0.0304258,0.0429207,0.0482906,0.0425362,0.0355146,0.0341751,0.0289133],Float32[0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0 … 0.0423933,0.0421078,0.0435519,0.0367886,0.0287822,0.031865,0.0380646,0.0436181,0.044922,0.0369993]),"chocolate"=>(Float32[0.0882984,0.109486,0.132059,0.142491,0.126044,0.0932202,0.0606339,0.0345131,0.0158982,0.00549846 … 0.000877097,0.00128017,0.00232887,0.00441245,0.00780599,0.0113218,0.0158478,0.0254929,0.0434798,0.0674071],Float32[2.07756f-8,1.08833f-6,2.25156f-5,0.000191979,0.000732207,0.00147511,0.00202866,0.00206655,0.00126088,0.000539836 … 0.0239094,0.0222666,0.0245187,0.027089,0.0257464,0.0241908,0.0257125,0.0272286,0.0290488,0.0298928],Float32[8.18755f-11,1.24108f-8,7.26261f-7,1.69979f-5,0.00017212,0.000873755,0.00263357,0.0053411,0.00876873,0.0131248 … 0.000951814,0.000763513,0.000627662,0.000633281,0.000610054,0.000281021,5.23169f-5,3.82566f-6,1.08377f-7,1.17596f-9]),"pine green"=>(Float32[6.69423f-22,3.18575f-18,5.67539f-15,3.81012f-12,9.74517f-10,9.67427f-8,3.84883f-6,6.47931f-5,0.000503661,0.0020527 … 0.0,0.0,0.0,0.0,0.0,0.0,6.60068f-41,1.64934f-35,1.52681f-30,5.24362f-26],Float32[5.37808f-31,2.1754f-26,3.27558f-22,1.84083f-18,3.87439f-15,3.06769f-12,9.19233f-10,1.05059f-7,4.62577f-6,7.94501f-5 … 0.0341524,0.0336478,0.0310093,0.0306545,0.0302416,0.0295403,0.0286218,0.0262724,0.0266375,0.0270279],Float32[1.87205f-13,7.85436f-11,1.27489f-8,8.16143f-7,2.11177f-5,0.000227367,0.00105462,0.0022399,0.00267908,0.00328049 … 0.00181737,0.00174439,0.00149785,0.000848706,0.000478609,0.00121371,0.00255855,0.00274035,0.00167859,0.000611219]),"dark brown"=>(Float32[0.071476,0.081891,0.0898502,0.0934314,0.0945634,0.0884445,0.0742154,0.0565342,0.0385011,0.0235102 … 0.00266697,0.00467186,0.00719851,0.0102954,0.0151188,0.0219299,0.028107,0.0354671,0.0481991,0.0616601],Float32[0.000471385,0.000663127,0.000524629,0.000457729,0.000394906,0.000442707,0.00115227,0.00197361,0.00186314,0.00160463 … 0.0267686,0.0262669,0.0244771,0.0243011,0.0242803,0.0231389,0.0228862,0.0253964,0.0323548,0.0391965],Float32[0.000199822,0.000427043,0.000929838,0.00230669,0.00532991,0.00989786,0.0150463,0.0213858,0.0288437,0.0358879 … 0.000168965,6.03108f-5,8.50434f-6,4.68444f-7,9.95933f-9,8.08561f-11,2.4853f-13,2.87338f-16,1.16804f-19,0.0]),"twilight blue"=>(Float32[0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0 … 0.000274638,1.83674f-5,4.75267f-7,4.70566f-9,1.76675f-11,2.49812f-14,1.28552f-17,0.0,0.0,0.0],Float32[4.4695f-12,1.50766f-9,1.94744f-7,9.78489f-6,0.00019525,0.00158591,0.00537416,0.00773836,0.00476406,0.00124382 … 0.0288614,0.0209865,0.0166133,0.0182791,0.0237345,0.0265354,0.0252637,0.025438,0.0229478,0.0125252],Float32[6.72395f-30,2.41205f-25,3.21859f-21,1.60237f-17,2.98895f-14,2.1016f-11,5.61787f-9,5.77752f-7,2.32245f-5,0.000372297 … 0.0260427,0.025851,0.0238042,0.0193153,0.0166621,0.0184081,0.0221719,0.0229999,0.0210574,0.0205289]),"light olive green"=>(Float32[4.06077f-13,1.50465f-10,2.11468f-8,1.14588f-6,2.49345f-5,0.000244022,0.00137072,0.0055195,0.0164126,0.038711 … 0.0,0.0,0.0,4.44212f-43,1.65691f-37,2.28558f-32,1.16865f-27,2.21822f-23,1.56597f-19,4.12229f-16],Float32[2.50659f-15,2.05024f-12,6.36863f-10,7.58042f-8,3.50128f-6,6.39533f-5,0.000477005,0.00155679,0.00268868,0.00367874 … 0.0130746,0.0125934,0.0124765,0.0133607,0.0154628,0.0160465,0.0129417,0.00890837,0.00795897,0.00782123],Float32[0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,3.25101f-43 … 0.0385652,0.0360822,0.030239,0.0231902,0.0193576,0.0173853,0.0145774,0.0108553,0.0090812,0.00822056]),"heliotrope"=>(Float32[0.00398165,0.00641567,0.00572263,0.00311894,0.00257727,0.00249873,0.00113292,0.00020582,1.54418f-5,2.45312f-5 … 0.0793904,0.063045,0.0476435,0.0335383,0.0241015,0.0161161,0.00939483,0.0050472,0.00182605,0.00155534],Float32[8.96063f-7,2.68344f-5,0.000311915,0.00142424,0.00258105,0.00190477,0.00106296,0.00264512,0.00506033,0.00405047 … 0.0166903,0.0197962,0.0232766,0.0226476,0.0169195,0.0136662,0.0140698,0.0149134,0.0181363,0.0187987],Float32[0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0 … 0.0590723,0.0607611,0.0637313,0.0649608,0.0599012,0.0536555,0.0580984,0.0687417,0.0706833,0.0589098]),"mud brown"=>(Float32[0.0219354,0.0461085,0.0736861,0.0922347,0.102704,0.123414,0.152776,0.154319,0.11172,0.0549519 … 7.47119f-11,1.39153f-8,9.87167f-7,2.70319f-5,0.000292591,0.00131789,0.00279755,0.00370773,0.00501496,0.00951382],Float32[7.25553f-30,2.4137f-25,2.98323f-21,1.37295f-17,2.36018f-14,1.5223f-11,3.70871f-9,3.44864f-7,1.24491f-5,0.000179527 … 0.0271173,0.0241177,0.0225199,0.024522,0.0283274,0.0286492,0.0218187,0.0147949,0.0169549,0.026815],Float32[2.9575f-12,9.31273f-10,1.11051f-7,5.05847f-6,8.89719f-5,0.00061179,0.00167121,0.0019424,0.00160498,0.00210686 … 0.000707617,0.000438167,0.0013179,0.0019386,0.00114011,0.00026613,2.44217f-5,8.70586f-7,1.19171f-8,6.2031f-11])…))
rng = 70_955:70_955
evaluate(mdl, valid_text[rng], valid_hsv[rng,:])
Dict{Symbol,Float64} with 6 entries: :perp_uniform_baseline => 1.0 :perp => Inf :mse_to_peak => 0.845673 :perp_sat => 308.476 :perp_hue => Inf :perp_val => 1.98883e14
methods(descretized_perplexity)
hh,ss,vv = query(mdl, valid_text[70_954])
(Float32[0.0397098,0.0573034,0.0823717,0.106618,0.126799,0.134196,0.119233,0.0979447,0.0751083,0.0492329 … 0.00030164,0.000798041,0.000964212,0.0009284,0.00104436,0.00139643,0.00359676,0.00854104,0.0173626,0.0287236],Float32[2.6975f-6,6.22119f-5,0.000573984,0.00221708,0.00396639,0.0043194,0.00485628,0.00661597,0.00818797,0.00825639 … 0.00331553,0.00266177,0.00139798,0.000433133,0.000703525,0.00247234,0.00435376,0.00395598,0.00195483,0.00049983],Float32[0.0,3.72213f-41,1.00472f-35,1.00448f-30,3.72406f-26,5.12819f-22,2.62839f-18,5.02807f-15,3.60355f-12,9.72514f-10 … 0.0207113,0.0224995,0.0264417,0.0259635,0.0192976,0.0129981,0.00968241,0.00615069,0.00228175,0.000401225])
find_bin(valid_hsv[rng,1], length(hh))
1-element Array{Int64,1}: 38
hh[38]
0.0f0