Demo to create a feature vector for protein fold classification. In this demo we try to classify a protein chain as either an all alpha or all beta protein based on protein sequence. We use n-grams and a Word2Vec representation of the protein sequence as a feature vector.
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from mmtfPyspark.io import mmtfReader
from mmtfPyspark.webfilters import Pisces
from mmtfPyspark.filters import ContainsLProteinChain
from mmtfPyspark.mappers import StructureToPolymerChains
from mmtfPyspark.datasets import secondaryStructureExtractor
from mmtfPyspark.ml import ProteinSequenceEncoder
spark = SparkSession.builder.appName("1-Features").getOrCreate()
pdb = mmtfReader.read_sequence_file('../resources/mmtf_reduced_sample/') \
.flatMap(StructureToPolymerChains()) \
.filter(ContainsLProteinChain())
data = secondaryStructureExtractor.get_dataset(pdb)
data.show(5)
+----------------+--------------------+----------+----------+----------+--------------------+--------------------+ |structureChainId| sequence| alpha| beta| coil| dsspQ8Code| dsspQ3Code| +----------------+--------------------+----------+----------+----------+--------------------+--------------------+ | 4WMY.A|TDWSHPQFEKSTDEANT...|0.19081272|0.26855123|0.54063606|XXXXXXXXXXXXXXXXX...|XXXXXXXXXXXXXXXXX...| | 4WMY.B|TDWSHPQFEKSTDEANT...|0.17081851|0.26334518| 0.5658363|XXXXXXXXXXXXXXXXX...|XXXXXXXXXXXXXXXXX...| | 4WN5.A|GSHMGRGAFLSRHSLDM...| 0.2962963|0.37962964|0.32407406|XXCCCCCCEEEEECTTC...|XXCCCCCCEEEEECCCC...| | 4WN5.B|GSHMGRGAFLSRHSLDM...|0.33333334|0.37142858| 0.2952381|XXXXXCCCEEEEECTTC...|XXXXXCCCEEEEECCCC...| | 4WND.A|GPGSMEASCLELALEGE...| 0.8358663| 0.0|0.16413374|XXXXCCSCHHHHHHHHH...|XXXXCCCCHHHHHHHHH...| +----------------+--------------------+----------+----------+----------+--------------------+--------------------+ only showing top 5 rows
def add_protein_fold_type(data, minThreshold, maxThreshold):
'''
Adds a column "foldType" with three major secondary structure class:
"alpha", "beta", "alpha+beta", and "other" based upon the fraction of alpha/beta content.
The simplified syntax used in this method relies on two imports:
from pyspark.sql.functions import when
from pyspark.sql.functions import col
Attributes:
data (Dataset<Row>): input dataset with alpha, beta composition
minThreshold (float): below this threshold, the secondary structure is ignored
maxThreshold (float): above this threshold, the secondary structure is ignored
'''
return data.withColumn("foldType", \
when((col("alpha") > maxThreshold) & (col("beta") < minThreshold), "alpha"). \
when((col("beta") > maxThreshold) & (col("alpha") < minThreshold), "beta"). \
when((col("alpha") > maxThreshold) & (col("beta") > maxThreshold), "alpha+beta"). \
otherwise("other")\
)
data = add_protein_fold_type(data, minThreshold=0.05, maxThreshold=0.15)
data.show()
+----------------+--------------------+-----------+-----------+----------+--------------------+--------------------+----------+ |structureChainId| sequence| alpha| beta| coil| dsspQ8Code| dsspQ3Code| foldType| +----------------+--------------------+-----------+-----------+----------+--------------------+--------------------+----------+ | 4WMY.A|TDWSHPQFEKSTDEANT...| 0.19081272| 0.26855123|0.54063606|XXXXXXXXXXXXXXXXX...|XXXXXXXXXXXXXXXXX...|alpha+beta| | 4WMY.B|TDWSHPQFEKSTDEANT...| 0.17081851| 0.26334518| 0.5658363|XXXXXXXXXXXXXXXXX...|XXXXXXXXXXXXXXXXX...|alpha+beta| | 4WN5.A|GSHMGRGAFLSRHSLDM...| 0.2962963| 0.37962964|0.32407406|XXCCCCCCEEEEECTTC...|XXCCCCCCEEEEECCCC...|alpha+beta| | 4WN5.B|GSHMGRGAFLSRHSLDM...| 0.33333334| 0.37142858| 0.2952381|XXXXXCCCEEEEECTTC...|XXXXXCCCEEEEECCCC...|alpha+beta| | 4WND.A|GPGSMEASCLELALEGE...| 0.8358663| 0.0|0.16413374|XXXXCCSCHHHHHHHHH...|XXXXCCCCHHHHHHHHH...| alpha| | 4WND.B|GPLGSDLPPKVVPSKQL...|0.115384616| 0.0|0.88461536|XXXXXXXXXXXXXXXCC...|XXXXXXXXXXXXXXXCC...| other| | 4WP6.A|GSHHHHHHSQDPMQAAQ...| 0.45695364|0.119205296|0.42384106|XXXXXXXXXXXXXXXXX...|XXXXXXXXXXXXXXXXX...| other| | 4WP9.A|FQGAMGSRVVILFTDIE...| 0.3939394| 0.3151515|0.29090908|XXCCSSEEEEEEEEEET...|XXCCCCEEEEEEEEEEC...|alpha+beta| | 4WP9.B|FQGAMGSRVVILFTDIE...| 0.4| 0.3125| 0.2875|XXXCCSEEEEEEEEEET...|XXXCCCEEEEEEEEEEC...|alpha+beta| | 4WPG.A|GPLLEMILITGSNGQLG...| 0.39372823| 0.17073171|0.43554008|XCCSCCEEEESTTSHHH...|XCCCCCEEEECCCCHHH...|alpha+beta| | 4WPK.A|MHHHHHHGMASMTARPL...| 0.4122807|0.114035085|0.47368422|XXXXXXXXXXCTTTSCH...|XXXXXXXXXXCCCCCCH...| other| | 4WQD.A|MEPPTVALTVPAAALLP...| 0.3991228|0.057017542|0.54385966|XXXXCBCCCCCCGGGCC...|XXXXCECCCCCCHHHCC...| other| | 4WRI.A|GILANLKEPSAHWCRKM...| 0.62032086|0.053475935| 0.3262032|XXXXXCCCCCHHHHHHH...|XXXXXCCCCCHHHHHHH...| other| | 4WSF.A|TTDTRRRVKLYALNAER...| 0.16216215| 0.4774775|0.36036035|XXCCTTEEEEEEECTTS...|XXCCCCEEEEEEECCCC...|alpha+beta| | 4WSF.B| PDESSADVVFKKPLAPAPR| 0.0| 0.0| 1.0| XXXXXXXCCSCCCSSCCCX| XXXXXXXCCCCCCCCCCCX| other| | 1GWM.A|MNVRATYTVIFKNASGL...|0.039215688| 0.503268|0.45751634|CCCSCCEEEEESSCSSS...|CCCCCCEEEEECCCCCC...| beta| | 1GXM.A|GLVPRGSHMTGRMLTLD...| 0.42901236| 0.13580246| 0.4351852|XXXXXXXXCBTTBCCCT...|XXXXXXXXCECCECCCC...| other| | 1GXM.B|GLVPRGSHMTGRMLTLD...| 0.4186747| 0.12951808|0.45180723|CCCCTTTTCBTTBCCCT...|CCCCCCCCCECCECCCC...| other| | 1GXR.A|DYFQGAMGSKPAYSFHV...| 0.0| 0.5432836|0.45671642|CCEEEEEEEEECCEEEE...|CCEEEEEEEEECCEEEE...| beta| | 1GXR.B|DYFQGAMGSKPAYSFHV...| 0.0| 0.5555556|0.44444445|CCEEEEEEEEECCEEET...|CCEEEEEEEEECCEEEC...| beta| +----------------+--------------------+-----------+-----------+----------+--------------------+--------------------+----------+ only showing top 20 rows
n = 2 # create 2-grams
windowSize = 25 # 25-amino residue window size for Word2Vector
vectorSize = 50 # dimension of feature vector
encoder = ProteinSequenceEncoder(data)
data = encoder.overlapping_ngram_word2vec_encode(n=2, windowSize=25, vectorSize=50).cache()
data.toPandas().head(5)
structureChainId | sequence | alpha | beta | coil | dsspQ8Code | dsspQ3Code | foldType | ngram | features | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 4WMY.A | TDWSHPQFEKSTDEANTYFKEWTCSSSPSLPRSCKEIKDECPSAFD... | 0.190813 | 0.268551 | 0.540636 | XXXXXXXXXXXXXXXXXXXXXXXCCCCCCCCSSHHHHHHHCTTCCS... | XXXXXXXXXXXXXXXXXXXXXXXCCCCCCCCCCHHHHHHHCCCCCC... | alpha+beta | [TD, DW, WS, SH, HP, PQ, QF, FE, EK, KS, ST, T... | [0.028354697964596942, 0.06656068684991266, 0.... |
1 | 4WMY.B | TDWSHPQFEKSTDEANTYFKEWTCSSSPSLPRSCKEIKDECPSAFD... | 0.170819 | 0.263345 | 0.565836 | XXXXXXXXXXXXXXXXXXXXXCCCXXXXCCCSSHHHHHHHCTTCCS... | XXXXXXXXXXXXXXXXXXXXXCCCXXXXCCCCCHHHHHHHCCCCCC... | alpha+beta | [TD, DW, WS, SH, HP, PQ, QF, FE, EK, KS, ST, T... | [0.028354697964596942, 0.06656068684991266, 0.... |
2 | 4WN5.A | GSHMGRGAFLSRHSLDMKFTYCDDRIAEVAGYSPDDLIGCSAYEYI... | 0.296296 | 0.379630 | 0.324074 | XXCCCCCCEEEEECTTCBEEEECGGHHHHHSCCHHHHBTSBGGGGB... | XXCCCCCCEEEEECCCCEEEEECHHHHHHHCCCHHHHECCEHHHHE... | alpha+beta | [GS, SH, HM, MG, GR, RG, GA, AF, FL, LS, SR, R... | [-0.04048257577641491, 0.1233881547426184, 0.3... |
3 | 4WN5.B | GSHMGRGAFLSRHSLDMKFTYCDDRIAEVAGYSPDDLIGCSAYEYI... | 0.333333 | 0.371429 | 0.295238 | XXXXXCCCEEEEECTTCBEEEECGGHHHHHSCCHHHHBTSBGGGGB... | XXXXXCCCEEEEECCCCEEEEECHHHHHHHCCCHHHHECCEHHHHE... | alpha+beta | [GS, SH, HM, MG, GR, RG, GA, AF, FL, LS, SR, R... | [-0.04048257577641491, 0.1233881547426184, 0.3... |
4 | 4WND.A | GPGSMEASCLELALEGERLCKSGDCRAGVSFFEAAVQVGTEDLKTL... | 0.835866 | 0.000000 | 0.164134 | XXXXCCSCHHHHHHHHHHHHHTTCHHHHHHHHHHHHHHCCSCHHHH... | XXXXCCCCHHHHHHHHHHHHHCCCHHHHHHHHHHHHHHCCCCHHHH... | alpha | [GP, PG, GS, SM, ME, EA, AS, SC, CL, LE, EL, L... | [-0.009619595496742813, 0.03677304709491171, 0... |
data = data.select(['structureChainId','alpha','beta','coil','foldType','features'])
data.write.mode('overwrite').format('parquet').save('./input_features')
spark.stop()