Prev: Common Pass Idioms
This AnalyzeCircuit Transform walks a firrtl.ir.Circuit
, and records the number of add ops it finds, per module.
Please run the following:
val path = System.getProperty("user.dir") + "/source/load-ivy.sc"
interp.load.module(ammonite.ops.Path(java.nio.file.FileSystems.getDefault().getPath(path)))
// Compiler Infrastructure
// Firrtl IR classes
// Map functions
// Scala's mutable collections
import scala.collection.mutable
As described, earlier, a Firrtl circuit is represented using a tree representation:
Circuit
contains a sequence of DefModule
s.DefModule
contains a sequence of Port
s, and maybe a Statement
.Statement
can contain other Statement
s, or Expression
s.Expression
can contain other Expression
s.To visit all Firrtl IR nodes in a circuit, we write functions that recursively walk down this tree. To record statistics, we will pass along a Ledger
class and use it when we come across an add op:
class Ledger {
import firrtl.Utils
private var moduleName: Option[String] = None
private val modules = mutable.Set[String]()
private val moduleAddMap = mutable.Map[String, Int]()
def foundAdd(): Unit = moduleName match {
case None => sys.error("Module name not defined in Ledger!")
case Some(name) => moduleAddMap(name) = moduleAddMap.getOrElse(name, 0) + 1
}
def getModuleName: String = moduleName match {
case None => Utils.error("Module name not defined in Ledger!")
case Some(name) => name
}
def setModuleName(myName: String): Unit = {
modules += myName
moduleName = Some(myName)
}
def serialize: String = {
modules map { myName =>
s"$myName => ${moduleAddMap.getOrElse(myName, 0)} add ops!"
} mkString "\n"
}
}
Now, let's define a FIRRTL Transform that walks the circuit and updates our Ledger
whenever it comes across an adder (DoPrim
with op argument Add
). Don't worry about inputForm
or outputForm
for now.
Take some time to understand how walkModule
, walkStatement
, and walkExpression
enable traversing all DefModule
, Statement
, and Expression
nodes in the FIRRTL AST.
Questions to answer:
class AnalyzeCircuit extends firrtl.Transform {
import firrtl._
import firrtl.ir._
import firrtl.Mappers._
import firrtl.Parser._
import firrtl.annotations._
import firrtl.PrimOps._
// Requires the [[Circuit]] form to be "low"
def inputForm = LowForm
// Indicates the output [[Circuit]] form to be "low"
def outputForm = LowForm
// Called by [[Compiler]] to run your pass. [[CircuitState]] contains
// the circuit and its form, as well as other related data.
def execute(state: CircuitState): CircuitState = {
val ledger = new Ledger()
val circuit = state.circuit
// Execute the function walkModule(ledger) on every [[DefModule]] in
// circuit, returning a new [[Circuit]] with new [[Seq]] of [[DefModule]].
// - "higher order functions" - using a function as an object
// - "function currying" - partial argument notation
// - "infix notation" - fancy function calling syntax
// - "map" - classic functional programming concept
// - discard the returned new [[Circuit]] because circuit is unmodified
circuit map walkModule(ledger)
// Print our ledger
println(ledger.serialize)
// Return an unchanged [[CircuitState]]
state
}
// Deeply visits every [[Statement]] in m.
def walkModule(ledger: Ledger)(m: DefModule): DefModule = {
// Set ledger to current module name
ledger.setModuleName(m.name)
// Execute the function walkStatement(ledger) on every [[Statement]] in m.
// - return the new [[DefModule]] (in this case, its identical to m)
// - if m does not contain [[Statement]], map returns m.
m map walkStatement(ledger)
}
// Deeply visits every [[Statement]] and [[Expression]] in s.
def walkStatement(ledger: Ledger)(s: Statement): Statement = {
// Execute the function walkExpression(ledger) on every [[Expression]] in s.
// - discard the new [[Statement]] (in this case, its identical to s)
// - if s does not contain [[Expression]], map returns s.
s map walkExpression(ledger)
// Execute the function walkStatement(ledger) on every [[Statement]] in s.
// - return the new [[Statement]] (in this case, its identical to s)
// - if s does not contain [[Statement]], map returns s.
s map walkStatement(ledger)
}
// Deeply visits every [[Expression]] in e.
// - "post-order traversal" - handle e's children [[Expression]] before e
def walkExpression(ledger: Ledger)(e: Expression): Expression = {
// Execute the function walkExpression(ledger) on every [[Expression]] in e.
// - return the new [[Expression]] (in this case, its identical to e)
// - if s does not contain [[Expression]], map returns e.
val visited = e map walkExpression(ledger)
visited match {
// If e is an adder, increment our ledger and return e.
case DoPrim(Add, _, _, _) =>
ledger.foundAdd
e
// If e is not an adder, return e.
case notadd => notadd
}
}
}
Now that we've defined it, let's run it on a Chisel design! First, let's define a Chisel module.
// Chisel stuff
import chisel3._
import chisel3.util._
class AddMe(nInputs: Int, width: Int) extends Module {
val io = IO(new Bundle {
val in = Input(Vec(nInputs, UInt(width.W)))
val out = Output(UInt(width.W))
})
io.out := io.in.reduce(_ +& _)
}
Next, let's elaborate it into FIRRTL AST syntax.
val firrtlSerialization = chisel3.Driver.emit(() => new AddMe(8, 4))
Finally, let's compile our FIRRTL into Verilog, but include our custom transform into the compilation. Note that it prints out the number of add ops it found!
val verilog = compileFIRRTL(firrtlSerialization, new firrtl.VerilogCompiler(), Seq(new AnalyzeCircuit()))
The compileFIRRTL
function is defined only in this tutorial - in a future section, we will describe how the process of inserting customTransforms.
That's it for this section!