This file contains the following contents.

First, it defines a MyFloat64. The four arithmetic operations based on two MyFloat64 will be counted. The way to use it is @analyzeMyFloat.

Second, we define a similar @analyze macro, that counts arithmetic operations between any types.

Finally, we define the estimateComplexity function. Given a single argument function f to estimateComplexity, estimateComplexity(f) outputs a number e such that $n^c$ is an estimated complexity of f.

In [1]:
# A custom type that can be used to count the number of 
# each operations.
#
struct MyFloat64
    v::Float64
end
In [2]:
const a = 1
Out[2]:
1
In [3]:
import Base: +, *, -, /
In [4]:
# We redefine addition for this type, and increment
# the counter each time we do addition.
#
AdditionCounter = 0
function +(a::MyFloat64, b::MyFloat64)
    global AdditionCounter += 1
    MyFloat64(a.v + b.v)
end
function +(a::Any, b::MyFloat64)
    global AdditionCounter += 1
    MyFloat64(a + b.v)
end
function +(a::MyFloat64, b::Any)
    global AdditionCounter += 1
    MyFloat64(a.v + b)
end
Out[4]:
+ (generic function with 166 methods)
In [5]:
SubtractionCounter = 0
function -(a::MyFloat64, b::MyFloat64)
    global SubtractionCounter += 1
    MyFloat64(a.v - b.v)
end
function -(a::Any, b::MyFloat64)
    global SubtractionCounter += 1
    MyFloat64(a - b.v)
end
function -(a::MyFloat64, b::Any)
    global SubtractionCounter += 1
    MyFloat64(a.v - b)
end
Out[5]:
- (generic function with 178 methods)
In [6]:
MultiplicationCounter = 0
function *(a::MyFloat64, b::MyFloat64)
    global MultiplicationCounter += 1
    MyFloat64(a.v * b.v)
end
function *(a::Any, b::MyFloat64)
#     global MultiplicationCounter += 1
    MyFloat64(a * b.v)
end
function *(a::MyFloat64, b::Any)
#     global MultiplicationCounter += 1
    MyFloat64(a.v * b)
end
Out[6]:
* (generic function with 346 methods)
In [7]:
DivisionCounter = 0
function /(a::MyFloat64, b::MyFloat64)
    global DivisionCounter += 1
    MyFloat64(a.v / b.v)
end
function /(a::Any, b::MyFloat64)
    global DivisionCounter += 1
    MyFloat64(a / b.v)
end
function /(a::MyFloat64, b::Any)
    global DivisionCounter += 1
    MyFloat64(a.v / b)
end
Out[7]:
/ (generic function with 107 methods)
In [8]:
# Define the @analyzeMyFloat macro. 
#
macro analyzeMyFloat(ex)
    global AdditionCounter = 0
    global SubtractionCounter = 0
    global MultiplicationCounter = 0
    global DivisionCounter = 0
    result = eval(ex)
    println("Number of additions: ", AdditionCounter)
    println("Number of subtractions: ", SubtractionCounter)
    println("Number of multiplications: ", MultiplicationCounter)
    println("Number of divisions: ", DivisionCounter)
    return result
end
Out[8]:
@analyzeMyFloat (macro with 1 method)
In [9]:
# Some simple tests.
#
# @analyzeMyFloat MyFloat64(1.0) + MyFloat64(2.0)
In [10]:
# @analyzeMyFloat MyFloat64(1.0) - MyFloat64(2.0)
In [11]:
# @analyzeMyFloat MyFloat64(1.0) * MyFloat64(2.0)
In [12]:
# @analyzeMyFloat MyFloat64(1.0) / MyFloat64(2.0)
In [13]:
# A = [MyFloat64(1) MyFloat64(2) MyFloat64(3); MyFloat64(1) MyFloat64(2) MyFloat64(3); MyFloat64(1) MyFloat64(2) MyFloat64(3)]
In [14]:
# B = [MyFloat64(1) MyFloat64(2) MyFloat64(3); MyFloat64(1) MyFloat64(2) MyFloat64(3); MyFloat64(1) MyFloat64(2) MyFloat64(3)]
In [15]:
# @analyzeMyFloat A*B
In [16]:
# @analyzeMyFloat A+B
In [17]:
# @analyzeMyFloat A-B
In [18]:
# B = [1 2 3; 4 5 6; 7 8 9]
In [19]:
# @analyzeMyFloat A+B
In [20]:
# @analyzeMyFloat A-B
In [21]:
# @analyzeMyFloat A*B
In [22]:
# @analyzeMyFloat B+A
In [23]:
# @analyzeMyFloat B*A
In [24]:
# @analyzeMyFloat B-A
In [25]:
# We use Cassette library to track the calls to functions.
using Cassette
In [26]:
Cassette.@context TraceCtx
Out[26]:
Cassette.Context{nametype(TraceCtx),M,P,T,B} where B<:Union{Nothing, IdDict{Module,Dict{Symbol,BindingMeta}}} where P<:Cassette.AbstractPass where T<:Union{Nothing, Tag} where M
In [27]:
const callerCounter = Dict()
Out[27]:
Dict{Any,Any} with 0 entries
In [28]:
# overwrite the Cassette execute function.
function Cassette.execute(ctx::TraceCtx, args...)
    subtrace = Any[]
    tmp = (args => subtrace)[1][1]
    if (get(callerCounter, tmp, -1) != -1)
        callerCounter[tmp] += 1
    end
    if Cassette.canoverdub(ctx, args...)
        newctx = Cassette.similarcontext(ctx, metadata = subtrace)
        return Cassette.overdub(newctx, args...)
    else
        return Cassette.fallback(ctx, args...)
    end
end
In [29]:
# Define the @analyze macro. 
#
macro analyze(ex)
    callerCounter[+] = 0
    callerCounter[-] = 0
    callerCounter[*] = 0
    callerCounter[/] = 0
    result = :(trace = Any[];
    Cassette.@overdub(TraceCtx(metadata = trace), $(esc(ex)));
    println("Number of additions: ", get(callerCounter, +, 0));
    println("Number of subtractions: ", get(callerCounter, -, 0));
    println("Number of multiplications: ", get(callerCounter, *, 0));
    println("Number of divisions: ", get(callerCounter, /, 0)))
    return result
end
Out[29]:
@analyze (macro with 1 method)

Now we start to define the estimateComplexity function.

In [31]:
# This is the loss function we use to estimate the quality of an estimated complexity. It is defined by 
# \sum_{i} (cnts[i] - c(i^e))^2
function loss(c, e, cnts)
    tot = 0
    for i=1:size(cnts)[1]
        tot += abs2(cnts[i] - c * (i ^ e))
    end
    return tot
end
Out[31]:
loss (generic function with 1 method)
In [32]:
# Fix the exponent e, we can calculate the value of c that minimizes loss(c, e, cnts).
# 
function minimizer(e, cnts)
    a = 0
    b = 0
    for i=1:size(cnts)[1]
        a += (i^e * cnts[i])
        b += i^(2*e)
    end
    c = a / b
    return loss(c, e, cnts)
end

function estimateComplexity(f, terms=20)
    cnts = zeros(terms)
    for n=1:terms
        trace = Any[]
        callerCounter[+] = 0
        callerCounter[-] = 0
        callerCounter[*] = 0
        callerCounter[/] = 0
        Cassette.@overdub(TraceCtx(metadata = trace), f(n))
        cnts[n] = callerCounter[+] + callerCounter[-] + callerCounter[*] + callerCounter[/] 
    end
    # We enumerate all values of e up to some granularity.
    beste = 0.1
    bestval = minimizer(0.1, cnts)
    for i = 0:1000
        tmp = minimizer(i / 100.0 + 0.1, cnts)
        if (tmp < bestval)
            bestval = tmp
            beste = i / 100.0 + 0.1
        end
    end
    println(beste)
end
Out[32]:
estimateComplexity (generic function with 2 methods)

Reference

Packages: Cassette

In [ ]: