diff --git a/core/src/ops/change_axes.rs b/core/src/ops/change_axes.rs index 8d0951c3d9..6185f836eb 100644 --- a/core/src/ops/change_axes.rs +++ b/core/src/ops/change_axes.rs @@ -176,18 +176,14 @@ impl AxisOp { } (Rm(x), Move(from, to)) => { - // disabled these two as they kinda break axis tracking - // semantics if x == from { - None - // Some((Some(Rm(*to)), None)) + Some((Some(Rm(*to)), None)) } else if x < from.min(to) { Some((Some(self.clone()), Some(Move(from - 1, to - 1)))) } else if x > from.max(to) { Some((Some(self.clone()), Some(change.clone()))) } else if from + 1 == *to && x == to { - // Some((Some(Rm(*from)), None)) - None + Some((Some(Rm(*from)), None)) } else if from < to && x <= to { Some((Some(Rm(x - 1)), Some(Move(*from, *to - 1)))) } else { diff --git a/harness/pre-optimized-graphes/mdl-en-2019-Q3-librispeech/expected b/harness/pre-optimized-graphes/mdl-en-2019-Q3-librispeech/expected index a602af51d5..86d989fe1d 100644 --- a/harness/pre-optimized-graphes/mdl-en-2019-Q3-librispeech/expected +++ b/harness/pre-optimized-graphes/mdl-en-2019-Q3-librispeech/expected @@ -31,14 +31,14 @@ fragment scan_body_0( i"four_parts.split-over-1.0..256" = add(i"four_parts.W.concat-einsum-k.add-1.split-over-1.0..256", i"four_parts.split-1-over-1.0..256.slice"); i"peephole0.output" = add(i"peephole0.mul", i"four_parts.split-over-1.0..256"); i"peephole0.output.nolin" = sigmoid(i"peephole0.output"); - i"four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.prop_axis.a.output" = squeeze(i"c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", axes = [2]); + i"four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.prop_axis.a.output" = squeeze(i"c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", axes = [0]); i"four_parts.W.concat-einsum-k.256..384.split-over-1.512..768" = tract_core_einsum([r, i"four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice"], expr = "ka,kn->bn", acc = "f32", output = ""); i"four_parts.W.concat-einsum-k.add-1.split-over-1.512..768" = add(i"four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.prop_axis.a.output", i"four_parts.W.concat-einsum-k.256..384.split-over-1.512..768"); i"four_parts.split-over-1.512..768" = add(i"four_parts.W.concat-einsum-k.add-1.split-over-1.512..768", i"four_parts.split-1-over-1.512..768.slice"); i"four_parts.j.nolin" = tanh(i"four_parts.split-over-1.512..768"); c_update = mul(i"peephole0.output.nolin", i"four_parts.j.nolin"); i"peephole1.mul" = mul(i"peephole1.mul.fix-rank-0-1", c); - i"four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.prop_axis.a.output" = squeeze(i"c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", axes = [2]); + i"four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.prop_axis.a.output" = squeeze(i"c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", axes = [0]); i"four_parts.W.concat-einsum-k.256..384.split-over-1.256..512" = tract_core_einsum([r, i"four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice"], expr = "ka,kn->bn", acc = "f32", output = ""); i"four_parts.W.concat-einsum-k.add-1.split-over-1.256..512" = add(i"four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.prop_axis.a.output", i"four_parts.W.concat-einsum-k.256..384.split-over-1.256..512"); i"four_parts.split-over-1.256..512" = add(i"four_parts.W.concat-einsum-k.add-1.split-over-1.256..512", i"four_parts.split-1-over-1.256..512.slice"); @@ -48,7 +48,7 @@ fragment scan_body_0( c_new = add(c_update, c_prop); tanh_c = tanh(c_new); i"peephole2.mul" = mul(i"peephole2.mul.fix-rank-0-1", c_new); - i"four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.prop_axis.a.output" = squeeze(i"c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", axes = [2]); + i"four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.prop_axis.a.output" = squeeze(i"c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", axes = [0]); i"four_parts.W.concat-einsum-k.256..384.split-over-1.768..1024" = tract_core_einsum([r, i"four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice"], expr = "ka,kn->bn", acc = "f32", output = ""); i"four_parts.W.concat-einsum-k.add-1.split-over-1.768..1024" = add(i"four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.prop_axis.a.output", i"four_parts.W.concat-einsum-k.256..384.split-over-1.768..1024"); i"four_parts.split-over-1.768..1024" = add(i"four_parts.W.concat-einsum-k.add-1.split-over-1.768..1024", i"four_parts.split-1-over-1.768..1024.slice"); @@ -57,7 +57,7 @@ fragment scan_body_0( m = mul(tanh_c, i"peephole2.output.nolin"); i"h_new.W.split-over-1.0..128" = tract_core_einsum([m, i"h_new.W.split-1-over-1.0..128.slice"], expr = "bk,kn->na", acc = "f32", output = ""); i"h_new.split-over-1.0..128" = add(i"h_new.W.split-over-1.0..128", i"h_new.split-1-over-1.0..128.slice"); - i"h_new.W.split-over-1.128..256.prop_axis.a.input_0" = unsqueeze(m, axes = [2]); + i"h_new.W.split-over-1.128..256.prop_axis.a.input_0" = unsqueeze(m, axes = [0]); r_new = i"h_new.split-over-1.0..128"; } @@ -89,14 +89,14 @@ fragment scan_body_1( i"four_parts.split-over-1.0..256" = add(i"four_parts.W.concat-einsum-k.add-1.split-over-1.0..256", i"four_parts.split-1-over-1.0..256.slice"); i"peephole0.output" = add(i"peephole0.mul", i"four_parts.split-over-1.0..256"); i"peephole0.output.nolin" = sigmoid(i"peephole0.output"); - i"four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.prop_axis.a.output" = squeeze(i"c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", axes = [2]); + i"four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.prop_axis.a.output" = squeeze(i"c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", axes = [0]); i"four_parts.W.concat-einsum-k.256..384.split-over-1.512..768" = tract_core_einsum([r, i"four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice"], expr = "ak,kn->bn", acc = "f32", output = ""); i"four_parts.W.concat-einsum-k.add-1.split-over-1.512..768" = add(i"four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.prop_axis.a.output", i"four_parts.W.concat-einsum-k.256..384.split-over-1.512..768"); i"four_parts.split-over-1.512..768" = add(i"four_parts.W.concat-einsum-k.add-1.split-over-1.512..768", i"four_parts.split-1-over-1.512..768.slice"); i"four_parts.j.nolin" = tanh(i"four_parts.split-over-1.512..768"); c_update = mul(i"peephole0.output.nolin", i"four_parts.j.nolin"); i"peephole1.mul" = mul(i"peephole1.mul.fix-rank-0-1", c); - i"four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.prop_axis.a.output" = squeeze(i"c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", axes = [2]); + i"four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.prop_axis.a.output" = squeeze(i"c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", axes = [0]); i"four_parts.W.concat-einsum-k.256..384.split-over-1.256..512" = tract_core_einsum([r, i"four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice"], expr = "ak,kn->bn", acc = "f32", output = ""); i"four_parts.W.concat-einsum-k.add-1.split-over-1.256..512" = add(i"four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.prop_axis.a.output", i"four_parts.W.concat-einsum-k.256..384.split-over-1.256..512"); i"four_parts.split-over-1.256..512" = add(i"four_parts.W.concat-einsum-k.add-1.split-over-1.256..512", i"four_parts.split-1-over-1.256..512.slice"); @@ -106,7 +106,7 @@ fragment scan_body_1( c_new = add(c_update, c_prop); tanh_c = tanh(c_new); i"peephole2.mul" = mul(i"peephole2.mul.fix-rank-0-1", c_new); - i"four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.prop_axis.a.output" = squeeze(i"c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", axes = [2]); + i"four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.prop_axis.a.output" = squeeze(i"c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", axes = [0]); i"four_parts.W.concat-einsum-k.256..384.split-over-1.768..1024" = tract_core_einsum([r, i"four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice"], expr = "ak,kn->bn", acc = "f32", output = ""); i"four_parts.W.concat-einsum-k.add-1.split-over-1.768..1024" = add(i"four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.prop_axis.a.output", i"four_parts.W.concat-einsum-k.256..384.split-over-1.768..1024"); i"four_parts.split-over-1.768..1024" = add(i"four_parts.W.concat-einsum-k.add-1.split-over-1.768..1024", i"four_parts.split-1-over-1.768..1024.slice"); @@ -115,7 +115,7 @@ fragment scan_body_1( m = mul(tanh_c, i"peephole2.output.nolin"); i"h_new.W.split-over-1.0..128" = tract_core_einsum([m, i"h_new.W.split-1-over-1.0..128.slice"], expr = "bk,kn->an", acc = "f32", output = ""); i"h_new.split-over-1.0..128" = add(i"h_new.W.split-over-1.0..128", i"h_new.split-1-over-1.0..128.slice"); - i"h_new.W.split-over-1.128..256.prop_axis.a.input_0" = unsqueeze(m, axes = [2]); + i"h_new.W.split-over-1.128..256.prop_axis.a.input_0" = unsqueeze(m, axes = [0]); r_new = i"h_new.split-over-1.0..128"; } @@ -175,11 +175,11 @@ graph network(input) -> (output) { i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.prop_axis.a.input_1" = variable(label = "fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.prop_axis.a.input_1", shape = [256, 256]); i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256" = tract_core_einsum([i"tdnn3.renorm.output", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.prop_axis.a.input_1"], expr = "ka,kn->an", acc = "f32", output = ""); i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.prop_axis.a.input_1" = variable(label = "fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.prop_axis.a.input_1", shape = [256, 256]); - i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768" = tract_core_einsum([i"tdnn3.renorm.output", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.prop_axis.a.input_1"], expr = "ka,kn->bna", acc = "f32", output = ""); + i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768" = tract_core_einsum([i"tdnn3.renorm.output", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.prop_axis.a.input_1"], expr = "ka,kn->abn", acc = "f32", output = ""); i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.prop_axis.a.input_1" = variable(label = "fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.prop_axis.a.input_1", shape = [256, 256]); - i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512" = tract_core_einsum([i"tdnn3.renorm.output", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.prop_axis.a.input_1"], expr = "ka,kn->bna", acc = "f32", output = ""); + i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512" = tract_core_einsum([i"tdnn3.renorm.output", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.prop_axis.a.input_1"], expr = "ka,kn->abn", acc = "f32", output = ""); i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.prop_axis.a.input_1" = variable(label = "fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.prop_axis.a.input_1", shape = [256, 256]); - i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024" = tract_core_einsum([i"tdnn3.renorm.output", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.prop_axis.a.input_1"], expr = "ka,kn->bna", acc = "f32", output = ""); + i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024" = tract_core_einsum([i"tdnn3.renorm.output", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.prop_axis.a.input_1"], expr = "ka,kn->abn", acc = "f32", output = ""); i"fastlstm1.c_final_state_init_0" = variable(label = "fastlstm1.c_final_state_init_0", shape = [1, 256]); i"fastlstm1.c_final_state_init_1" = variable(label = "fastlstm1.c_final_state_init_1", shape = [128, 1]); i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice" = variable(label = "fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice", shape = [128, 256]); @@ -195,9 +195,9 @@ graph network(input) -> (output) { i"fastlstm1.peephole0.mul.fix-rank-0-1" = variable(label = "fastlstm1.peephole0.mul.fix-rank-0-1", shape = [1, 256]); i"fastlstm1.peephole1.mul.fix-rank-0-1" = variable(label = "fastlstm1.peephole1.mul.fix-rank-0-1", shape = [1, 256]); i"fastlstm1.peephole2.mul.fix-rank-0-1" = variable(label = "fastlstm1.peephole2.mul.fix-rank-0-1", shape = [1, 256]); - ( i"fastlstm1.c_final", i"fastlstm1.c_final_1" ) = tract_core_scan(body = "scan_body_0", scan = [("c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256", 0, 1), ("c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", 2, 1), ("c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", 2, 1), ("c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", 2, 1)], full = [("four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice"), ("four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice"), ("four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice"), ("four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice"), ("four_parts.split-1-over-1.0..256.slice", i"fastlstm1.four_parts.split-1-over-1.0..256.slice"), ("four_parts.split-1-over-1.256..512.slice", i"fastlstm1.four_parts.split-1-over-1.256..512.slice"), ("four_parts.split-1-over-1.512..768.slice", i"fastlstm1.four_parts.split-1-over-1.512..768.slice"), ("four_parts.split-1-over-1.768..1024.slice", i"fastlstm1.four_parts.split-1-over-1.768..1024.slice"), ("h_new.W.split-1-over-1.0..128.slice", i"fastlstm1.h_new.W.split-1-over-1.0..128.slice"), ("h_new.split-1-over-1.0..128.slice", i"fastlstm1.h_new.split-1-over-1.0..128.slice"), ("peephole0.mul.fix-rank-0-1", i"fastlstm1.peephole0.mul.fix-rank-0-1"), ("peephole1.mul.fix-rank-0-1", i"fastlstm1.peephole1.mul.fix-rank-0-1"), ("peephole2.mul.fix-rank-0-1", i"fastlstm1.peephole2.mul.fix-rank-0-1")], state = [("c", i"fastlstm1.c_final_state_init_0", "c_new"), ("r", i"fastlstm1.c_final_state_init_1", "r_new")], output = [("r_new", "full", 1, 1), ("h_new.W.split-over-1.128..256.prop_axis.a.input_0", "full", 2, 1)], skip = 2); + ( i"fastlstm1.c_final", i"fastlstm1.c_final_1" ) = tract_core_scan(body = "scan_body_0", scan = [("c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256", 0, 1), ("c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", 0, 1), ("c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", 0, 1), ("c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", i"fastlstm1.c_final.extracted.fastlstm1.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", 0, 1)], full = [("four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice"), ("four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice"), ("four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice"), ("four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice", i"fastlstm1.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice"), ("four_parts.split-1-over-1.0..256.slice", i"fastlstm1.four_parts.split-1-over-1.0..256.slice"), ("four_parts.split-1-over-1.256..512.slice", i"fastlstm1.four_parts.split-1-over-1.256..512.slice"), ("four_parts.split-1-over-1.512..768.slice", i"fastlstm1.four_parts.split-1-over-1.512..768.slice"), ("four_parts.split-1-over-1.768..1024.slice", i"fastlstm1.four_parts.split-1-over-1.768..1024.slice"), ("h_new.W.split-1-over-1.0..128.slice", i"fastlstm1.h_new.W.split-1-over-1.0..128.slice"), ("h_new.split-1-over-1.0..128.slice", i"fastlstm1.h_new.split-1-over-1.0..128.slice"), ("peephole0.mul.fix-rank-0-1", i"fastlstm1.peephole0.mul.fix-rank-0-1"), ("peephole1.mul.fix-rank-0-1", i"fastlstm1.peephole1.mul.fix-rank-0-1"), ("peephole2.mul.fix-rank-0-1", i"fastlstm1.peephole2.mul.fix-rank-0-1")], state = [("c", i"fastlstm1.c_final_state_init_0", "c_new"), ("r", i"fastlstm1.c_final_state_init_1", "r_new")], output = [("r_new", "full", 1, 1), ("h_new.W.split-over-1.128..256.prop_axis.a.input_0", "full", 0, 1)], skip = 2); i"fastlstm1.c_final.fastlstm1.h_new.W.split-over-1.128..256.prop_axis.a.input_1" = variable(label = "fastlstm1.c_final.fastlstm1.h_new.W.split-over-1.128..256.prop_axis.a.input_1", shape = [256, 128]); - i"fastlstm1.h_new.W.split-over-1.128..256" = tract_core_einsum([i"fastlstm1.c_final_1", i"fastlstm1.c_final.fastlstm1.h_new.W.split-over-1.128..256.prop_axis.a.input_1"], expr = "bka,kn->na", acc = "f32", output = ""); + i"fastlstm1.h_new.W.split-over-1.128..256" = tract_core_einsum([i"fastlstm1.c_final_1", i"fastlstm1.c_final.fastlstm1.h_new.W.split-over-1.128..256.prop_axis.a.input_1"], expr = "abk,kn->na", acc = "f32", output = ""); i"fastlstm1.c_final.fastlstm1.h_new.split-1-over-1.128..256.slice" = variable(label = "fastlstm1.c_final.fastlstm1.h_new.split-1-over-1.128..256.slice", shape = [128, 1]); i"fastlstm1.h_new.split-over-1.128..256" = add(i"fastlstm1.h_new.W.split-over-1.128..256", i"fastlstm1.c_final.fastlstm1.h_new.split-1-over-1.128..256.slice"); i"fastlstm1.h_new.concat-1" = concat([i"fastlstm1.c_final", i"fastlstm1.h_new.split-over-1.128..256"], axis = 0); @@ -232,11 +232,11 @@ graph network(input) -> (output) { i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.prop_axis.a.input_1" = variable(label = "fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.prop_axis.a.input_1", shape = [256, 256]); i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256" = tract_core_einsum([i"tdnn5.renorm.output", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256.prop_axis.a.input_1"], expr = "ka,kn->an", acc = "f32", output = ""); i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.prop_axis.a.input_1" = variable(label = "fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.prop_axis.a.input_1", shape = [256, 256]); - i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768" = tract_core_einsum([i"tdnn5.renorm.output", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.prop_axis.a.input_1"], expr = "ka,kn->bna", acc = "f32", output = ""); + i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768" = tract_core_einsum([i"tdnn5.renorm.output", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768.prop_axis.a.input_1"], expr = "ka,kn->abn", acc = "f32", output = ""); i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.prop_axis.a.input_1" = variable(label = "fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.prop_axis.a.input_1", shape = [256, 256]); - i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512" = tract_core_einsum([i"tdnn5.renorm.output", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.prop_axis.a.input_1"], expr = "ka,kn->bna", acc = "f32", output = ""); + i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512" = tract_core_einsum([i"tdnn5.renorm.output", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512.prop_axis.a.input_1"], expr = "ka,kn->abn", acc = "f32", output = ""); i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.prop_axis.a.input_1" = variable(label = "fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.prop_axis.a.input_1", shape = [256, 256]); - i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024" = tract_core_einsum([i"tdnn5.renorm.output", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.prop_axis.a.input_1"], expr = "ka,kn->bna", acc = "f32", output = ""); + i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024" = tract_core_einsum([i"tdnn5.renorm.output", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024.prop_axis.a.input_1"], expr = "ka,kn->abn", acc = "f32", output = ""); i"fastlstm2.c_final_state_init_0" = variable(label = "fastlstm2.c_final_state_init_0", shape = [1, 256]); i"fastlstm2.c_final_state_init_1" = variable(label = "fastlstm2.c_final_state_init_1", shape = [1, 128]); i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice" = variable(label = "fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice", shape = [128, 256]); @@ -252,11 +252,11 @@ graph network(input) -> (output) { i"fastlstm2.peephole0.mul.fix-rank-0-1" = variable(label = "fastlstm2.peephole0.mul.fix-rank-0-1", shape = [1, 256]); i"fastlstm2.peephole1.mul.fix-rank-0-1" = variable(label = "fastlstm2.peephole1.mul.fix-rank-0-1", shape = [1, 256]); i"fastlstm2.peephole2.mul.fix-rank-0-1" = variable(label = "fastlstm2.peephole2.mul.fix-rank-0-1", shape = [1, 256]); - ( i"fastlstm2.c_final", i"fastlstm2.c_final_1" ) = tract_core_scan(body = "scan_body_1", scan = [("c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256", 0, 1), ("c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", 2, 1), ("c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", 2, 1), ("c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", 2, 1)], full = [("four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice"), ("four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice"), ("four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice"), ("four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice"), ("four_parts.split-1-over-1.0..256.slice", i"fastlstm2.four_parts.split-1-over-1.0..256.slice"), ("four_parts.split-1-over-1.256..512.slice", i"fastlstm2.four_parts.split-1-over-1.256..512.slice"), ("four_parts.split-1-over-1.512..768.slice", i"fastlstm2.four_parts.split-1-over-1.512..768.slice"), ("four_parts.split-1-over-1.768..1024.slice", i"fastlstm2.four_parts.split-1-over-1.768..1024.slice"), ("h_new.W.split-1-over-1.0..128.slice", i"fastlstm2.h_new.W.split-1-over-1.0..128.slice"), ("h_new.split-1-over-1.0..128.slice", i"fastlstm2.h_new.split-1-over-1.0..128.slice"), ("peephole0.mul.fix-rank-0-1", i"fastlstm2.peephole0.mul.fix-rank-0-1"), ("peephole1.mul.fix-rank-0-1", i"fastlstm2.peephole1.mul.fix-rank-0-1"), ("peephole2.mul.fix-rank-0-1", i"fastlstm2.peephole2.mul.fix-rank-0-1")], state = [("c", i"fastlstm2.c_final_state_init_0", "c_new"), ("r", i"fastlstm2.c_final_state_init_1", "r_new")], output = [("r_new", "full", 0, 1), ("h_new.W.split-over-1.128..256.prop_axis.a.input_0", "full", 2, 1)], skip = 6); + ( i"fastlstm2.c_final", i"fastlstm2.c_final_1" ) = tract_core_scan(body = "scan_body_1", scan = [("c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.0..256", 0, 1), ("c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.512..768", 0, 1), ("c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.256..512", 0, 1), ("c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", i"fastlstm2.c_final.extracted.fastlstm2.four_parts.W.concat-einsum-k.0..256.split-over-1.768..1024", 0, 1)], full = [("four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.0..256.slice"), ("four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.256..512.slice"), ("four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.512..768.slice"), ("four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice", i"fastlstm2.four_parts.W.concat-einsum-k.256..384.split-1-over-1.768..1024.slice"), ("four_parts.split-1-over-1.0..256.slice", i"fastlstm2.four_parts.split-1-over-1.0..256.slice"), ("four_parts.split-1-over-1.256..512.slice", i"fastlstm2.four_parts.split-1-over-1.256..512.slice"), ("four_parts.split-1-over-1.512..768.slice", i"fastlstm2.four_parts.split-1-over-1.512..768.slice"), ("four_parts.split-1-over-1.768..1024.slice", i"fastlstm2.four_parts.split-1-over-1.768..1024.slice"), ("h_new.W.split-1-over-1.0..128.slice", i"fastlstm2.h_new.W.split-1-over-1.0..128.slice"), ("h_new.split-1-over-1.0..128.slice", i"fastlstm2.h_new.split-1-over-1.0..128.slice"), ("peephole0.mul.fix-rank-0-1", i"fastlstm2.peephole0.mul.fix-rank-0-1"), ("peephole1.mul.fix-rank-0-1", i"fastlstm2.peephole1.mul.fix-rank-0-1"), ("peephole2.mul.fix-rank-0-1", i"fastlstm2.peephole2.mul.fix-rank-0-1")], state = [("c", i"fastlstm2.c_final_state_init_0", "c_new"), ("r", i"fastlstm2.c_final_state_init_1", "r_new")], output = [("r_new", "full", 0, 1), ("h_new.W.split-over-1.128..256.prop_axis.a.input_0", "full", 0, 1)], skip = 6); i"output.affine.output.W.concat-einsum-slice-k.0.0..128" = variable(label = "output.affine.output.W.concat-einsum-slice-k.0.0..128", shape = [1690, 128]); i"output.affine.output.W.concat-einsum-k.0..128" = tract_core_einsum([i"output.affine.output.W.concat-einsum-slice-k.0.0..128", i"fastlstm2.c_final"], expr = "mk,nk->nm", acc = "f32", output = ""); i"fastlstm2.c_final.fastlstm2.h_new.W.split-over-1.128..256.prop_axis.a.input_1" = variable(label = "fastlstm2.c_final.fastlstm2.h_new.W.split-over-1.128..256.prop_axis.a.input_1", shape = [256, 128]); - i"fastlstm2.h_new.W.split-over-1.128..256" = tract_core_einsum([i"fastlstm2.c_final_1", i"fastlstm2.c_final.fastlstm2.h_new.W.split-over-1.128..256.prop_axis.a.input_1"], expr = "bka,kn->na", acc = "f32", output = ""); + i"fastlstm2.h_new.W.split-over-1.128..256" = tract_core_einsum([i"fastlstm2.c_final_1", i"fastlstm2.c_final.fastlstm2.h_new.W.split-over-1.128..256.prop_axis.a.input_1"], expr = "abk,kn->na", acc = "f32", output = ""); i"fastlstm2.c_final.fastlstm2.h_new.split-1-over-1.128..256.slice" = variable(label = "fastlstm2.c_final.fastlstm2.h_new.split-1-over-1.128..256.slice", shape = [128, 1]); i"fastlstm2.h_new.split-over-1.128..256" = add(i"fastlstm2.h_new.W.split-over-1.128..256", i"fastlstm2.c_final.fastlstm2.h_new.split-1-over-1.128..256.slice"); i"output.affine.output.W.concat-einsum-slice-k.0.128..256" = variable(label = "output.affine.output.W.concat-einsum-slice-k.0.128..256", shape = [1690, 128]);