Skip to content

Commit

Permalink
add nrTransformPrecode() and nrTransformDeprecode() (#75)
Browse files Browse the repository at this point in the history
  • Loading branch information
catkira authored Oct 31, 2024
1 parent 87eeb47 commit 982b8e8
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 0 deletions.
2 changes: 2 additions & 0 deletions py3gpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
from .nrDLSCHInfo import nrDLSCHInfo
from .nrPDSCHIndices import nrPDSCHIndices
from .nrPDSCHMCSTables import nrPDSCHMCSTables
from .nrTransformPrecode import nrTransformPrecode
from .nrTransformDeprecode import nrTransformDeprecode

from .configs.nrCarrierConfig import nrCarrierConfig
from .configs.nrNumerologyConfig import nrNumerologyConfig
Expand Down
6 changes: 6 additions & 0 deletions py3gpp/nrTransformDeprecode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import numpy as np

def nrTransformDeprecode(modSym, mrb):
mrb = int(mrb)
assert modSym.shape[0] % (mrb * 12) == 0, "input number of rows must be an integer multiple of mrb * 12"
return (np.fft.ifft(modSym.reshape(int(modSym.shape[0] / (mrb * 12)), mrb * 12)) * np.sqrt(mrb * 12)).ravel()
6 changes: 6 additions & 0 deletions py3gpp/nrTransformPrecode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import numpy as np

def nrTransformPrecode(modSym, mrb):
mrb = int(mrb)
assert modSym.shape[0] % (mrb * 12) == 0, "input number of rows must be an integer multiple of mrb * 12"
return (np.fft.fft(modSym.reshape(int(modSym.shape[0] / (mrb * 12)), mrb * 12)) * 1/np.sqrt(mrb * 12)).ravel()
6 changes: 6 additions & 0 deletions tests/test_data/transformPrecode.py

Large diffs are not rendered by default.

24 changes: 24 additions & 0 deletions tests/test_nrTransformPrecode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import sys
import numpy as np
import pytest

from py3gpp.nrTransformPrecode import nrTransformPrecode
from py3gpp.nrSymbolModulate import nrSymbolModulate

sys.path.append("test_data")

from test_data.transformPrecode import cw, desired_result_2, desired_result_40

def test_run_nr_transform_precode_2():
modSym = nrSymbolModulate(cw, 'QPSK')
result_2 = nrTransformPrecode(modSym, 2)
assert np.array_equal(np.round(result_2, 8), np.round(desired_result_2, 8))

def test_run_nr_transform_precode_40():
modSym = nrSymbolModulate(cw, 'QPSK')
result_40 = nrTransformPrecode(modSym, 40)
assert np.array_equal(np.round(result_40, 8), np.round(desired_result_40, 8))

if __name__ == '__main__':
test_run_nr_transform_precode_2()
test_run_nr_transform_precode_40()
29 changes: 29 additions & 0 deletions tests/test_nrTransformPrecode_nrTransformDeprecode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import sys
import numpy as np
import pytest

from py3gpp.nrTransformPrecode import nrTransformPrecode
from py3gpp.nrTransformDeprecode import nrTransformDeprecode
from py3gpp.nrSymbolModulate import nrSymbolModulate

sys.path.append("test_data")

from test_data.transformPrecode import cw, desired_result_2, desired_result_40

def test_run_nrTransformPrecode_nrTransformDeprecode_2():
modSym = nrSymbolModulate(cw, 'QPSK')
result = nrTransformPrecode(modSym, 2)
assert np.array_equal(np.round(result, 8), np.round(desired_result_2, 8))
x = nrTransformDeprecode(result, 2)
assert np.array_equal(np.round(x, 8), np.round(modSym, 8))

def test_run_nrTransformPrecode_nrTransformDeprecode_40():
modSym = nrSymbolModulate(cw, 'QPSK')
result = nrTransformPrecode(modSym, 40)
assert np.array_equal(np.round(result, 8), np.round(desired_result_40, 8))
x = nrTransformDeprecode(result, 40)
assert np.array_equal(np.round(x, 8), np.round(modSym, 8))

if __name__ == '__main__':
test_run_nrTransformPrecode_nrTransformDeprecode_2()
test_run_nrTransformPrecode_nrTransformDeprecode_40()

0 comments on commit 982b8e8

Please sign in to comment.