This file contains the following contents.
First, it defines a MyFloat64. The four arithmetic operations based on two MyFloat64 will be counted. The way to use it is @analyzeMyFloat
.
Second, we define a similar @analyze
macro, that counts arithmetic operations between any types.
Finally, we define the estimateComplexity
function. Given a single argument function f
to estimateComplexity
, estimateComplexity(f)
outputs a number e
such that $n^c$ is an estimated complexity of f
.
# A custom type that can be used to count the number of
# each operations.
#
struct MyFloat64
v::Float64
end
const a = 1
1
import Base: +, *, -, /
# We redefine addition for this type, and increment
# the counter each time we do addition.
#
AdditionCounter = 0
function +(a::MyFloat64, b::MyFloat64)
global AdditionCounter += 1
MyFloat64(a.v + b.v)
end
function +(a::Any, b::MyFloat64)
global AdditionCounter += 1
MyFloat64(a + b.v)
end
function +(a::MyFloat64, b::Any)
global AdditionCounter += 1
MyFloat64(a.v + b)
end
+ (generic function with 166 methods)
SubtractionCounter = 0
function -(a::MyFloat64, b::MyFloat64)
global SubtractionCounter += 1
MyFloat64(a.v - b.v)
end
function -(a::Any, b::MyFloat64)
global SubtractionCounter += 1
MyFloat64(a - b.v)
end
function -(a::MyFloat64, b::Any)
global SubtractionCounter += 1
MyFloat64(a.v - b)
end
- (generic function with 178 methods)
MultiplicationCounter = 0
function *(a::MyFloat64, b::MyFloat64)
global MultiplicationCounter += 1
MyFloat64(a.v * b.v)
end
function *(a::Any, b::MyFloat64)
# global MultiplicationCounter += 1
MyFloat64(a * b.v)
end
function *(a::MyFloat64, b::Any)
# global MultiplicationCounter += 1
MyFloat64(a.v * b)
end
* (generic function with 346 methods)
DivisionCounter = 0
function /(a::MyFloat64, b::MyFloat64)
global DivisionCounter += 1
MyFloat64(a.v / b.v)
end
function /(a::Any, b::MyFloat64)
global DivisionCounter += 1
MyFloat64(a / b.v)
end
function /(a::MyFloat64, b::Any)
global DivisionCounter += 1
MyFloat64(a.v / b)
end
/ (generic function with 107 methods)
# Define the @analyzeMyFloat macro.
#
macro analyzeMyFloat(ex)
global AdditionCounter = 0
global SubtractionCounter = 0
global MultiplicationCounter = 0
global DivisionCounter = 0
result = eval(ex)
println("Number of additions: ", AdditionCounter)
println("Number of subtractions: ", SubtractionCounter)
println("Number of multiplications: ", MultiplicationCounter)
println("Number of divisions: ", DivisionCounter)
return result
end
@analyzeMyFloat (macro with 1 method)
# Some simple tests.
#
# @analyzeMyFloat MyFloat64(1.0) + MyFloat64(2.0)
# @analyzeMyFloat MyFloat64(1.0) - MyFloat64(2.0)
# @analyzeMyFloat MyFloat64(1.0) * MyFloat64(2.0)
# @analyzeMyFloat MyFloat64(1.0) / MyFloat64(2.0)
# A = [MyFloat64(1) MyFloat64(2) MyFloat64(3); MyFloat64(1) MyFloat64(2) MyFloat64(3); MyFloat64(1) MyFloat64(2) MyFloat64(3)]
# B = [MyFloat64(1) MyFloat64(2) MyFloat64(3); MyFloat64(1) MyFloat64(2) MyFloat64(3); MyFloat64(1) MyFloat64(2) MyFloat64(3)]
# @analyzeMyFloat A*B
# @analyzeMyFloat A+B
# @analyzeMyFloat A-B
# B = [1 2 3; 4 5 6; 7 8 9]
# @analyzeMyFloat A+B
# @analyzeMyFloat A-B
# @analyzeMyFloat A*B
# @analyzeMyFloat B+A
# @analyzeMyFloat B*A
# @analyzeMyFloat B-A
# We use Cassette library to track the calls to functions.
using Cassette
Cassette.@context TraceCtx
Cassette.Context{nametype(TraceCtx),M,P,T,B} where B<:Union{Nothing, IdDict{Module,Dict{Symbol,BindingMeta}}} where P<:Cassette.AbstractPass where T<:Union{Nothing, Tag} where M
const callerCounter = Dict()
Dict{Any,Any} with 0 entries
# overwrite the Cassette execute function.
function Cassette.execute(ctx::TraceCtx, args...)
subtrace = Any[]
tmp = (args => subtrace)[1][1]
if (get(callerCounter, tmp, -1) != -1)
callerCounter[tmp] += 1
end
if Cassette.canoverdub(ctx, args...)
newctx = Cassette.similarcontext(ctx, metadata = subtrace)
return Cassette.overdub(newctx, args...)
else
return Cassette.fallback(ctx, args...)
end
end
# Define the @analyze macro.
#
macro analyze(ex)
callerCounter[+] = 0
callerCounter[-] = 0
callerCounter[*] = 0
callerCounter[/] = 0
result = :(trace = Any[];
Cassette.@overdub(TraceCtx(metadata = trace), $(esc(ex)));
println("Number of additions: ", get(callerCounter, +, 0));
println("Number of subtractions: ", get(callerCounter, -, 0));
println("Number of multiplications: ", get(callerCounter, *, 0));
println("Number of divisions: ", get(callerCounter, /, 0)))
return result
end
@analyze (macro with 1 method)
Now we start to define the estimateComplexity function.
# This is the loss function we use to estimate the quality of an estimated complexity. It is defined by
# \sum_{i} (cnts[i] - c(i^e))^2
function loss(c, e, cnts)
tot = 0
for i=1:size(cnts)[1]
tot += abs2(cnts[i] - c * (i ^ e))
end
return tot
end
loss (generic function with 1 method)
# Fix the exponent e, we can calculate the value of c that minimizes loss(c, e, cnts).
#
function minimizer(e, cnts)
a = 0
b = 0
for i=1:size(cnts)[1]
a += (i^e * cnts[i])
b += i^(2*e)
end
c = a / b
return loss(c, e, cnts)
end
function estimateComplexity(f, terms=20)
cnts = zeros(terms)
for n=1:terms
trace = Any[]
callerCounter[+] = 0
callerCounter[-] = 0
callerCounter[*] = 0
callerCounter[/] = 0
Cassette.@overdub(TraceCtx(metadata = trace), f(n))
cnts[n] = callerCounter[+] + callerCounter[-] + callerCounter[*] + callerCounter[/]
end
# We enumerate all values of e up to some granularity.
beste = 0.1
bestval = minimizer(0.1, cnts)
for i = 0:1000
tmp = minimizer(i / 100.0 + 0.1, cnts)
if (tmp < bestval)
bestval = tmp
beste = i / 100.0 + 0.1
end
end
println(beste)
end
estimateComplexity (generic function with 2 methods)
Packages: Cassette