# Plot the data...
colors = plt.cm.jet(np.linspace(0,1,len(Al_fraction)+1))
f, (ax1, ax2) = plt.subplots(1, 2, figsize=(11.25,4))
ax1b = ax1.twinx()
lns1 = ax1.plot(wl, GaAs_n, 'b', label='n, GaAs')
lns2 = ax1b.plot(wl, GaAs_k, 'r', label='k, GaAs')
lns3 = ax1.plot(wl, Ge_n, ls="--", color='blue', label='n, Ge')
lns4 = ax1b.plot(wl, Ge_k,ls="--", color='red', label='k, Ge')
ax1.set_xlim([300,1800])
ax1b.set_xlim([300,1800])
ax1b.set_ylim([0, 3.8])
# added these three lines
lns = lns1+lns2+lns3+lns4
labs = [l.get_label() for l in lns]
ax1.legend(lns, labs, loc="upper right", frameon=False)
ax1.text(0.05, 0.9, '(a)', transform=ax1.transAxes, fontsize=12)
ax1.set_xlabel("Wavelength (nm)")
ax1.set_ylabel("Refractive Index, n")
ax1b.set_ylabel("Extinction Coefficient, k")
for i, k in enumerate(Al_fraction):
ax2.plot(wl, AlGaAs_k[i], color=colors[i+1], label='{}%'.format(int(Al_fraction[i])))
ax2.set_xlim([300, 900])
ax2.set_ylim([0, 2.8])
ax2.set_xlabel("Wavelength (nm)")
ax2.set_ylabel("Extinction Coefficient, k")
ax2.legend(loc="upper right", frameon=False)
ax2.text(0.05, 0.9, '(b)', transform=ax2.transAxes, fontsize=12)
plt.tight_layout(w_pad=4)
plt.show()