In [330]:
#read in data from csv file
#using Pkg
#Pkg.add("CSV")
using CSV
# CSV File with oxyegn adsoprtion energies, columns are crystal number, crystal chemical formula
# adsorption site and oxyegn adsoprtion energies
#Oads=CSV.File("/Users/yusuliu/Downloads/18.337_2018/project/3colOads.csv") 
Oads=CSV.File("3colOads.csv")
Out[330]:
CSV.File("3colOads.csv", rows=1000):
Tables.Schema:
 :Count  Union{Missing, Int64}  
 :Name   Union{Missing, String} 
 :Site   Union{Missing, String} 
 :E      Union{Missing, Float64}
In [332]:
# CSV File with data from the periodic table
# columns are element name, A site or B site element in a perovskite, atomic number, period
# covalent radius, d-electron count, electronegativity, most common oxidation state
# polarizability, ionization energy
PD=CSV.File("PeriodicData.csv")
Out[332]:
CSV.File("PeriodicData.csv", rows=28):
Tables.Schema:
 :Element          Union{Missing, String} 
 :AorB             Union{Missing, String} 
 :Z                Union{Missing, Int64}  
 :period           Union{Missing, Int64}  
 :covalent_radius  Union{Missing, Float64}
 :d_elect          Union{Missing, Int64}  
 :Eneg             Union{Missing, Float64}
 :Ox               Union{Missing, Float64}
 :Pol              Union{Missing, Float64}
 :IE               Union{Missing, Float64}
In [334]:
#Make dictionaries from Periodic date
AorB =Dict{String,String}()
Z=Dict{String,Int64}()
period=Dict{String,Int64}()
radius=Dict{String,Float64}()
d_elect=Dict{String,Int64}()
Eneg=Dict{String,Float64}()
Ox=Dict{String,Float64}()
Pol=Dict{String,Float64}()
IE=Dict{String,Float64}()

for row in PD
    AorB[row.Element]=row.AorB
    Z[row.Element]=row.Z
    period[row.Element]=row.period
    radius[row.Element]=row.covalent_radius
    d_elect[row.Element]=row.d_elect
    Eneg[row.Element]=row.Eneg
    Ox[row.Element]=row.Ox
    Pol[row.Element]=row.Pol
    IE[row.Element]=row.IE
end
In [335]:
# functions to parse a chemical formula 

# get captical letters which are the indicator for the start of a new element
function getcaps(string)
    pos=[]
    for (index, letter) in enumerate(string)
        if (letter==uppercase(letter))
            append!(pos,index)
        end
    end
    return pos
end

# determine if an element is A or B site
function getelements(string)
    set=getcaps(string)
    j=length(set)
    element_list_A=String[]
    element_list_B=String[]
    for i in 1:j-2
        Astart=set[i]
        Bstart=set[i+1]
        new_el=string[Astart:(Bstart-1)]
        if AorB[new_el]=="A"
            push!(element_list_A, String(new_el))
        elseif AorB[new_el]=="B"
            push!(element_list_B, String(new_el))
        end 
    end
    return element_list_A, element_list_B
end

# function to vectorize each element
function makevec(element)
    vec=zeros(8)
    vec[1]=Z[element]
    vec[2]=period[element]
    vec[3]=radius[element]
    vec[4]=d_elect[element]
    vec[5]=Eneg[element]
    vec[6]=Ox[element]
    vec[7]=Pol[element]
    vec[8]=IE[element]
    return vec
end
Out[335]:
makevec (generic function with 1 method)
In [336]:
# make input array, 24 attributes per cyrstal, 1000 crystals
using Statistics
input_mat=zeros(24,1000)
for i=1:1003
    for row in Oads
        if (row.Count)==i
            set_A,set_B=getelements(row.Name)
            size_A=length(set_A)
            size_B=length(set_B)
            vecsA=[]
            vecsB=[]
            for i in 1:size_A
                newvec=makevec((set_A[i]))
                push!(vecsA,newvec)
            end
            for i in 1:size_B
                newvec=makevec((set_B[i]))
                push!(vecsB,newvec)
            end
            input_mat[1:8,i]=mean(vecsA)
            input_mat[9:16,i]=mean(vecsB)
            input_mat[17:24,i]=makevec(row.Site)
        end
    end
