diff --git a/frontends/PyCDE/integration_test/esi_test.py b/frontends/PyCDE/integration_test/esi_test.py index d0a5f9543f66..840912e0a384 100644 --- a/frontends/PyCDE/integration_test/esi_test.py +++ b/frontends/PyCDE/integration_test/esi_test.py @@ -7,11 +7,12 @@ import pycde from pycde import (AppID, Clock, Module, Reset, modparams, generator) from pycde.bsp import cosim -from pycde.common import Constant +from pycde.common import Constant, Input, Output from pycde.constructs import ControlReg, Reg, Wire from pycde.esi import ChannelService, FuncService, MMIO, MMIOReadWriteCmdType from pycde.types import (Bits, Channel, UInt) from pycde.behavioral import If, Else, EndIf +from pycde.handshake import Func import sys @@ -107,6 +108,28 @@ def construct(ports): ChannelService.to_host(AppID("const_producer"), ch) +class JoinFunc(Func): + a = Input(UInt(32)) + b = Input(UInt(32)) + x = Output(UInt(32)) + + @generator + def construct(ports): + ports.x = (ports.a + ports.b).as_uint(32) + + +class Join(Module): + clk = Clock() + rst = Reset() + + @generator + def construct(ports): + a = ChannelService.from_host(AppID("join_a"), UInt(32)) + b = ChannelService.from_host(AppID("join_b"), UInt(32)) + f = JoinFunc(clk=ports.clk, rst=ports.rst, a=a, b=b) + ChannelService.to_host(AppID("join_x"), f.x) + + class Top(Module): clk = Clock() rst = Reset() @@ -118,6 +141,7 @@ def construct(ports): MMIOClient(i)() MMIOReadWriteClient(clk=ports.clk, rst=ports.rst) ConstProducer(clk=ports.clk, rst=ports.rst) + Join(clk=ports.clk, rst=ports.rst) if __name__ == "__main__": diff --git a/frontends/PyCDE/integration_test/test_software/esi_test.py b/frontends/PyCDE/integration_test/test_software/esi_test.py index 1c7e26c53a66..8dc55f663719 100644 --- a/frontends/PyCDE/integration_test/test_software/esi_test.py +++ b/frontends/PyCDE/integration_test/test_software/esi_test.py @@ -143,3 +143,20 @@ def read_offset_check(i: int, add_amt: int): producer.disconnect() print(f"data: {data}") assert data == 42 + +################################################################################ +# Handshake Join +################################################################################ + +a = d.ports[esi.AppID("join_a")].write_port("data") +a.connect() +b = d.ports[esi.AppID("join_b")].write_port("data") +b.connect() +x = d.ports[esi.AppID("join_x")].read_port("data") +x.connect() + +a.write(15) +b.write(24) +xdata = x.read() +print(f"join: {xdata}") +assert xdata == 15 + 24 diff --git a/lib/Conversion/HandshakeToDC/HandshakeToDC.cpp b/lib/Conversion/HandshakeToDC/HandshakeToDC.cpp index a7404d1814cd..25173a9777d6 100644 --- a/lib/Conversion/HandshakeToDC/HandshakeToDC.cpp +++ b/lib/Conversion/HandshakeToDC/HandshakeToDC.cpp @@ -762,10 +762,6 @@ class HandshakeToDCPass public: void runOnOperation() override { mlir::ModuleOp mod = getOperation(); - auto targetModifier = [](mlir::ConversionTarget &target) { - // target.addLegalDialect(); - }; - auto patternBuilder = [&](TypeConverter &typeConverter, handshaketodc::ConvertedOps &convertedOps, RewritePatternSet &patterns) { @@ -774,7 +770,7 @@ class HandshakeToDCPass patterns.add(typeConverter, mod.getContext()); }; - LogicalResult res = runHandshakeToDC(mod, patternBuilder, targetModifier); + LogicalResult res = runHandshakeToDC(mod, patternBuilder, nullptr); if (failed(res)) signalPassFailure(); }