Chisel logo

Module 4.4: A FIRRTL Transform Example

Prev: Common Pass Idioms

This AnalyzeCircuit Transform walks a firrtl.ir.Circuit, and records the number of add ops it finds, per module.

Setup

Please run the following:

In [ ]:
val path = System.getProperty("user.dir") + "/source/load-ivy.sc"
interp.load.module(ammonite.ops.Path(java.nio.file.FileSystems.getDefault().getPath(path)))
In [ ]:
// Compiler Infrastructure

// Firrtl IR classes

// Map functions

// Scala's mutable collections
import scala.collection.mutable

Counting Adders Per Module

As described, earlier, a Firrtl circuit is represented using a tree representation:

  • A Firrtl Circuit contains a sequence of DefModules.
  • A DefModule contains a sequence of Ports, and maybe a Statement.
  • A Statement can contain other Statements, or Expressions.
  • A Expression can contain other Expressions.

To visit all Firrtl IR nodes in a circuit, we write functions that recursively walk down this tree. To record statistics, we will pass along a Ledger class and use it when we come across an add op:

In [ ]:
class Ledger {
  import firrtl.Utils
  private var moduleName: Option[String] = None
  private val modules = mutable.Set[String]()
  private val moduleAddMap = mutable.Map[String, Int]()
  def foundAdd(): Unit = moduleName match {
    case None => sys.error("Module name not defined in Ledger!")
    case Some(name) => moduleAddMap(name) = moduleAddMap.getOrElse(name, 0) + 1
  }
  def getModuleName: String = moduleName match {
    case None => Utils.error("Module name not defined in Ledger!")
    case Some(name) => name
  }
  def setModuleName(myName: String): Unit = {
    modules += myName
    moduleName = Some(myName)
  }
  def serialize: String = {
    modules map { myName =>
      s"$myName => ${moduleAddMap.getOrElse(myName, 0)} add ops!"
    } mkString "\n"
  }
}

Now, let's define a FIRRTL Transform that walks the circuit and updates our Ledger whenever it comes across an adder (DoPrim with op argument Add). Don't worry about inputForm or outputForm for now.

Take some time to understand how walkModule, walkStatement, and walkExpression enable traversing all DefModule, Statement, and Expression nodes in the FIRRTL AST.

Questions to answer:

  • Why doesn't walkModule call walkExpression?
  • Why does walkExpression do a post-order traversal?
  • Can you modify walkExpression to do a pre-order traversal of Expressions?
In [ ]:
class AnalyzeCircuit extends firrtl.Transform {
  import firrtl._
  import firrtl.ir._
  import firrtl.Mappers._
  import firrtl.Parser._
  import firrtl.annotations._
  import firrtl.PrimOps._
    
  // Requires the [[Circuit]] form to be "low"
  def inputForm = LowForm
  // Indicates the output [[Circuit]] form to be "low"
  def outputForm = LowForm

  // Called by [[Compiler]] to run your pass. [[CircuitState]] contains
  // the circuit and its form, as well as other related data.
  def execute(state: CircuitState): CircuitState = {
    val ledger = new Ledger()
    val circuit = state.circuit

    // Execute the function walkModule(ledger) on every [[DefModule]] in
    // circuit, returning a new [[Circuit]] with new [[Seq]] of [[DefModule]].
    //   - "higher order functions" - using a function as an object
    //   - "function currying" - partial argument notation
    //   - "infix notation" - fancy function calling syntax
    //   - "map" - classic functional programming concept
    //   - discard the returned new [[Circuit]] because circuit is unmodified
    circuit map walkModule(ledger)

    // Print our ledger
    println(ledger.serialize)

    // Return an unchanged [[CircuitState]]
    state
  }

  // Deeply visits every [[Statement]] in m.
  def walkModule(ledger: Ledger)(m: DefModule): DefModule = {
    // Set ledger to current module name
    ledger.setModuleName(m.name)

    // Execute the function walkStatement(ledger) on every [[Statement]] in m.
    //   - return the new [[DefModule]] (in this case, its identical to m)
    //   - if m does not contain [[Statement]], map returns m.
    m map walkStatement(ledger)
  }

  // Deeply visits every [[Statement]] and [[Expression]] in s.
  def walkStatement(ledger: Ledger)(s: Statement): Statement = {

    // Execute the function walkExpression(ledger) on every [[Expression]] in s.
    //   - discard the new [[Statement]] (in this case, its identical to s)
    //   - if s does not contain [[Expression]], map returns s.
    s map walkExpression(ledger)

    // Execute the function walkStatement(ledger) on every [[Statement]] in s.
    //   - return the new [[Statement]] (in this case, its identical to s)
    //   - if s does not contain [[Statement]], map returns s.
    s map walkStatement(ledger)
  }

  // Deeply visits every [[Expression]] in e.
  //   - "post-order traversal" - handle e's children [[Expression]] before e
  def walkExpression(ledger: Ledger)(e: Expression): Expression = {

    // Execute the function walkExpression(ledger) on every [[Expression]] in e.
    //   - return the new [[Expression]] (in this case, its identical to e)
    //   - if s does not contain [[Expression]], map returns e.
    val visited = e map walkExpression(ledger)

    visited match {
      // If e is an adder, increment our ledger and return e.
      case DoPrim(Add, _, _, _) =>
        ledger.foundAdd
        e
      // If e is not an adder, return e.
      case notadd => notadd
    }
  }
}

Running our Transform

Now that we've defined it, let's run it on a Chisel design! First, let's define a Chisel module.

In [ ]:
// Chisel stuff
import chisel3._
import chisel3.util._

class AddMe(nInputs: Int, width: Int) extends Module {
  val io = IO(new Bundle {
    val in  = Input(Vec(nInputs, UInt(width.W)))
    val out = Output(UInt(width.W))
  })
  io.out := io.in.reduce(_ +& _)
}

Next, let's elaborate it into FIRRTL AST syntax.

In [ ]:
val firrtlSerialization = chisel3.Driver.emit(() => new AddMe(8, 4))

Finally, let's compile our FIRRTL into Verilog, but include our custom transform into the compilation. Note that it prints out the number of add ops it found!

In [ ]:
val verilog = compileFIRRTL(firrtlSerialization, new firrtl.VerilogCompiler(), Seq(new AnalyzeCircuit()))

The compileFIRRTL function is defined only in this tutorial - in a future section, we will describe how the process of inserting customTransforms.

That's it for this section!

In [ ]: