from mmtfPyspark.ml import SparkMultiClassClassifier, datasetBalancer
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.ml.classification import DecisionTreeClassifier, LogisticRegression, MultilayerPerceptronClassifier, RandomForestClassifier
spark = SparkSession.builder.appName("datasetClassifier").getOrCreate()
parquetFile = './input_features/'
data = spark.read.parquet(parquetFile).cache()
print(f"Total number of data: {data.count()}")
data.toPandas().head()
Total number of data: 18491
structureChainId | alpha | beta | coil | foldType | features | |
---|---|---|---|---|---|---|
0 | 1LBU.A | 0.361502 | 0.107981 | 0.530516 | other | [-0.03669819220865391, 0.13017714411934028, 0.... |
1 | 1LC0.A | 0.410345 | 0.275862 | 0.313793 | alpha+beta | [0.017792403316538488, 0.06889735366958401, 0.... |
2 | 1LC5.A | 0.428169 | 0.157746 | 0.414084 | alpha+beta | [0.12736012024892182, -0.0036459625095703716, ... |
3 | 1LFP.A | 0.427984 | 0.234568 | 0.337449 | alpha+beta | [0.07269115472257498, -0.010540929652990833, 0... |
4 | 1LFW.A | 0.322650 | 0.273504 | 0.403846 | alpha+beta | [-0.027897640212830196, 0.0941510383131058, 0.... |
data = data.where((data.foldType == 'alpha') | (data.foldType == 'beta')) #| (data.foldType == 'other'))
print(f"Total number of data: {data.count()}")
data.toPandas().head()
Total number of data: 4937
structureChainId | alpha | beta | coil | foldType | features | |
---|---|---|---|---|---|---|
0 | 1LGH.A | 0.857143 | 0.0 | 0.142857 | alpha | [0.23627377279441464, 0.05140024884180589, 0.4... |
1 | 1LGH.B | 0.744186 | 0.0 | 0.255814 | alpha | [0.07006392560221933, -0.05091538017785007, 0.... |
2 | 1LGH.D | 0.857143 | 0.0 | 0.142857 | alpha | [0.23627377279441464, 0.05140024884180589, 0.4... |
3 | 1LGH.E | 0.744186 | 0.0 | 0.255814 | alpha | [0.07006392560221933, -0.05091538017785007, 0.... |
4 | 1LGH.G | 0.857143 | 0.0 | 0.142857 | alpha | [0.23627377279441464, 0.05140024884180589, 0.4... |
label = 'foldType'
testFraction = 0.1
seed = 123
vector = data.first()["features"]
featureCount = len(vector)
print(f"Feature count : {featureCount}")
classCount = int(data.select(label).distinct().count())
print(f"Class count : {classCount}")
print(f"Dataset size (unbalanced) : {data.count()}")
data.groupby(label).count().show()
data = datasetBalancer.downsample(data, label, 1)
print(f"Dataset size (balanced) : {data.count()}")
data.groupby(label).count().show()
Feature count : 50 Class count : 2 Dataset size (unbalanced) : 4937 +--------+-----+ |foldType|count| +--------+-----+ | beta| 1253| | alpha| 3684| +--------+-----+ Dataset size (balanced) : 2487 +--------+-----+ |foldType|count| +--------+-----+ | beta| 1253| | alpha| 1234| +--------+-----+
rfc = RandomForestClassifier()
mcc = SparkMultiClassClassifier(rfc, label, testFraction, seed)
matrics = mcc.fit(data)
for k,v in matrics.items(): print(f"{k}\t{v}")
Class Train Test beta 1129 124 alpha 1096 138 Sample predictions: RandomForestClassifier +----------------+-----------+----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+ |structureChainId| alpha| beta| coil|foldType| features|indexedLabel| rawPrediction| probability|prediction|predictedLabel| +----------------+-----------+----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+ | 3C5X.C|0.037037037|0.56790125|0.39506173| beta|[0.09558909519236...| 0.0|[17.9023113878481...|[0.89511556939240...| 0.0| beta| | 4D7C.A|0.044715445| 0.5406504|0.41463414| beta|[-0.0319722487827...| 0.0|[16.1381744010601...|[0.80690872005300...| 0.0| beta| | 5LTG.B| 0.0| 0.5786164|0.42138365| beta|[0.05144563539998...| 0.0|[10.7898740235799...|[0.53949370117899...| 0.0| beta| | 2B4H.B|0.018348623|0.63761467| 0.3440367| beta|[-0.0295000851021...| 0.0|[18.2689413108828...|[0.91344706554414...| 0.0| beta| | 2QF4.A| 0.01764706| 0.5117647|0.47058824| beta|[-0.0272243778869...| 0.0|[15.5572626952883...|[0.77786313476441...| 0.0| beta| +----------------+-----------+----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+ only showing top 5 rows Total time taken: 8.132889032363892 Method RandomForestClassifier AUC 0.8634291725105189 F 0.862691591032037 Accuracy 0.8625954198473282 Precision 0.8636924144557732 Recall 0.8625954198473282 False Positive Rate 0.13573707482629038 True Positive Rate 0.8625954198473282 Confusion Matrix ['beta', 'alpha'] DenseMatrix([[109., 15.], [ 21., 117.]])
lr = LogisticRegression()
mcc = SparkMultiClassClassifier(lr, label, testFraction, seed)
matrics = mcc.fit(data)
for k,v in matrics.items(): print(f"{k}\t{v}")
Class Train Test beta 1129 124 alpha 1096 138 Sample predictions: LogisticRegression +----------------+-----------+----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+ |structureChainId| alpha| beta| coil|foldType| features|indexedLabel| rawPrediction| probability|prediction|predictedLabel| +----------------+-----------+----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+ | 3C5X.C|0.037037037|0.56790125|0.39506173| beta|[0.09558909519236...| 0.0|[2.91709717105090...|[0.94868516858396...| 0.0| beta| | 4D7C.A|0.044715445| 0.5406504|0.41463414| beta|[-0.0319722487827...| 0.0|[1.74714903999137...|[0.85159285165984...| 0.0| beta| | 5LTG.B| 0.0| 0.5786164|0.42138365| beta|[0.05144563539998...| 0.0|[0.51536824728804...|[0.62606407134874...| 0.0| beta| | 2B4H.B|0.018348623|0.63761467| 0.3440367| beta|[-0.0295000851021...| 0.0|[4.29881287897821...|[0.98659739396346...| 0.0| beta| | 2QF4.A| 0.01764706| 0.5117647|0.47058824| beta|[-0.0272243778869...| 0.0|[3.10840376209434...|[0.95723806408012...| 0.0| beta| +----------------+-----------+----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+ only showing top 5 rows Total time taken: 9.848273038864136 Method LogisticRegression AUC 0.8859864422627396 F 0.8855630038274853 Accuracy 0.8854961832061069 Precision 0.8860427110630056 Recall 0.8854961832061068 False Positive Rate 0.11352329868062769 True Positive Rate 0.8854961832061068 Confusion Matrix ['beta', 'alpha'] DenseMatrix([[111., 13.], [ 17., 121.]])
layers = [featureCount, 32, 32, classCount]
mpc = MultilayerPerceptronClassifier().setLayers(layers) \
.setBlockSize(128) \
.setSeed(1234) \
.setMaxIter(100)
mcc = SparkMultiClassClassifier(mpc, label, testFraction, seed)
matrics = mcc.fit(data)
for k,v in matrics.items(): print(f"{k}\t{v}")
Class Train Test beta 1129 124 alpha 1096 138 Sample predictions: MultilayerPerceptronClassifier +----------------+-----------+----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+ |structureChainId| alpha| beta| coil|foldType| features|indexedLabel| rawPrediction| probability|prediction|predictedLabel| +----------------+-----------+----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+ | 3C5X.C|0.037037037|0.56790125|0.39506173| beta|[0.09558909519236...| 0.0|[2.10432098158542...|[0.93830746536760...| 0.0| beta| | 4D7C.A|0.044715445| 0.5406504|0.41463414| beta|[-0.0319722487827...| 0.0|[1.65768386789357...|[0.85989879208471...| 0.0| beta| | 5LTG.B| 0.0| 0.5786164|0.42138365| beta|[0.05144563539998...| 0.0|[0.79207503061248...|[0.52331890951242...| 0.0| beta| | 2B4H.B|0.018348623|0.63761467| 0.3440367| beta|[-0.0295000851021...| 0.0|[3.15463251641009...|[0.99212519827765...| 0.0| beta| | 2QF4.A| 0.01764706| 0.5117647|0.47058824| beta|[-0.0272243778869...| 0.0|[3.96416203784699...|[0.99871776545954...| 0.0| beta| +----------------+-----------+----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+ only showing top 5 rows Total time taken: 11.974912881851196 Method MultilayerPerceptronClassifier AUC 0.8944600280504909 F 0.8932045402757206 Accuracy 0.8931297709923665 Precision 0.894944749906582 Recall 0.8931297709923665 False Positive Rate 0.10420971489138464 True Positive Rate 0.8931297709923665 Confusion Matrix ['beta', 'alpha'] DenseMatrix([[114., 10.], [ 18., 120.]])
spark.stop()