end
display("text/plain",input_mat)
24×1000 Array{Float64,2}:
 59.5    60.0     61.0    61.0    …  60.0     61.0    60.0     60.0   
  6.0     6.0      6.0     6.0        6.0      6.0     6.0      6.0   
  1.915   1.8      1.63    1.63       1.8      1.63    1.8      1.8   
  0.0     0.5      0.0     0.0        0.5      0.0     0.5      0.5   
  1.045   1.045    1.13    1.13       1.045    1.13    1.045    1.045 
  2.25    2.5      3.0     3.0    …   2.5      3.0     2.5      2.5   
 33.7    31.6     30.1    30.1       31.6     30.1    31.6     31.6   
  5.438   5.6755   5.554   5.554      5.6755   5.554   5.6755   5.6755
 41.0    44.0     25.0    44.0       40.0     29.0    25.0     46.0   
  5.0     5.0      4.0     5.0        5.0      4.0     4.0      5.0   
  1.34    1.25     1.17    1.25   …   1.25     1.17    1.17     1.28  
  3.0     6.0      5.0     6.0        2.0      9.0     5.0      8.0   
  1.6     2.2      1.55    2.2        1.33     1.9     1.55     2.2   
  4.0     4.6      4.4     4.6        4.0      1.5     4.4      3.0   
 15.7     9.6      9.4     9.6       17.9      6.2     9.4      4.8   
  6.76    7.36     7.43    7.36   …   6.64     7.73    7.43     8.34  
 41.0    44.0     25.0    44.0       40.0     29.0    25.0     46.0   
  5.0     5.0      4.0     5.0        5.0      4.0     4.0      5.0   
  1.34    1.25     1.17    1.25       1.25     1.17    1.17     1.28  
  3.0     6.0      5.0     6.0        2.0      9.0     5.0      8.0   
  1.6     2.2      1.55    2.2    …   1.33     1.9     1.55     2.2   
  4.0     4.6      4.4     4.6        4.0      1.5     4.4      3.0   
 15.7     9.6      9.4     9.6       17.9      6.2     9.4      4.8   
  6.76    7.36     7.43    7.36       6.64     7.73    7.43     8.34  
In [338]:
# make target vector: oxygen adsorption energies
target=Float64[]
for row in Oads
    append!(target,row.E)
end

y=zeros(1,1000)
for i in 1:1000
    y[1,i]=target[i]
end
display("text/plain",y)
1×1000 Array{Float64,2}:
 -20.9685  -18.3228  -16.6349  …  2.60126  7.46994  7.78974  8.87288
In [339]:
# normalize x input array

using Statistics, Random
x=copy(input_mat)
x_norm=copy(x)
for i=1:size(x,1)
    m=mean(x[i,:])
    stdev=std(x[i,:])
    for j=1:size(x,2)
        x_norm[i,j]=(x[i,j]-m)/stdev
    end
end
In [340]:
# randomize and seperate into 700 training points and 300 test points

index=randperm(Random.seed!(1),size(x,2))
xtrn=x_norm[:,index[1:700]]
xtst=x_norm[:,index[701:end]]
ytrn=y[:,index[1:700]]
ytst=y[:,index[701:end]]
Out[340]:
1×300 Array{Float64,2}:
 -2.80295  -5.24203  -5.29491  -7.87468  …  -4.73293  -3.69288  -9.0944
In [341]:
# make weight matrix
using Pkg
#Pkg.add("Distributions")
using Distributions
d=Normal()
weights=rand(d,24)*0.1
w=(weights,[0.0])
Out[341]:
([0.0635889, -0.185359, 0.0810713, 0.0158052, 0.0580127, -0.011097, -0.0968496, -0.0393294, -0.0243981, -0.0775562  …  0.114181, -0.165903, -0.12511, -0.181291, 0.00803023, 0.0067685, -0.00588284, 0.0341452, 0.0741828, -0.202932], [0.0])
In [342]:
# function to take in weight and input, gives output prediction vector
function predict_lr(w_mat, att_mat) # takes weight matrix and attributes matrix as arguments
    v_out=zeros(size(att_mat,2)) # initialize output matrix
    v_out=w_mat[1]'*att_mat.+w_mat[2]
    return v_out
