This file contains the following contents.

First, it defines a MyFloat64. The four arithmetic operations based on two MyFloat64 will be counted. The way to use it is @analyzeMyFloat.

Second, we define a similar @analyze macro, that counts arithmetic operations between any types.

Finally, we define the estimateComplexity function. Given a single argument function f to estimateComplexity, estimateComplexity(f) outputs a number e such that $n^c$ is an estimated complexity of f.

In [1]:
# A custom type that can be used to count the number of
# each operations.
#
struct MyFloat64
v::Float64
end

In [2]:
const a = 1

Out[2]:
1
In [3]:
import Base: +, *, -, /

In [4]:
# We redefine addition for this type, and increment
# the counter each time we do addition.
#
function +(a::MyFloat64, b::MyFloat64)
global AdditionCounter += 1
MyFloat64(a.v + b.v)
end
function +(a::Any, b::MyFloat64)
global AdditionCounter += 1
MyFloat64(a + b.v)
end
function +(a::MyFloat64, b::Any)
global AdditionCounter += 1
MyFloat64(a.v + b)
end

Out[4]:
+ (generic function with 166 methods)
In [5]:
SubtractionCounter = 0
function -(a::MyFloat64, b::MyFloat64)
global SubtractionCounter += 1
MyFloat64(a.v - b.v)
end
function -(a::Any, b::MyFloat64)
global SubtractionCounter += 1
MyFloat64(a - b.v)
end
function -(a::MyFloat64, b::Any)
global SubtractionCounter += 1
MyFloat64(a.v - b)
end

Out[5]:
- (generic function with 178 methods)
In [6]:
MultiplicationCounter = 0
function *(a::MyFloat64, b::MyFloat64)
global MultiplicationCounter += 1
MyFloat64(a.v * b.v)
end
function *(a::Any, b::MyFloat64)
#     global MultiplicationCounter += 1
MyFloat64(a * b.v)
end
function *(a::MyFloat64, b::Any)
#     global MultiplicationCounter += 1
MyFloat64(a.v * b)
end

Out[6]:
* (generic function with 346 methods)
In [7]:
DivisionCounter = 0
function /(a::MyFloat64, b::MyFloat64)
global DivisionCounter += 1
MyFloat64(a.v / b.v)
end
function /(a::Any, b::MyFloat64)
global DivisionCounter += 1
MyFloat64(a / b.v)
end
function /(a::MyFloat64, b::Any)
global DivisionCounter += 1
MyFloat64(a.v / b)
end

Out[7]:
/ (generic function with 107 methods)
In [8]:
# Define the @analyzeMyFloat macro.
#
macro analyzeMyFloat(ex)
global AdditionCounter = 0
global SubtractionCounter = 0
global MultiplicationCounter = 0
global DivisionCounter = 0
result = eval(ex)
println("Number of subtractions: ", SubtractionCounter)
println("Number of multiplications: ", MultiplicationCounter)
println("Number of divisions: ", DivisionCounter)
return result
end

Out[8]:
@analyzeMyFloat (macro with 1 method)
In [9]:
# Some simple tests.
#
# @analyzeMyFloat MyFloat64(1.0) + MyFloat64(2.0)

In [10]:
# @analyzeMyFloat MyFloat64(1.0) - MyFloat64(2.0)

In [11]:
# @analyzeMyFloat MyFloat64(1.0) * MyFloat64(2.0)

In [12]:
# @analyzeMyFloat MyFloat64(1.0) / MyFloat64(2.0)

In [13]:
# A = [MyFloat64(1) MyFloat64(2) MyFloat64(3); MyFloat64(1) MyFloat64(2) MyFloat64(3); MyFloat64(1) MyFloat64(2) MyFloat64(3)]

In [14]:
# B = [MyFloat64(1) MyFloat64(2) MyFloat64(3); MyFloat64(1) MyFloat64(2) MyFloat64(3); MyFloat64(1) MyFloat64(2) MyFloat64(3)]

In [15]:
# @analyzeMyFloat A*B

In [16]:
# @analyzeMyFloat A+B

In [17]:
# @analyzeMyFloat A-B

In [18]:
# B = [1 2 3; 4 5 6; 7 8 9]

In [19]:
# @analyzeMyFloat A+B

In [20]:
# @analyzeMyFloat A-B

In [21]:
# @analyzeMyFloat A*B

In [22]:
# @analyzeMyFloat B+A

In [23]:
# @analyzeMyFloat B*A

In [24]:
# @analyzeMyFloat B-A

In [25]:
# We use Cassette library to track the calls to functions.
using Cassette

In [26]:
Cassette.@context TraceCtx

Out[26]:
Cassette.Context{nametype(TraceCtx),M,P,T,B} where B<:Union{Nothing, IdDict{Module,Dict{Symbol,BindingMeta}}} where P<:Cassette.AbstractPass where T<:Union{Nothing, Tag} where M
In [27]:
const callerCounter = Dict()

Out[27]:
Dict{Any,Any} with 0 entries
In [28]:
# overwrite the Cassette execute function.
function Cassette.execute(ctx::TraceCtx, args...)
subtrace = Any[]
tmp = (args => subtrace)[1][1]
if (get(callerCounter, tmp, -1) != -1)
callerCounter[tmp] += 1
end
if Cassette.canoverdub(ctx, args...)
newctx = Cassette.similarcontext(ctx, metadata = subtrace)
return Cassette.overdub(newctx, args...)
else
return Cassette.fallback(ctx, args...)
end
end

In [29]:
# Define the @analyze macro.
#
macro analyze(ex)
callerCounter[+] = 0
callerCounter[-] = 0
callerCounter[*] = 0
callerCounter[/] = 0
result = :(trace = Any[];
Cassette.@overdub(TraceCtx(metadata = trace), \$(esc(ex)));
println("Number of additions: ", get(callerCounter, +, 0));
println("Number of subtractions: ", get(callerCounter, -, 0));
println("Number of multiplications: ", get(callerCounter, *, 0));
println("Number of divisions: ", get(callerCounter, /, 0)))
return result
end

Out[29]:
@analyze (macro with 1 method)

Now we start to define the estimateComplexity function.

In [31]:
# This is the loss function we use to estimate the quality of an estimated complexity. It is defined by
# \sum_{i} (cnts[i] - c(i^e))^2
function loss(c, e, cnts)
tot = 0
for i=1:size(cnts)[1]
tot += abs2(cnts[i] - c * (i ^ e))
end
end

Out[31]:
loss (generic function with 1 method)
In [32]:
# Fix the exponent e, we can calculate the value of c that minimizes loss(c, e, cnts).
#
function minimizer(e, cnts)
a = 0
b = 0
for i=1:size(cnts)[1]
a += (i^e * cnts[i])
b += i^(2*e)
end
c = a / b
return loss(c, e, cnts)
end

function estimateComplexity(f, terms=20)
cnts = zeros(terms)
for n=1:terms
trace = Any[]
callerCounter[+] = 0
callerCounter[-] = 0
callerCounter[*] = 0
callerCounter[/] = 0
Cassette.@overdub(TraceCtx(metadata = trace), f(n))
cnts[n] = callerCounter[+] + callerCounter[-] + callerCounter[*] + callerCounter[/]
end
# We enumerate all values of e up to some granularity.
beste = 0.1
bestval = minimizer(0.1, cnts)
for i = 0:1000
tmp = minimizer(i / 100.0 + 0.1, cnts)
if (tmp < bestval)
bestval = tmp
beste = i / 100.0 + 0.1
end
end
println(beste)
end

Out[32]:
estimateComplexity (generic function with 2 methods)

# Reference¶

Packages: Cassette

In [ ]: