:set -XTypeFamilies class GradientDescent a where -- Type to represent the parameter space. data Params a :: * -- Compute the gradient at a location in parameter space. grad :: a -> Params a -> Params a -- Move in parameter space. paramMove :: Double -- Scaling factor. -> Params a -- Direction vector. -> Params a -- Original location. -> Params a -- New location. -- We need flexible instances for declarations like these. :set -XFlexibleInstances instance Floating a => GradientDescent (a -> a) where -- The parameter for a function is just its argument. data Params (a -> a) = Arg { unArg :: a } -- Use numeric differentiation for taking the gradient. grad f (Arg value) = Arg $ (f value - f (value - epsilon)) / epsilon where epsilon = 0.0001 paramMove scale (Arg vec) (Arg old) = Arg $ old + fromRational (toRational scale) * vec -- Define a way to decide when to stop. -- This lets the user specify an error tolerance easily. -- The function takes the previous two sets of parameters and returns -- `True` to continue the descent and `False` to stop. newtype StopCondition a = StopWhen (Params a -> Params a -> Bool) gradientDescent :: GradientDescent a => a -- What to optimize. -> StopCondition a -- When to stop. -> Double -- Step size (alpha). -> Params a -- Initial point (x0). -> Params a -- Return: Location of minimum. gradientDescent function (StopWhen stop) alpha x0 = let iterations = iterate takeStep x0 iterationPairs = zip iterations $ tail iterations in -- Drop all elements where the resulting parameters (and previous parameters) -- do not meet the stop condition. Then, return just the last parameter set. snd . head $ dropWhile (not . uncurry stop) iterationPairs where -- For each step... takeStep params = -- Compute the gradient. let gradients = grad function params in -- And move against the gradient with a step size alpha. paramMove (-alpha) gradients params -- Create a stop condition that respects a given error tolerance. stopCondition :: (Double -> Double) -> Double -> StopCondition (Double -> Double) stopCondition f tolerance = StopWhen stop where stop (Arg prev) (Arg cur) = abs (f prev - f cur) < tolerance -- A demo function with minimum at -3/2 function x = x^2 + 3 * x let alpha = 1e-1 let tolerance = 1e-4 let initValue = 12.0 unArg $ gradientDescent function (stopCondition function tolerance) alpha (Arg initValue) :set -XMultiParamTypeClasses class Monad m => GradientDescent m a where -- Type to represent the parameter space. data Params a :: * -- Compute the gradient at a location in parameter space. grad :: a -> Params a -> m (Params a) -- Move in parameter space. paramMove :: Double -- Scaling factor. -> Params a -- Direction vector. -> Params a -- Original location. -> m (Params a) -- New location. -- Since we've redefined GradientDescent, we need to redefine StopCondition. newtype StopCondition a = StopWhen (Params a -> Params a -> Bool) gradientDescent :: (GradientDescent m a) => a -- What to optimize. -> StopCondition a -- When to stop. -> Double -- Step size (alpha). -> Params a -- Initial point (x0). -> m (Params a) -- Return: Location of minimum. gradientDescent function (StopWhen stop) alpha x0 = do -- Take the next step. next <- takeStep x0 -- If we stop, do so, otherwise recurse. if stop x0 next then return next else gradientDescent function (StopWhen stop) alpha next where takeStep params = do gradients <- grad function params paramMove (-alpha) gradients params instance (Ord a, Floating a) => GradientDescent Maybe (a -> a) where -- The parameter for a function is just its argument. data Params (a -> a) = Arg { unArg :: a } -- Use numeric differentiation for taking the gradient. grad f (Arg value) = if value > 0 then Just $ Arg $ (f value - f (value - epsilon)) / epsilon else Nothing where epsilon = 0.0001 paramMove scale (Arg vec) (Arg old) = Just $ Arg $ old + fromRational (toRational scale) * vec stopCondition f tolerance = StopWhen stop where stop (Arg prev) (Arg cur) = abs (f prev - f cur) < tolerance let x0 = Arg initValue let stopper = stopCondition function tolerance case gradientDescent function stopper alpha x0 of Just x -> print $ unArg x Nothing -> putStrLn "Nothing!"