end
Out[342]:
predict_lr (generic function with 1 method)
In [343]:
# function to calculate mean absolute error
function MAE_loss(w,x,y) # takes in weight, input matrix and ground truth
    ypred=predict_lr(w,x) 
    summ=sum(abs(ypred[i]-y[i]) for i=1:size(y,2))
    J=summ/(size(y,2))
    return J
end
Out[343]:
MAE_loss (generic function with 1 method)
In [344]:
# training function using autograd and knet
#Pkg.add("AutoGrad")
#Pkg.add("Knet")
using AutoGrad, Knet

function train(ww,lr, epochs)   
    lossgradient=grad(MAE_loss)
    println((0, :trnloss, MAE_loss(ww,xtrn,ytrn), :tstloss, MAE_loss(ww,xtst,ytst)))
    for epoch=1:epochs        
        dw=lossgradient(ww, xtrn, ytrn) 
      
        for i in 1:length(ww)   
            for j in 1:length(ww[i])
                ww[i][j] -=lr*dw[i][j]
            end
        end 
        println((epoch, :trnloss, MAE_loss(ww,xtrn,ytrn), :tstloss, MAE_loss(ww,xtst,ytst)))
    end
    return ww
end
Out[344]:
train (generic function with 1 method)
In [346]:
# 0.27 learning rate, 200 epochs
d=Normal()
weights=rand(d,24)*0.1
w=(weights,[0.0])
train(w,0.27,200)
println(w)
(0, :trnloss, 5.812201778528555, :tstloss, 5.783835376161)
(1, :trnloss, 5.548019689315449, :tstloss, 5.525624505210485)
(2, :trnloss, 5.285564574203059, :tstloss, 5.267626754794884)
(3, :trnloss, 5.025705015126596, :tstloss, 5.011697007088098)
(4, :trnloss, 4.766843824287708, :tstloss, 4.756326755896198)
(5, :trnloss, 4.5079826334488295, :tstloss, 4.5009565047043)
(6, :trnloss, 4.249121442609955, :tstloss, 4.245586253512402)
(7, :trnloss, 3.990260251771067, :tstloss, 3.9910634523070834)
(8, :trnloss, 3.7334727562295313, :tstloss, 3.7400313205768194)
(9, :trnloss, 3.4856369457724354, :tstloss, 3.496170285259635)
(10, :trnloss, 3.246019854309629, :tstloss, 3.2605256790940933)
(11, :trnloss, 3.01747829523752, :tstloss, 3.0315152785492496)
(12, :trnloss, 2.790565437978286, :tstloss, 2.803136889201372)
(13, :trnloss, 2.566320729814648, :tstloss, 2.5767744369662964)
(14, :trnloss, 2.3475822020259454, :tstloss, 2.351953680139021)
(15, :trnloss, 2.134820663170838, :tstloss, 2.1329710061845297)
(16, :trnloss, 1.9342131645893148, :tstloss, 1.9269479585603582)
(17, :trnloss, 1.759151054346452, :tstloss, 1.746296494649943)
(18, :trnloss, 1.6175475914693047, :tstloss, 1.6083689837941655)
(19, :trnloss, 1.4915446267642358, :tstloss, 1.474829795582169)
(20, :trnloss, 1.338413743408781, :tstloss, 1.3150650959980466)
(21, :trnloss, 1.257824434096353, :tstloss, 1.2295859571737142)
(22, :trnloss, 1.1950433039442825, :tstloss, 1.1677835205268698)
(23, :trnloss, 1.1474836875190355, :tstloss, 1.1185625280923903)
(24, :trnloss, 1.1116403313946643, :tstloss, 1.0848818667259228)
(25, :trnloss, 1.0829632900379171, :tstloss, 1.0595675622113558)
(26, :trnloss, 1.0559132352758378, :tstloss, 1.0406658153170094)
(27, :trnloss, 1.0346369613504578, :tstloss, 1.0326143841017148)
(28, :trnloss, 1.018326011994146, :tstloss, 1.0188314784178758)
(29, :trnloss, 1.0060920822313162, :tstloss, 1.014520708646363)
(30, :trnloss, 0.9957072795431668, :tstloss, 1.0038021534052404)
(31, :trnloss, 0.9854636843318866, :tstloss, 1.0011326087969499)
(32, :trnloss, 0.9770580228298781, :tstloss, 0.9989419851881501)
(33, :trnloss, 0.9697845661506997, :tstloss, 0.9929213837226206)
(34, :trnloss, 0.9632557952953695, :tstloss, 0.9886859043464633)
(35, :trnloss, 0.9567465433439317, :tstloss, 0.9845639675549508)
(36, :trnloss, 0.9517516707792248, :tstloss, 0.9791027541180768)
(37, :trnloss, 0.9466824935699027, :tstloss, 0.9782197994853673)
(38, :trnloss, 0.9433727303232518, :tstloss, 0.969751608101606)
(39, :trnloss, 0.9413193983161563, :tstloss, 0.9741614159622891)
(40, :trnloss, 0.9393445482518187, :tstloss, 0.9633312518085265)
(41, :trnloss, 0.9329633408555651, :tstloss, 0.96674306706153)
(42, :trnloss, 0.9316754974992817, :tstloss, 0.9578236158659043)
(43, :trnloss, 0.9260546460046806, :tstloss, 0.9619834823095225)
(44, :trnloss, 0.9247147520816122, :tstloss, 0.9526259275225698)
(45, :trnloss, 0.9206743719264446, :tstloss, 0.9555128493551066)
(46, :trnloss, 0.9183004385442103, :tstloss, 0.9484248412492055)
(47, :trnloss, 0.9158061249445144, :tstloss, 0.948606021411047)
(48, :trnloss, 0.9144384584664934, :tstloss, 0.9435082299125143)
(49, :trnloss, 0.9118400361644998, :tstloss, 0.9446883369122899)
(50, :trnloss, 0.9104181532832915, :tstloss, 0.9412306360846197)
(51, :trnloss, 0.9086730456453328, :tstloss, 0.9397553973460395)
(52, :trnloss, 0.9072309198356798, :tstloss, 0.938799276730259)
(53, :trnloss, 0.905133472754498, :tstloss, 0.9361864598614928)
(54, :trnloss, 0.9038033223314668, :tstloss, 0.9358538216963097)
(55, :trnloss, 0.9020564855188645, :tstloss, 0.9333331193342366)
(56, :trnloss, 0.9015738833383068, :tstloss, 0.9312535698549012)
(57, :trnloss, 0.901772479884398, :tstloss, 0.9357727014641125)
(58, :trnloss, 0.9023608752292492, :tstloss, 0.9307557318523484)
(59, :trnloss, 0.8993645332965414, :tstloss, 0.9320595537741407)
(60, :trnloss, 0.899727992134762, :tstloss, 0.9300669667671686)
(61, :trnloss, 0.8970229164129581, :tstloss, 0.9293904656097726)
(62, :trnloss, 0.8973679417748581, :tstloss, 0.9274744328162688)
(63, :trnloss, 0.895223301998018, :tstloss, 0.9278996945893331)
(64, :trnloss, 0.8963034512524829, :tstloss, 0.9271145450734709)
(65, :trnloss, 0.8944147690354284, :tstloss, 0.9267170360598891)
(66, :trnloss, 0.8963110868261079, :tstloss, 0.9256171765853588)
(67, :trnloss, 0.891194365786814, :tstloss, 0.9232637272928097)
(68, :trnloss, 0.8933699344158338, :tstloss, 0.9251740911961894)
(69, :trnloss, 0.8909748022319346, :tstloss, 0.9244642215721185)
(70, :trnloss, 0.8937784012771754, :tstloss, 0.921201047271471)
(71, :trnloss, 0.8934744078433794, :tstloss, 0.9271252515390798)
(72, :trnloss, 0.9048476557798962, :tstloss, 0.9263985161955287)
(73, :trnloss, 0.9080579560189628, :tstloss, 0.941317578697219)
(74, :trnloss, 0.911111108596098, :tstloss, 0.9313803915002338)
(75, :trnloss, 0.901401121043708, :tstloss, 0.9342012166889367)
(76, :trnloss, 0.9089807169917303, :tstloss, 0.9295881210365163)
(77, :trnloss, 0.9034237694411635, :tstloss, 0.937976945957796)
(78, :trnloss, 0.905617128563733, :tstloss, 0.9263545491011858)
(79, :trnloss, 0.904316128017374, :tstloss, 0.9389396430412243)
(80, :trnloss, 0.906005504374971, :tstloss, 0.9263284045315258)
(81, :trnloss, 0.9008815878486759, :tstloss, 0.9342181799620461)
(82, :trnloss, 0.9003981992114297, :tstloss, 0.9215154006312757)
(83, :trnloss, 0.9025570549712875, :tstloss, 0.9340091405469946)
(84, :trnloss, 0.9091903389510054, :tstloss, 0.9290800426075837)
(85, :trnloss, 0.9004879154299489, :tstloss, 0.9291762271543481)
(86, :trnloss, 0.9034975569645478, :tstloss, 0.9240875053705441)
(87, :trnloss, 0.9005968809840473, :tstloss, 0.9286153490487135)
(88, :trnloss, 0.9057306188675, :tstloss, 0.9255596047324892)
(89, :trnloss, 0.9024035834856496, :tstloss, 0.931018370350742)
(90, :trnloss, 0.9078659296984438, :tstloss, 0.927648599589352)
(91, :trnloss, 0.9006444885503765, :tstloss, 0.9295334107107888)
(92, :trnloss, 0.9057245629359066, :tstloss, 0.9248531241832167)
(93, :trnloss, 0.8995765493368879, :tstloss, 0.9268316593892992)
(94, :trnloss, 0.9060409085231667, :tstloss, 0.924840211343488)
(95, :trnloss, 0.9016066025965106, :tstloss, 0.9290301469966871)
(96, :trnloss, 0.9047244148175039, :tstloss, 0.9228329424526548)
(97, :trnloss, 0.8950392448523842, :tstloss, 0.920565456850984)
(98, :trnloss, 0.8997922400675108, :tstloss, 0.9200202435672892)
(99, :trnloss, 0.8920611028424618, :tstloss, 0.9183249796044098)
(100, :trnloss, 0.896633549655862, :tstloss, 0.917388600487293)
(101, :trnloss, 0.8913774935463937, :tstloss, 0.9189578366505468)
(102, :trnloss, 0.8992919470194204, :tstloss, 0.9205414246898737)
(103, :trnloss, 0.8939181529252447, :tstloss, 0.9214984074161001)
(104, :trnloss, 0.9013839809996885, :tstloss, 0.9218411085038055)
(105, :trnloss, 0.8965187791733782, :tstloss, 0.9241092735366077)
(106, :trnloss, 0.9037073783571279, :tstloss, 0.9241905080907876)
(107, :trnloss, 0.8963520807440422, :tstloss, 0.924392758092699)
(108, :trnloss, 0.9054687045712293, :tstloss, 0.9253626234939725)
(109, :trnloss, 0.895765055496697, :tstloss, 0.9222850379635137)
(110, :trnloss, 0.9019089114957232, :tstloss, 0.9222552136678543)
(111, :trnloss, 0.8939810937510294, :tstloss, 0.9215196131808002)
(112, :trnloss, 0.8935906202476159, :tstloss, 0.9150230932428648)
(113, :trnloss, 0.8912443868933895, :tstloss, 0.9189628016561814)
(114, :trnloss, 0.8986665320007957, :tstloss, 0.9201528083317748)
(115, :trnloss, 0.8990766283834624, :tstloss, 0.9259767721669023)
(116, :trnloss, 0.912466715350314, :tstloss, 0.9338175680193297)
(117, :trnloss, 0.8988626300872733, :tstloss, 0.9263299732423862)
(118, :trnloss, 0.9051381348341395, :tstloss, 0.9260349810227494)
(119, :trnloss, 0.9020091072896848, :tstloss, 0.9296097315534917)
(120, :trnloss, 0.9140111502361025, :tstloss, 0.9351220283875049)
(121, :trnloss, 0.8977931294813135, :tstloss, 0.9245333735042564)
(122, :trnloss, 0.9061069560192696, :tstloss, 0.9275009893657384)
(123, :trnloss, 0.8954650739414173, :tstloss, 0.9221065107845586)
(124, :trnloss, 0.9029533546241717, :tstloss, 0.9242108991255039)
(125, :trnloss, 0.8983704983312158, :tstloss, 0.9252369678988098)
(126, :trnloss, 0.9044606333378733, :tstloss, 0.9256870464211926)
(127, :trnloss, 0.9003310202749385, :tstloss, 0.9280289124367834)
(128, :trnloss, 0.9125232360653824, :tstloss, 0.9342835461049721)
(129, :trnloss, 0.8990769318904618, :tstloss, 0.9264635285795286)
(130, :trnloss, 0.9052974552095293, :tstloss, 0.9270230081898849)
(131, :trnloss, 0.8956290956347494, :tstloss, 0.9221462142786042)
(132, :trnloss, 0.9018343158116212, :tstloss, 0.9235526724687908)
(133, :trnloss, 0.8948713167851606, :tstloss, 0.9203400838725181)
(134, :trnloss, 0.9033117668783812, :tstloss, 0.9253848422333224)
(135, :trnloss, 0.896011341890243, :tstloss, 0.9223055926932253)
(136, :trnloss, 0.9087485227288742, :tstloss, 0.930488052859445)
(137, :trnloss, 0.8983520636741769, :tstloss, 0.9254762722394543)
(138, :trnloss, 0.9048187722726365, :tstloss, 0.9265497605029183)
(139, :trnloss, 0.8993454984721067, :tstloss, 0.9285673150970604)
(140, :trnloss, 0.906627995826929, :tstloss, 0.9289159798498976)
(141, :trnloss, 0.9012111593541248, :tstloss, 0.9301330598422997)
(142, :trnloss, 0.9134126862440952, :tstloss, 0.935597862077712)
(143, :trnloss, 0.8996235582331373, :tstloss, 0.9290523631069063)
(144, :trnloss, 0.9038866595673378, :tstloss, 0.9251398039632203)
(145, :trnloss, 0.9003171545203043, :tstloss, 0.931658378210581)
(146, :trnloss, 0.9032751342554101, :tstloss, 0.9244120340982697)
(147, :trnloss, 0.8971567690721913, :tstloss, 0.9261833097324992)
(148, :trnloss, 0.9015976911373491, :tstloss, 0.9225310063754256)
(149, :trnloss, 0.8973764401149527, :tstloss, 0.9257286836840534)
(150, :trnloss, 0.9054331218371825, :tstloss, 0.9261857069898006)
(151, :trnloss, 0.9018602541420041, :tstloss, 0.9303253522030042)
(152, :trnloss, 0.9089732910976409, :tstloss, 0.9303598203615561)
(153, :trnloss, 0.9012333733052686, :tstloss, 0.930404151056236)
(154, :trnloss, 0.911819959384576, :tstloss, 0.9336297643039709)
(155, :trnloss, 0.899247533042487, :tstloss, 0.9281309624203807)
(156, :trnloss, 0.9078720418498397, :tstloss, 0.9295639996542631)
(157, :trnloss, 0.9004203565806016, :tstloss, 0.9311670110859933)
(158, :trnloss, 0.9072911023740629, :tstloss, 0.9290518862481906)
(159, :trnloss, 0.9012385936030552, :tstloss, 0.9313083500225072)
(160, :trnloss, 0.9050797726814437, :tstloss, 0.9264097104857109)
(161, :trnloss, 0.8985633409003642, :tstloss, 0.9286130536294441)
(162, :trnloss, 0.9059948750202562, :tstloss, 0.9275411136689296)
(163, :trnloss, 0.8985937350200747, :tstloss, 0.9275783936424106)
(164, :trnloss, 0.903758371288957, :tstloss, 0.9256649505319023)
(165, :trnloss, 0.897786444248739, :tstloss, 0.9272250528183089)
(166, :trnloss, 0.9047730506778924, :tstloss, 0.9261011900561359)
(167, :trnloss, 0.8989177682756345, :tstloss, 0.9274009355339541)
(168, :trnloss, 0.9063267174821252, :tstloss, 0.9283874713120226)
(169, :trnloss, 0.9008310153072508, :tstloss, 0.9296107419997065)
(170, :trnloss, 0.9146646449830441, :tstloss, 0.9371319562392147)
(171, :trnloss, 0.8988934769432714, :tstloss, 0.9279116613063692)
(172, :trnloss, 0.9065410592645651, :tstloss, 0.9283567440594243)
(173, :trnloss, 0.9003247197710024, :tstloss, 0.9308285994434322)
(174, :trnloss, 0.9057301782034786, :tstloss, 0.9283320128002707)
(175, :trnloss, 0.8992600798996245, :tstloss, 0.9301896562891362)
(176, :trnloss, 0.9062342940582955, :tstloss, 0.9283212870012439)
(177, :trnloss, 0.896582972372566, :tstloss, 0.923737065044985)
(178, :trnloss, 0.9038332024811423, :tstloss, 0.9259271354450864)
(179, :trnloss, 0.8971453605539746, :tstloss, 0.9255301022179374)
(180, :trnloss, 0.9013365266011075, :tstloss, 0.9222655101765652)
(181, :trnloss, 0.8943205032017177, :tstloss, 0.9220310082549168)
(182, :trnloss, 0.895929323610895, :tstloss, 0.9165534543143576)
(183, :trnloss, 0.8897967818787755, :tstloss, 0.9156548474732377)
(184, :trnloss, 0.8916208338332373, :tstloss, 0.912130162432191)
(185, :trnloss, 0.8896569338220245, :tstloss, 0.9167793641889851)
(186, :trnloss, 0.8928154056678991, :tstloss, 0.9137378501821641)
(187, :trnloss, 0.8900888245230221, :tstloss, 0.9186306120520025)
(188, :trnloss, 0.8955835453611015, :tstloss, 0.9165019514916312)
(189, :trnloss, 0.8895123340622101, :tstloss, 0.9169913185789481)
(190, :trnloss, 0.8958614232276446, :tstloss, 0.9173436974117701)
(191, :trnloss, 0.8917330456871475, :tstloss, 0.9174489420461481)
(192, :trnloss, 0.8984664343911289, :tstloss, 0.919791638208556)
(193, :trnloss, 0.8926956190819777, :tstloss, 0.9174213851961696)
(194, :trnloss, 0.9026186453353725, :tstloss, 0.9236692604319782)
(195, :trnloss, 0.8955805006951891, :tstloss, 0.9203112807365503)
(196, :trnloss, 0.9091714342089782, :tstloss, 0.9304654632648824)
(197, :trnloss, 0.8987137345087355, :tstloss, 0.9253306987382912)
(198, :trnloss, 0.9080666014052294, :tstloss, 0.9298170084790482)
(199, :trnloss, 0.9006431624368567, :tstloss, 0.9307052120403404)
(200, :trnloss, 0.9097864794425826, :tstloss, 0.9327873959559893)
([-0.12104, -0.203275, 0.0797697, 0.0950643, -0.0227906, -0.122038, 0.13459, 0.0908772, 0.0722425, -0.350894, 0.573885, 0.708688, -0.301313, -0.0274322, -0.320048, 0.0215012, 0.0789374, -0.0365025, 0.304972, 0.780068, -0.728792, -0.198594, -0.635068, 0.191104], [-5.87674])
In [347]:
# analyze performance of test set
# first look at by A and B numbers
test_set=index[701:end]
#retrieve formula from Oads table, categorize into # of As and Bs
AB=Int64[]
A2B=Int64[]
AB2=Int64[]
A2B2=Int64[]
for i in test_set
    for row in Oads
        if (row.Count)==i
            set_A,set_B=getelements(row.Name)
            size_A=length(set_A)
            size_B=length(set_B)
            if size_A==1 
                if size_B==1
                    push!(AB,i)
                elseif size_B==2
                    push!(AB2,i)
                end
            elseif size_A==2
                if size_B==1
                    push!(A2B,i)
                elseif size_B==2
                    push!(A2B2,i)
                end
            end
        end
    end
end
In [348]:
#Evaluate performance of each category by # of A sites and # of B sites
loss_AB=[]
for i in AB
    loss=abs((w[1]'*x_norm[:,i].+w[2])[1]-y[i])
    push!(loss_AB,loss)
end
MAE_AB=mean(loss_AB)

loss_A2B=[]
for i in A2B
    loss=abs((w[1]'*x_norm[:,i].+w[2])[1]-y[i])
    push!(loss_A2B,loss)
end
MAE_A2B=mean(loss_A2B)

loss_A2B2=[]
for i in A2B2
    loss=abs((w[1]'*x_norm[:,i].+w[2])[1]-y[i])
    push!(loss_A2B2,loss)
end
MAE_A2B2=mean(loss_A2B2)

loss_AB2=[]
for i in AB2
    loss=abs((w[1]'*x_norm[:,i].+w[2])[1]-y[i])
    push!(loss_AB2,loss)
end
MAE_AB2=mean(loss_AB2)

using Plots
x_dat=["AB", "A2B", "AB2", "A2B2"]
y_dat=[MAE_AB, MAE_A2B, MAE_AB2,MAE_A2B2]
plot(x_dat,y_dat,seriestype=:bar,xlabel = "Perovskite Type",ylabel = "MAE",legend=:none)
Out[348]:
AB AB AB A2B A2B A2B A2B A2B AB2 AB2 AB2 AB2 A2B2 A2B2 A2B2 0.0 0.5 1.0 1.5 Perovskite Type MAE
In [351]:
#Evaluate performance of each category by elements

set_el=String[]
for row in PD
    push!(set_el,row.Element)    
end
el_matt=zeros(length(set_el),1000)
el_dict=Dict{String,Int64}()
for i in 1:28
    el_dict[set_el[i]]=i
end
In [352]:
# el_matt: row number is the dict value of each element, if entry=1, then entry 
# has that element
# for example, row.Count=1, (22,1)=1 because it has Eu
for i in test_set
    for row in Oads
        if (row.Count)==i
            set_A,set_B=getelements(row.Name)
            bigset=vcat(set_A,set_B)
            for el in set_el
                if el in bigset
                    el_matt[el_dict[el],i]=1
                end
            end
        end
    end
end

#collect all non-zero entries
el_vec=[]
count_1=[]
dic_rev=Dict(value => key for (key, value) in el_dict)
# get average MAE vector by elements
for i in 1:28
    count=[]
    for j in 1:1000
        if el_matt[i,j]==1.0
            loss=abs((w[1]'*x_norm[:,j].+w[2])[1]-y[j])
            push!(count,loss)
        end
    end
    push!(el_vec,mean(count))
end
el_vec
Out[352]:
28-element Array{Any,1}:
 0.8254359950908214 
 0.5594722617440011 
 0.576613170208411  
 1.051258664451097  
 0.42115692480098776
 0.6553206058051882 
 0.7676780026758735 
 0.7855559945947036 
 1.6838770686007407 
 1.3971481474820888 
 0.6493086480146583 
 0.4563552533949525 
 0.7091403996409344 
 ⋮                  
 4.217212138646781  
 0.8106414651405175 
 1.888900905238512  
 0.6969945667286961 
 0.674820703488499  
 2.3567393121902525 
 3.8846422079000016 
 0.8048168819095773 
 2.4084748767324684 
 1.4870664461443481 
 0.7387288585587733 
 0.9575636234901876 
In [353]:
using Plots
x_dat_el=String[]
for i in 1:28
   
    element=dic_rev[i]
    push!(x_dat_el,element)
end
y_dat_el=el_vec
plot(x_dat_el,y_dat_el,seriestype=:bar,xlabel = "Element",ylabel = "MAE",legend=:none, xticks = :all,xtickfont = font(8, "Courier"))
Out[353]:
Ti V Cr Mn Fe Co Ni Cu Zr Nb Mo Tc Ru Rh Pd Ag Y Sc Ba La Ce Eu Gd Sr Nd Pr Sm Pm 0 1 2 3 4 Element MAE
In [ ]: