import pandas as pd
import numpy as np
import dask.dataframe as dd
import matplotlib.pyplot as plt
import seaborn as sns

pd.options.display.max_rows = 10

dtype = {
    'vendor_name': 'category',
    'Payment_Type': 'category',
}

df = pd.read_csv("data/yellow_tripdata_2009-01.csv", dtype=dtype,
                 parse_dates=['Trip_Pickup_DateTime', 'Trip_Dropoff_DateTime'],)

df.head()

X = df.drop("Tip_Amt", axis=1)
y = df['Tip_Amt'] > 0

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y)

len(X_train)

len(X_test)

df.Payment_Type.cat.categories

df.Payment_Type.str.lower()

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import FunctionTransformer

class ColumnSelector(TransformerMixin):
    "Select `columns` from `X`"
    def __init__(self, columns):
        self.columns = columns
    
    def fit(self, X, y=None):
        return self
    
    def transform(self, X, y=None):
        return X[self.columns]

class HourExtractor(TransformerMixin):
    "Transform each datetime64 column in `columns` to integer hours"
    def __init__(self, columns):
        self.columns = columns
    
    def fit(self, X, y=None):
        return self
    
    def transform(self, X, y=None):
        return X.assign(**{col: lambda x: x[col].dt.hour for col in self.columns})

def payment_lowerer(X):
    """Lowercase all the Payment_Type values"""
    return X.assign(Payment_Type=X.Payment_Type.str.lower())

class CategoricalEncoder(TransformerMixin):
    """Convert to Categorical with specific `categories`"""
    def __init__(self, categories):
        self.categories = categories
    
    def fit(self, X, y=None):
        return self
    
    def transform(self, X, y=None):
        for col, categories in self.categories.items():
            X[col] = X[col].astype('category').cat.set_categories(categories)
        return X

class StandardScaler(TransformerMixin):
    "Scale a subset of the columns in a DataFrame"
    def __init__(self, columns):
        self.columns = columns
    
    def fit(self, X, y=None):
        self.μs = X[self.columns].mean()
        self.σs = X[self.columns].std()
        return self
    
    def transform(self, X, y=None):
        X = X.copy()
        X[self.columns] = X[self.columns].sub(self.μs).div(self.σs)
        return X

# The columns at the start of the pipeline
columns = ['vendor_name', 'Trip_Pickup_DateTime', 'Passenger_Count',
           'Trip_Distance', 'Payment_Type', 'Fare_Amt', 'surcharge']

# The mapping of {column: set of categories}
categories = {
    'vendor_name': ['CMT', 'DDS', 'VTS'],
    'Payment_Type': ['cash', 'credit', 'dispute', 'no charge'],
}

scale = ['Trip_Distance', 'Fare_Amt', 'surcharge']

pipe = make_pipeline(
    ColumnSelector(columns),
    HourExtractor(['Trip_Pickup_DateTime']),
    FunctionTransformer(payment_lowerer, validate=False),
    CategoricalEncoder(categories),
    FunctionTransformer(pd.get_dummies, validate=False),
    StandardScaler(scale),
    LogisticRegression(),
)
pipe

pipe.steps

pipe.fit(X_train, y_train)

pipe.score(X_train, y_train)

pipe.score(X_test, y_test)

def mkpipe():
    pipe = make_pipeline(
        ColumnSelector(columns),
        HourExtractor(['Trip_Pickup_DateTime']),
        FunctionTransformer(payment_lowerer, validate=False),
        CategoricalEncoder(categories),
        FunctionTransformer(pd.get_dummies, validate=False),
        StandardScaler(scale),
        LogisticRegression(),
    )
    return pipe

## Scaling it Out

import dask.dataframe as dd

df = dd.read_csv("data/*.csv", dtype=dtype,
                 parse_dates=['Trip_Pickup_DateTime', 'Trip_Dropoff_DateTime'],)

X = df.drop("Tip_Amt", axis=1)
y = df['Tip_Amt'] > 0

Since the scikit-learn world isn't really "dask-aware" at the moment, we'll use the `map_partitions` method. This is a good escape hatch for dealing with non-daskified code.

yhat = X.map_partitions(lambda x: pd.Series(pipe.predict_proba(x)[:, 1], name='yhat'), meta=('yhat', 'f8'))

yhat.to_frame().to_parquet("data/predictions.parq")