# A custom type that can be used to count the number of # each operations. # struct MyFloat64 v::Float64 end const a = 1 import Base: +, *, -, / # We redefine addition for this type, and increment # the counter each time we do addition. # AdditionCounter = 0 function +(a::MyFloat64, b::MyFloat64) global AdditionCounter += 1 MyFloat64(a.v + b.v) end function +(a::Any, b::MyFloat64) global AdditionCounter += 1 MyFloat64(a + b.v) end function +(a::MyFloat64, b::Any) global AdditionCounter += 1 MyFloat64(a.v + b) end SubtractionCounter = 0 function -(a::MyFloat64, b::MyFloat64) global SubtractionCounter += 1 MyFloat64(a.v - b.v) end function -(a::Any, b::MyFloat64) global SubtractionCounter += 1 MyFloat64(a - b.v) end function -(a::MyFloat64, b::Any) global SubtractionCounter += 1 MyFloat64(a.v - b) end MultiplicationCounter = 0 function *(a::MyFloat64, b::MyFloat64) global MultiplicationCounter += 1 MyFloat64(a.v * b.v) end function *(a::Any, b::MyFloat64) # global MultiplicationCounter += 1 MyFloat64(a * b.v) end function *(a::MyFloat64, b::Any) # global MultiplicationCounter += 1 MyFloat64(a.v * b) end DivisionCounter = 0 function /(a::MyFloat64, b::MyFloat64) global DivisionCounter += 1 MyFloat64(a.v / b.v) end function /(a::Any, b::MyFloat64) global DivisionCounter += 1 MyFloat64(a / b.v) end function /(a::MyFloat64, b::Any) global DivisionCounter += 1 MyFloat64(a.v / b) end # Define the @analyzeMyFloat macro. # macro analyzeMyFloat(ex) global AdditionCounter = 0 global SubtractionCounter = 0 global MultiplicationCounter = 0 global DivisionCounter = 0 result = eval(ex) println("Number of additions: ", AdditionCounter) println("Number of subtractions: ", SubtractionCounter) println("Number of multiplications: ", MultiplicationCounter) println("Number of divisions: ", DivisionCounter) return result end # Some simple tests. # # @analyzeMyFloat MyFloat64(1.0) + MyFloat64(2.0) # @analyzeMyFloat MyFloat64(1.0) - MyFloat64(2.0) # @analyzeMyFloat MyFloat64(1.0) * MyFloat64(2.0) # @analyzeMyFloat MyFloat64(1.0) / MyFloat64(2.0) # A = [MyFloat64(1) MyFloat64(2) MyFloat64(3); MyFloat64(1) MyFloat64(2) MyFloat64(3); MyFloat64(1) MyFloat64(2) MyFloat64(3)] # B = [MyFloat64(1) MyFloat64(2) MyFloat64(3); MyFloat64(1) MyFloat64(2) MyFloat64(3); MyFloat64(1) MyFloat64(2) MyFloat64(3)] # @analyzeMyFloat A*B # @analyzeMyFloat A+B # @analyzeMyFloat A-B # B = [1 2 3; 4 5 6; 7 8 9] # @analyzeMyFloat A+B # @analyzeMyFloat A-B # @analyzeMyFloat A*B # @analyzeMyFloat B+A # @analyzeMyFloat B*A # @analyzeMyFloat B-A # We use Cassette library to track the calls to functions. using Cassette Cassette.@context TraceCtx const callerCounter = Dict() # overwrite the Cassette execute function. function Cassette.execute(ctx::TraceCtx, args...) subtrace = Any[] tmp = (args => subtrace)[1][1] if (get(callerCounter, tmp, -1) != -1) callerCounter[tmp] += 1 end if Cassette.canoverdub(ctx, args...) newctx = Cassette.similarcontext(ctx, metadata = subtrace) return Cassette.overdub(newctx, args...) else return Cassette.fallback(ctx, args...) end end # Define the @analyze macro. # macro analyze(ex) callerCounter[+] = 0 callerCounter[-] = 0 callerCounter[*] = 0 callerCounter[/] = 0 result = :(trace = Any[]; Cassette.@overdub(TraceCtx(metadata = trace), $(esc(ex))); println("Number of additions: ", get(callerCounter, +, 0)); println("Number of subtractions: ", get(callerCounter, -, 0)); println("Number of multiplications: ", get(callerCounter, *, 0)); println("Number of divisions: ", get(callerCounter, /, 0))) return result end # This is the loss function we use to estimate the quality of an estimated complexity. It is defined by # \sum_{i} (cnts[i] - c(i^e))^2 function loss(c, e, cnts) tot = 0 for i=1:size(cnts)[1] tot += abs2(cnts[i] - c * (i ^ e)) end return tot end # Fix the exponent e, we can calculate the value of c that minimizes loss(c, e, cnts). # function minimizer(e, cnts) a = 0 b = 0 for i=1:size(cnts)[1] a += (i^e * cnts[i]) b += i^(2*e) end c = a / b return loss(c, e, cnts) end function estimateComplexity(f, terms=20) cnts = zeros(terms) for n=1:terms trace = Any[] callerCounter[+] = 0 callerCounter[-] = 0 callerCounter[*] = 0 callerCounter[/] = 0 Cassette.@overdub(TraceCtx(metadata = trace), f(n)) cnts[n] = callerCounter[+] + callerCounter[-] + callerCounter[*] + callerCounter[/] end # We enumerate all values of e up to some granularity. beste = 0.1 bestval = minimizer(0.1, cnts) for i = 0:1000 tmp = minimizer(i / 100.0 + 0.1, cnts) if (tmp < bestval) bestval = tmp beste = i / 100.0 + 0.1 end end println(beste) end