import scipy,pylab,numpy from numpy import * from scipy import * from pylab import * from urllib import urlopen from gzip import GzipFile from scipy.spatial import distance from scipy.spatial.distance import cdist data = random_sample((1000,2)) labels = (data[:,0]*0.7+data[:,1]*0.4>0.5) sum(labels) d0 = data[labels==False] d1 = data[labels] figure(figsize=(6,6)); xlim((0,1)); ylim((0,1)) plot(d0[:,0],d0[:,1],"bo") plot(d1[:,0],d1[:,1],"ro") print len(d0),len(d1) augmented = concatenate([ones((len(data),1)),data],axis=1) augmented[:3,:] labels = 2.0*labels-1.0 a = dot(linalg.pinv(augmented),labels) d,a0,a1 = a print d,a0,a1 d0 = data[labels<=0] d1 = data[labels>0] figure(figsize=(6,6)); xlim((0,1)); ylim((0,1)) plot(d0[:,0],d0[:,1],"bo") plot(d1[:,0],d1[:,1],"ro") plot([0,-d/a0],[-d/a1,0],"g") savefig("tmp.png") print len(d0),len(d1) xs = linspace(0.0,1.0,100)[:,newaxis] ys = linspace(0.0,1.0,100)[newaxis,:] xs = xs + 0*ys ys = ys + 0*xs image = xs*a0+ys*a1+d imshow(image.T[::-1]); savefig("temp.png")